This is an automated email from the ASF dual-hosted git repository. aglinxinyuan pushed a commit to branch xinyuan-state-only in repository https://gitbox.apache.org/repos/asf/texera.git
commit d21a25271f8bf86a2ed3e19e351b6709142027e6 Author: Xinyuan Lin <[email protected]> AuthorDate: Sun Apr 19 01:27:00 2026 -0700 refactor: isolate state serialization and materialization changes --- .../core/architecture/packaging/output_manager.py | 38 +++++++- amber/src/main/python/core/models/operator.py | 38 +++++++- amber/src/main/python/core/models/state.py | 97 ++++++++++--------- .../main/python/core/runnables/data_processor.py | 1 + amber/src/main/python/core/runnables/main_loop.py | 49 +++++++++- .../main/python/core/runnables/network_receiver.py | 15 ++- .../main/python/core/runnables/network_sender.py | 23 +++-- .../main/python/core/storage/document_factory.py | 107 ++++++++++++--------- .../input_port_materialization_reader_runnable.py | 30 +++++- .../main/python/core/storage/vfs_uri_factory.py | 1 + .../messaginglayer/OutputManager.scala | 20 ++++ .../pythonworker/PythonProxyClient.scala | 7 +- .../pythonworker/PythonProxyServer.scala | 5 +- .../scheduling/RegionExecutionCoordinator.scala | 40 ++++++-- .../InputPortMaterializationReaderThread.scala | 26 ++++- .../amber/core/executor/OperatorExecutor.scala | 8 +- .../org/apache/texera/amber/core/state/State.scala | 83 +++++++++++----- .../apache/texera/amber/core/state/package.scala | 24 +++++ .../amber/core/storage/DocumentFactory.scala | 2 + .../texera/amber/core/storage/VFSURIFactory.scala | 1 + 20 files changed, 452 insertions(+), 163 deletions(-) diff --git a/amber/src/main/python/core/architecture/packaging/output_manager.py b/amber/src/main/python/core/architecture/packaging/output_manager.py index bf4afbf396..065b063f7d 100644 --- a/amber/src/main/python/core/architecture/packaging/output_manager.py +++ b/amber/src/main/python/core/architecture/packaging/output_manager.py @@ -17,6 +17,7 @@ import threading import typing +import uuid from collections import OrderedDict from itertools import chain from loguru import logger @@ -43,7 +44,12 @@ from core.architecture.sendsemantics.round_robin_partitioner import ( ) from core.models import Tuple, Schema, StateFrame from core.models.payload import DataPayload, DataFrame -from core.models.state import State +from core.models.state import ( + State, + STATE_SCHEMA, + serialize_state, + state_uri_from_result_uri, +) from core.storage.document_factory import DocumentFactory from core.storage.runnables.port_storage_writer import ( PortStorageWriter, @@ -87,6 +93,8 @@ class OutputManager: PortIdentity, typing.Tuple[Queue, PortStorageWriter, Thread] ] = dict() + self._storage_uris: typing.Dict[PortIdentity, str] = dict() + def is_missing_output_ports(self): """ This method is only used for ensuring correct region execution. @@ -126,6 +134,7 @@ class OutputManager: Create a separate thread for saving output tuples of a port to storage in batch. """ + self._storage_uris[port_id] = storage_uri document, _ = DocumentFactory.open_document(storage_uri) buffered_item_writer = document.writer(str(get_worker_index(self.worker_id))) writer_queue = Queue() @@ -171,6 +180,31 @@ class OutputManager: PortStorageWriterElement(data_tuple=tuple_) ) + def save_state_to_storage_if_needed(self, state: State, port_id=None) -> None: + if port_id is None: + uris = self._storage_uris.values() + elif port_id in self._storage_uris: + uris = [self._storage_uris[port_id]] + else: + return + + for uri in uris: + state_uri = state_uri_from_result_uri(uri) + try: + document = DocumentFactory.open_document(state_uri)[0] + except ValueError: + document = DocumentFactory.create_document(state_uri, STATE_SCHEMA) + writer = document.writer(str(uuid.uuid4())) + writer.put_one(serialize_state(state)) + writer.close() + + def reset_output_storage(self) -> None: + port_id = self.get_port_ids()[0] + storage_uri = self._storage_uris[port_id] + self.close_port_storage_writers() + DocumentFactory.create_document(storage_uri, self._ports[port_id].get_schema()) + self.set_up_port_storage_writer(port_id, storage_uri) + def close_port_storage_writers(self) -> None: """ Flush the buffers of port storage writers and wait for all the @@ -248,7 +282,7 @@ class OutputManager: receiver, ( StateFrame(payload) - if isinstance(payload, State) + if isinstance(payload, dict) else self.tuple_to_frame(payload) ), ) diff --git a/amber/src/main/python/core/models/operator.py b/amber/src/main/python/core/models/operator.py index 7905083995..5b9672988a 100644 --- a/amber/src/main/python/core/models/operator.py +++ b/amber/src/main/python/core/models/operator.py @@ -108,14 +108,12 @@ class Operator(ABC): def process_state(self, state: State, port: int) -> Optional[State]: """ Process an input State from the given link. - The default implementation is to pass the State to all downstream operators - if the State has pass_to_all_downstream set to True. + The default implementation is to pass the State to downstream operators. :param state: State, a State from an input port to be processed. :param port: int, input port index of the current exhausted port. :return: State, producing one State object """ - if state.passToAllDownstream: - return state + return state def produce_state_on_start(self, port: int) -> State: """ @@ -293,3 +291,35 @@ class TableOperator(TupleOperatorV2): time, or None. """ yield + + +class LoopStartOperator(TableOperator): + @overrides.final + def process_state(self, state: State, port: int) -> Optional[State]: + if "LoopStartStateURI" in state: + state["loop_counter"] += 1 + return state + self.state.update(state) + return None + + @overrides.final + def produce_state_on_finish(self, port: int) -> State: + from pickle import dumps + + self.state["table"] = dumps(Table(self._TableOperator__table_data[port])) + return dict(self.state) + + +class LoopEndOperator(TableOperator): + @overrides.final + def process_table(self, table: Table, port: int) -> Iterator[Optional[TableLike]]: + yield table + + @abstractmethod + def condition(self) -> None: + pass + + def loop_start_id(self) -> str: + del self.state["table"] + del self.state["output"] + return self.state["LoopStartId"] diff --git a/amber/src/main/python/core/models/state.py b/amber/src/main/python/core/models/state.py index 2c8a268dfb..e5726cc3c2 100644 --- a/amber/src/main/python/core/models/state.py +++ b/amber/src/main/python/core/models/state.py @@ -15,61 +15,64 @@ # specific language governing permissions and limitations # under the License. -from dataclasses import dataclass -from pandas import DataFrame -from pyarrow import Table -from typing import Optional +import base64 +import json +from typing import Any, Dict, TypeAlias -from .schema import Schema, AttributeType -from .schema.attribute_type import FROM_PYOBJECT_MAPPING +from .schema import Schema +from .tuple import Tuple +State: TypeAlias = Dict[str, Any] -@dataclass -class State: - def __init__( - self, table: Optional[Table] = None, pass_to_all_downstream: bool = False - ): - self.schema = Schema() - self.passToAllDownstream = pass_to_all_downstream - if table is not None: - self.__dict__.update(table.to_pandas().iloc[0].to_dict()) - self.schema = Schema(table.schema) +STATE_CONTENT = "content" +_TYPE_MARKER = "__texera_type__" +_PAYLOAD_MARKER = "payload" +_BYTES_TYPE = "bytes" - def add( - self, key: str, value: any, value_type: Optional[AttributeType] = None - ) -> None: - self.__dict__[key] = value - if value_type is not None: - self.schema.add(key, value_type) - elif key != "schema": - self.schema.add(key, FROM_PYOBJECT_MAPPING[type(value)]) +STATE_SCHEMA = Schema(raw_schema={STATE_CONTENT: "STRING"}) - def get(self, key: str) -> any: - return self.__dict__[key] - def to_table(self) -> Table: - return Table.from_pandas( - df=DataFrame([self.__dict__]), - schema=self.schema.as_arrow_schema(), - ) +def state_uri_from_result_uri(result_uri: str) -> str: + return result_uri.replace("/result", "/state") - def __setattr__(self, key: str, value: any) -> None: - self.add(key, value) - def __setitem__(self, key: str, value: any) -> None: - self.add(key, value) +def serialize_state(state: State) -> Tuple: + return Tuple( + { + STATE_CONTENT: json.dumps( + _to_json_value(state), separators=(",", ":") + ) + }, + schema=STATE_SCHEMA, + ) - def __getitem__(self, key: str) -> any: - return self.get(key) - def __str__(self) -> str: - content = ", ".join( - [ - repr(key) + ": " + repr(value) - for key, value in self.__dict__.items() - if key != "schema" - ] - ) - return f"State[{content}]" +def deserialize_state(row: Tuple) -> State: + return _from_json_value(json.loads(row[STATE_CONTENT])) - __repr__ = __str__ + +def _to_json_value(value: Any) -> Any: + if value is None or isinstance(value, (bool, int, float, str)): + return value + if isinstance(value, bytes): + return { + _TYPE_MARKER: _BYTES_TYPE, + _PAYLOAD_MARKER: base64.b64encode(value).decode("ascii"), + } + if isinstance(value, dict): + return {str(key): _to_json_value(inner) for key, inner in value.items()} + if isinstance(value, (list, tuple)): + return [_to_json_value(inner) for inner in value] + raise TypeError( + f"State value of type {type(value).__name__} is not JSON serializable" + ) + + +def _from_json_value(value: Any) -> Any: + if isinstance(value, list): + return [_from_json_value(inner) for inner in value] + if isinstance(value, dict): + if value.get(_TYPE_MARKER) == _BYTES_TYPE: + return base64.b64decode(value[_PAYLOAD_MARKER]) + return {key: _from_json_value(inner) for key, inner in value.items()} + return value diff --git a/amber/src/main/python/core/runnables/data_processor.py b/amber/src/main/python/core/runnables/data_processor.py index 4399b1a3a2..815e85a644 100644 --- a/amber/src/main/python/core/runnables/data_processor.py +++ b/amber/src/main/python/core/runnables/data_processor.py @@ -100,6 +100,7 @@ class DataProcessor(Runnable, Stoppable): self._context.worker_id, self._context.console_message_manager.print_buf, ): + self._switch_context() self._set_output_state(executor.process_state(state, port_id)) except Exception as err: diff --git a/amber/src/main/python/core/runnables/main_loop.py b/amber/src/main/python/core/runnables/main_loop.py index d73c655734..ece5cf8e10 100644 --- a/amber/src/main/python/core/runnables/main_loop.py +++ b/amber/src/main/python/core/runnables/main_loop.py @@ -38,8 +38,15 @@ from core.models.internal_queue import ( ECMElement, InternalQueueElement, ) -from core.models.state import State +from core.models.operator import LoopEndOperator, LoopStartOperator +from core.models.state import ( + State, + STATE_SCHEMA, + serialize_state, + state_uri_from_result_uri, +) from core.runnables.data_processor import DataProcessor +from core.storage.document_factory import DocumentFactory from core.util import StoppableQueueBlockingRunnable, get_one_of from core.util.console_message.timestamp import current_time_in_local_timezone from core.util.customized_queue.queue_base import QueueElement @@ -48,6 +55,7 @@ from proto.org.apache.texera.amber.core import ( PortIdentity, ChannelIdentity, EmbeddedControlMessageIdentity, + OperatorIdentity, ) from proto.org.apache.texera.amber.engine.architecture.rpc import ( ConsoleMessage, @@ -61,6 +69,7 @@ from proto.org.apache.texera.amber.engine.architecture.rpc import ( EmbeddedControlMessage, AsyncRpcContext, ControlRequest, + IterationCompletedRequest, ) from proto.org.apache.texera.amber.engine.architecture.worker import ( WorkerState, @@ -87,6 +96,29 @@ class MainLoop(StoppableQueueBlockingRunnable): target=self.data_processor.run, daemon=True, name="data_processor_thread" ).start() + def _attach_loop_start_id(self, output_state: State) -> None: + if "LoopStartId" in output_state: + return + output_state["LoopStartId"] = self.context.worker_id.split("-", 1)[1].rsplit( + "-main-0", 1 + )[0] + output_state["LoopStartStateURI"] = state_uri_from_result_uri( + self.context.input_manager.get_input_state_result_uri() + ) + + def _next_iteration( + self, executor: LoopEndOperator, controller_interface + ) -> None: + controller_interface.iteration_completed( + IterationCompletedRequest(OperatorIdentity(executor.loop_start_id())) + ) + uri = executor.state["LoopStartStateURI"] + del executor.state["LoopStartStateURI"] + del executor.state["LoopStartId"] + writer = DocumentFactory.create_document(uri, STATE_SCHEMA).writer("0") + writer.put_one(serialize_state(executor.state)) + writer.close() + def complete(self) -> None: """ Complete the DataProcessor, marking state to COMPLETED, and notify the @@ -94,12 +126,15 @@ class MainLoop(StoppableQueueBlockingRunnable): """ # flush the buffered console prints self._check_and_report_console_messages(force_flush=True) - self.context.executor_manager.executor.close() + controller_interface = self._async_rpc_client.controller_stub() + executor = self.context.executor_manager.executor + if isinstance(executor, LoopEndOperator) and executor.condition(): + self._next_iteration(executor, controller_interface) + executor.close() # stop the data processing thread self.data_processor.stop() self.context.state_manager.transit_to(WorkerState.COMPLETED) self.context.statistics_manager.update_total_execution_time(time.time_ns()) - controller_interface = self._async_rpc_client.controller_stub() controller_interface.worker_execution_completed(EmptyRequest()) self.context.close() @@ -188,6 +223,10 @@ class MainLoop(StoppableQueueBlockingRunnable): output_state = self.context.state_processing_manager.get_output_state() self._switch_context() if output_state is not None: + if isinstance(self.context.executor_manager.executor, LoopEndOperator): + self.context.output_manager.reset_output_storage() + if isinstance(self.context.executor_manager.executor, LoopStartOperator): + self._attach_loop_start_id(output_state) for to, batch in self.context.output_manager.emit_state(output_state): self._output_queue.put( DataElement( @@ -197,6 +236,7 @@ class MainLoop(StoppableQueueBlockingRunnable): payload=batch, ) ) + self.context.output_manager.save_state_to_storage_if_needed(output_state) def process_tuple_with_udf(self) -> Iterator[Optional[Tuple]]: """ @@ -241,6 +281,7 @@ class MainLoop(StoppableQueueBlockingRunnable): def _process_state(self, state_: State) -> None: self.context.state_processing_manager.current_input_state = state_ + self._switch_context() self.process_input_state() self._check_and_process_control() @@ -329,7 +370,7 @@ class MainLoop(StoppableQueueBlockingRunnable): if ecm.ecm_type != EmbeddedControlMessageType.NO_ALIGNMENT: self.context.pause_manager.resume(PauseType.ECM_PAUSE) - + self._switch_context() if self.context.tuple_processing_manager.current_internal_marker: { StartChannel: self._process_start_channel, diff --git a/amber/src/main/python/core/runnables/network_receiver.py b/amber/src/main/python/core/runnables/network_receiver.py index fd42a8f589..e1815b08f7 100644 --- a/amber/src/main/python/core/runnables/network_receiver.py +++ b/amber/src/main/python/core/runnables/network_receiver.py @@ -32,6 +32,7 @@ from core.architecture.handlers.actorcommand.credit_update_handler import ( ) from core.models import ( DataFrame, + Tuple, StateFrame, ) from core.models.internal_queue import ( @@ -40,8 +41,8 @@ from core.models.internal_queue import ( InternalQueue, ECMElement, ) -from core.models.state import State from core.proxy import ProxyServer +from core.models.state import STATE_SCHEMA, deserialize_state from core.util import Stoppable, get_one_of from core.util.runnable.runnable import Runnable from proto.org.apache.texera.amber.engine.architecture.rpc import EmbeddedControlMessage @@ -96,7 +97,17 @@ class NetworkReceiver(Runnable, Stoppable): "Data", lambda _: DataFrame(table), "State", - lambda _: StateFrame(State(table)), + lambda _: StateFrame( + deserialize_state( + Tuple( + { + name: table[name][0].as_py() + for name in STATE_SCHEMA.get_attr_names() + }, + schema=STATE_SCHEMA, + ) + ) + ), "ECM", lambda _: EmbeddedControlMessage().parse(table["payload"][0].as_py()), ) diff --git a/amber/src/main/python/core/runnables/network_sender.py b/amber/src/main/python/core/runnables/network_sender.py index 9595433fb7..f1bd8659ee 100644 --- a/amber/src/main/python/core/runnables/network_sender.py +++ b/amber/src/main/python/core/runnables/network_sender.py @@ -20,13 +20,18 @@ from loguru import logger from overrides import overrides from typing import Optional -from core.models import DataPayload, InternalQueue, DataFrame, StateFrame, State +from core.models import DataPayload, InternalQueue, DataFrame, StateFrame from core.models.internal_queue import ( InternalQueueElement, DataElement, DCMElement, ECMElement, ) +from core.models.state import ( + STATE_CONTENT, + STATE_SCHEMA, + serialize_state, +) from core.proxy import ProxyClient from core.util import StoppableQueueBlockingRunnable from proto.org.apache.texera.amber.core import ChannelIdentity @@ -98,13 +103,15 @@ class NetworkSender(StoppableQueueBlockingRunnable): data_header = PythonDataHeader(tag=to, payload_type="Data") self._proxy_client.send_data(bytes(data_header), data_payload.frame) elif isinstance(data_payload, StateFrame): - data_header = PythonDataHeader( - tag=to, payload_type=data_payload.frame.__class__.__name__ - ) - table = ( - data_payload.frame.to_table() - if isinstance(data_payload.frame, State) - else None + data_header = PythonDataHeader(tag=to, payload_type="State") + serialized_state = serialize_state(data_payload.frame) + table = pa.Table.from_pydict( + { + STATE_CONTENT: [ + serialized_state[STATE_CONTENT] + ], + }, + schema=STATE_SCHEMA.as_arrow_schema(), ) self._proxy_client.send_data(bytes(data_header), table) else: diff --git a/amber/src/main/python/core/storage/document_factory.py b/amber/src/main/python/core/storage/document_factory.py index 9b686ab66b..8a4d6fe3c5 100644 --- a/amber/src/main/python/core/storage/document_factory.py +++ b/amber/src/main/python/core/storage/document_factory.py @@ -61,30 +61,35 @@ class DocumentFactory: if parsed_uri.scheme == VFSURIFactory.VFS_FILE_URI_SCHEME: _, _, _, resource_type = VFSURIFactory.decode_uri(uri) - if resource_type in {VFSResourceType.RESULT}: - storage_key = DocumentFactory.sanitize_uri_path(parsed_uri) - - # Convert Amber Schema to Iceberg Schema with LARGE_BINARY - # field name encoding - iceberg_schema = amber_schema_to_iceberg_schema(schema) - - create_table( - IcebergCatalogInstance.get_instance(), - StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE, - storage_key, - iceberg_schema, - override_if_exists=True, - ) - - return IcebergDocument[Tuple]( - StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE, - storage_key, - iceberg_schema, - amber_tuples_to_arrow_table, - arrow_table_to_amber_tuples, - ) - else: - raise ValueError(f"Resource type {resource_type} is not supported") + match resource_type: + case VFSResourceType.RESULT: + namespace = StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE + case VFSResourceType.STATE: + namespace = "state" + case _: + raise ValueError(f"Resource type {resource_type} is not supported") + + storage_key = DocumentFactory.sanitize_uri_path(parsed_uri) + # Convert Amber Schema to Iceberg Schema with LARGE_BINARY + # field name encoding + iceberg_schema = amber_schema_to_iceberg_schema(schema) + + create_table( + IcebergCatalogInstance.get_instance(), + namespace, + storage_key, + iceberg_schema, + override_if_exists=True, + ) + + return IcebergDocument[Tuple]( + namespace, + storage_key, + iceberg_schema, + amber_tuples_to_arrow_table, + arrow_table_to_amber_tuples, + ) + else: raise NotImplementedError( f"Unsupported URI scheme: {parsed_uri.scheme} for creating the document" @@ -96,30 +101,36 @@ class DocumentFactory: if parsed_uri.scheme == "vfs": _, _, _, resource_type = VFSURIFactory.decode_uri(uri) - if resource_type in {VFSResourceType.RESULT}: - storage_key = DocumentFactory.sanitize_uri_path(parsed_uri) - - table = load_table_metadata( - IcebergCatalogInstance.get_instance(), - StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE, - storage_key, - ) - - if table is None: - raise ValueError("No storage is found for the given URI") - - amber_schema = Schema(table.schema().as_arrow()) - - document = IcebergDocument( - StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE, - storage_key, - table.schema(), - amber_tuples_to_arrow_table, - arrow_table_to_amber_tuples, - ) - return document, amber_schema - else: - raise ValueError(f"Resource type {resource_type} is not supported") + match resource_type: + case VFSResourceType.RESULT: + namespace = StorageConfig.ICEBERG_TABLE_RESULT_NAMESPACE + case VFSResourceType.STATE: + namespace = "state" + case _: + raise ValueError(f"Resource type {resource_type} is not supported") + + storage_key = DocumentFactory.sanitize_uri_path(parsed_uri) + + table = load_table_metadata( + IcebergCatalogInstance.get_instance(), + namespace, + storage_key, + ) + + if table is None: + raise ValueError("No storage is found for the given URI") + + amber_schema = Schema(table.schema().as_arrow()) + + document = IcebergDocument( + namespace, + storage_key, + table.schema(), + amber_tuples_to_arrow_table, + arrow_table_to_amber_tuples, + ) + return document, amber_schema + else: raise NotImplementedError( f"Unsupported URI scheme: {parsed_uri.scheme} for opening the document" diff --git a/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py b/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py index e49c0316cc..493ecf0a41 100644 --- a/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py +++ b/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py @@ -17,8 +17,8 @@ import typing from loguru import logger -from pyarrow import Table from typing import Union +from pyarrow import Table from core.architecture.sendsemantics.broad_cast_partitioner import ( BroadcastPartitioner, @@ -34,8 +34,9 @@ from core.architecture.sendsemantics.range_based_shuffle_partitioner import ( from core.architecture.sendsemantics.round_robin_partitioner import ( RoundRobinPartitioner, ) -from core.models import Tuple, InternalQueue, DataFrame, DataPayload +from core.models import Tuple, InternalQueue, DataFrame, DataPayload, State, StateFrame from core.models.internal_queue import DataElement, ECMElement +from core.models.state import deserialize_state, state_uri_from_result_uri from core.storage.document_factory import DocumentFactory from core.util import Stoppable, get_one_of from core.util.runnable.runnable import Runnable @@ -125,6 +126,15 @@ class InputPortMaterializationReaderRunnable(Runnable, Stoppable): if receiver == self.worker_actor_id: yield self.tuples_to_data_frame(tuples) + def emit_state_with_filter(self, state: State) -> typing.Iterator[StateFrame]: + for receiver, payload in self.partitioner.flush_state(state): + if receiver == self.worker_actor_id: + yield ( + StateFrame(payload) + if isinstance(payload, dict) + else self.tuples_to_data_frame(payload) + ) + def run(self) -> None: """ Main execution logic that reads tuples from the materialized storage and @@ -138,8 +148,21 @@ class InputPortMaterializationReaderRunnable(Runnable, Stoppable): self.uri ) self.emit_ecm("StartChannel", EmbeddedControlMessageType.NO_ALIGNMENT) - storage_iterator = self.materialization.get() + try: + state_document, _ = DocumentFactory.open_document( + state_uri_from_result_uri(self.uri) + ) + state_iterator = state_document.get() + for state in state_iterator: + for state_frame in self.emit_state_with_filter( + deserialize_state(state) + ): + self.emit_payload(state_frame) + except ValueError: + pass + + storage_iterator = self.materialization.get() # Iterate and process tuples. for tup in storage_iterator: if self._stopped: @@ -149,6 +172,7 @@ class InputPortMaterializationReaderRunnable(Runnable, Stoppable): tup.cast_to_schema(self.tuple_schema) for data_frame in self.tuple_to_batch_with_filter(tup): self.emit_payload(data_frame) + self.emit_ecm("EndChannel", EmbeddedControlMessageType.PORT_ALIGNMENT) self._finished = True except Exception as err: diff --git a/amber/src/main/python/core/storage/vfs_uri_factory.py b/amber/src/main/python/core/storage/vfs_uri_factory.py index de0c5db56e..0e23e60705 100644 --- a/amber/src/main/python/core/storage/vfs_uri_factory.py +++ b/amber/src/main/python/core/storage/vfs_uri_factory.py @@ -34,6 +34,7 @@ class VFSResourceType(str, Enum): RESULT = "result" RUNTIME_STATISTICS = "runtimeStatistics" CONSOLE_MESSAGES = "consoleMessages" + STATE = "state" class VFSURIFactory: diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/messaginglayer/OutputManager.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/messaginglayer/OutputManager.scala index 4ab3d18056..53755b780c 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/messaginglayer/OutputManager.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/messaginglayer/OutputManager.scala @@ -124,6 +124,8 @@ class OutputManager( : mutable.HashMap[PortIdentity, OutputPortResultWriterThread] = mutable.HashMap() + private val storageUris: mutable.HashMap[Int, URI] = mutable.HashMap() + /** * Add down stream operator and its corresponding Partitioner. * @@ -232,6 +234,23 @@ class OutputManager( }) } + def saveStateToStorageIfNeeded(state: State): Unit = { + try { + storageUris.foreach { + case (_, uri) => + val writer = DocumentFactory + .openDocument(State.stateUriFromResultUri(uri)) + ._1 + .writer(VirtualIdentityUtils.getWorkerIndex(actorId).toString) + .asInstanceOf[BufferedItemWriter[Tuple]] + writer.putOne(State.serialize(state)) + writer.close() + } + } catch { + case _: Exception => () + } + } + /** * Singal the port storage writer to flush the remaining buffer and wait for commits to finish so that * the output port is properly completed. If the output port does not need storage, no action will be done. @@ -280,6 +299,7 @@ class OutputManager( } private def setupOutputStorageWriterThread(portId: PortIdentity, storageUri: URI): Unit = { + this.storageUris(portId.id) = storageUri val bufferedItemWriter = DocumentFactory .openDocument(storageUri) ._1 diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala index 6618e857b1..e53fccf8c0 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyClient.scala @@ -21,6 +21,7 @@ package org.apache.texera.amber.engine.architecture.pythonworker import com.twitter.util.{Await, Promise} import org.apache.texera.amber.core.WorkflowRuntimeException +import org.apache.texera.amber.core.state.State import org.apache.texera.amber.core.tuple.{Schema, Tuple} import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity} import org.apache.texera.amber.engine.architecture.pythonworker.WorkerBatchInternalQueue.{ @@ -125,7 +126,11 @@ class PythonProxyClient(portNumberPromise: Promise[Int], val actorId: ActorVirtu case DataFrame(frame) => writeArrowStream(mutable.Queue(ArraySeq.unsafeWrapArray(frame): _*), from, "Data") case StateFrame(state) => - writeArrowStream(mutable.Queue(state.toTuple), from, "State") + writeArrowStream( + mutable.Queue(State.serialize(state)), + from, + "State" + ) } } diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyServer.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyServer.scala index c904e436bc..2a1e212ac8 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyServer.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/pythonworker/PythonProxyServer.scala @@ -128,7 +128,10 @@ private class AmberProducer( dataHeader.payloadType match { case "State" => assert(root.getRowCount == 1) - outputPort.sendTo(to, StateFrame(State(Some(ArrowUtils.getTexeraTuple(0, root))))) + outputPort.sendTo( + to, + StateFrame(State.deserialize(ArrowUtils.getTexeraTuple(0, root))) + ) case "ECM" => assert(root.getRowCount == 1) outputPort.sendTo( diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionExecutionCoordinator.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionExecutionCoordinator.scala index a0c73b6506..a384f383e1 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionExecutionCoordinator.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/scheduling/RegionExecutionCoordinator.scala @@ -21,6 +21,7 @@ package org.apache.texera.amber.engine.architecture.scheduling import org.apache.pekko.pattern.gracefulStop import com.twitter.util.{Duration => TwitterDuration, Future, JavaTimer, Return, Throw, Timer} +import org.apache.texera.amber.core.state.State import org.apache.texera.amber.core.storage.DocumentFactory import org.apache.texera.amber.core.storage.VFSURIFactory.decodeURI import org.apache.texera.amber.core.virtualidentity.ActorVirtualIdentity @@ -181,6 +182,8 @@ class RegionExecutionCoordinator( val actorRef = actorRefService.getActorRef(workerId) // Remove the actorRef so that no other actors can find the worker and send messages. actorRefService.removeActorRef(workerId) + asyncRPCClient.inputGateway.removeControlChannel(workerId) + asyncRPCClient.outputGateway.removeControlChannel(workerId) gracefulStop(actorRef, ScalaDuration(5, TimeUnit.SECONDS)).asTwitter() } }.toSeq @@ -209,14 +212,15 @@ class RegionExecutionCoordinator( regionExecution: RegionExecution, attempt: Int = 1 ): Future[Unit] = { - terminateWorkers(regionExecution).rescue { case err => - logger.warn( - s"Failed to terminate region ${region.id.id} on attempt $attempt. Retrying in ${killRetryDelay.inMilliseconds} ms.", - err - ) - Future - .sleep(killRetryDelay)(killRetryTimer) - .flatMap(_ => terminateWorkersWithRetry(regionExecution, attempt + 1)) + terminateWorkers(regionExecution).rescue { + case err => + logger.warn( + s"Failed to terminate region ${region.id.id} on attempt $attempt. Retrying in ${killRetryDelay.inMilliseconds} ms.", + err + ) + Future + .sleep(killRetryDelay)(killRetryTimer) + .flatMap(_ => terminateWorkersWithRetry(regionExecution, attempt + 1)) } } @@ -563,12 +567,30 @@ class RegionExecutionCoordinator( portConfigs.foreach { case (outputPortId, portConfig) => val storageUriToAdd = portConfig.storageURI + val stateUriToAdd = State.stateUriFromResultUri(storageUriToAdd) val (_, eid, _, _) = decodeURI(storageUriToAdd) val schemaOptional = region.getOperator(outputPortId.opId).outputPorts(outputPortId.portId)._3 val schema = schemaOptional.getOrElse(throw new IllegalStateException("Schema is missing")) - DocumentFactory.createDocument(storageUriToAdd, schema) + if (region.getOperators.exists(_.id.logicalOpId.id.startsWith("LoopEnd-operator-"))) { + try { + DocumentFactory.openDocument(storageUriToAdd) + } catch { + case _: Exception => + DocumentFactory.createDocument(storageUriToAdd, schema) + } + try { + DocumentFactory.openDocument(stateUriToAdd) + } catch { + case _: Exception => + DocumentFactory.createDocument(stateUriToAdd, State.schema) + } + } else { + DocumentFactory.createDocument(storageUriToAdd, schema) + DocumentFactory.createDocument(stateUriToAdd, State.schema) + } + WorkflowExecutionsResource.insertOperatorPortResultUri( eid = eid, globalPortId = outputPortId, diff --git a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/InputPortMaterializationReaderThread.scala b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/InputPortMaterializationReaderThread.scala index 10fbbc44a2..acada743bc 100644 --- a/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/InputPortMaterializationReaderThread.scala +++ b/amber/src/main/scala/org/apache/texera/amber/engine/architecture/worker/managers/InputPortMaterializationReaderThread.scala @@ -21,6 +21,7 @@ package org.apache.texera.amber.engine.architecture.worker.managers import io.grpc.MethodDescriptor import org.apache.texera.amber.config.ApplicationConfig +import org.apache.texera.amber.core.state.State import org.apache.texera.amber.core.storage.DocumentFactory import org.apache.texera.amber.core.storage.model.VirtualDocument import org.apache.texera.amber.core.tuple.Tuple @@ -45,7 +46,11 @@ import org.apache.texera.amber.engine.architecture.worker.WorkflowWorker.{ DPInputQueueElement, FIFOMessageElement } -import org.apache.texera.amber.engine.common.ambermessage.{DataFrame, WorkflowFIFOMessage} +import org.apache.texera.amber.engine.common.ambermessage.{ + DataFrame, + StateFrame, + WorkflowFIFOMessage +} import org.apache.texera.amber.util.VirtualIdentityUtils.getFromActorIdForInputPortStorage import java.net.URI @@ -106,6 +111,25 @@ class InputPortMaterializationReaderThread( } // Flush any remaining tuples in the buffer. if (buffer.nonEmpty) flush() + + try { + val state_document = + DocumentFactory + .openDocument(State.stateUriFromResultUri(uri)) + ._1 + .asInstanceOf[VirtualDocument[Tuple]] + val stateReadIterator = state_document.get() + + while (stateReadIterator.hasNext) { + val state = State.deserialize(stateReadIterator.next()) + inputMessageQueue.put( + FIFOMessageElement(WorkflowFIFOMessage(channelId, getSequenceNumber, StateFrame(state))) + ) + } + } catch { + case _: Exception => + } + emitECM(METHOD_END_CHANNEL, PORT_ALIGNMENT) isFinished.set(true) } catch { diff --git a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/executor/OperatorExecutor.scala b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/executor/OperatorExecutor.scala index f99739acc0..9837213abb 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/executor/OperatorExecutor.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/executor/OperatorExecutor.scala @@ -29,13 +29,7 @@ trait OperatorExecutor { def produceStateOnStart(port: Int): Option[State] = None - def processState(state: State, port: Int): Option[State] = { - if (state.isPassToAllDownstream) { - Some(state) - } else { - None - } - } + def processState(state: State, port: Int): Option[State] = Some(state) def processTupleMultiPort( tuple: Tuple, diff --git a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala index 3226c9d2fe..f76a314b7a 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/State.scala @@ -19,39 +19,70 @@ package org.apache.texera.amber.core.state +import com.fasterxml.jackson.databind.JsonNode import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema, Tuple} +import org.apache.texera.amber.util.JSONUtils.objectMapper -import scala.collection.mutable +import java.net.URI +import java.util.Base64 +import scala.jdk.CollectionConverters.IteratorHasAsScala -final case class State(tuple: Option[Tuple] = None, passToAllDownstream: Boolean = false) { - val data: mutable.Map[String, (AttributeType, Any)] = mutable.LinkedHashMap() - add("passToAllDownstream", passToAllDownstream, AttributeType.BOOLEAN) - if (tuple.isDefined) { - tuple.get.getSchema.getAttributes.foreach { attribute => - add(attribute.getName, tuple.get.getField(attribute.getName), attribute.getType) - } - } +object State { + private val StateContent = "content" + private val BytesTypeMarker = "__texera_type__" + private val BytesValue = "bytes" + private val PayloadMarker = "payload" - def add(key: String, value: Any, valueType: AttributeType): Unit = - data.put(key, (valueType, value)) + val schema: Schema = new Schema( + new Attribute(StateContent, AttributeType.STRING) + ) - def get(key: String): Any = data(key)._2 + def stateUriFromResultUri(resultUri: URI): URI = + new URI(resultUri.toString.replace("/result", "/state")) - def isPassToAllDownstream: Boolean = get("passToAllDownstream").asInstanceOf[Boolean] + def serialize(state: State): Tuple = { + val payloadJson = objectMapper.writeValueAsString(toJsonValue(state)) + Tuple.builder(schema).addSequentially(Array(payloadJson)).build() + } - def apply(key: String): Any = get(key) + def deserialize(tuple: Tuple): State = { + val payload = tuple.getField[String](StateContent) + objectMapper.readTree(payload).fields().asScala.map(entry => entry.getKey -> fromJsonValue(entry.getValue)).toMap + } - def toTuple: Tuple = - Tuple - .builder( - Schema(data.map { - case (name, (attrType, _)) => - new Attribute(name, attrType) - }.toList) - ) - .addSequentially(data.values.map(_._2).toArray) - .build() + private def toJsonValue(value: Any): Any = + value match { + case null => null + case bytes: Array[Byte] => + Map(BytesTypeMarker -> BytesValue, PayloadMarker -> Base64.getEncoder.encodeToString(bytes)) + case map: State => + map.iterator.map { case (k, v) => k -> toJsonValue(v) }.toMap + case iterable: Iterable[_] => + iterable.map(toJsonValue).toList + case other => other + } - override def toString: String = - data.map { case (key, (_, value)) => s"$key: $value" }.mkString(", ") + private def fromJsonValue(node: JsonNode): Any = { + if (node == null || node.isNull) { + null + } else if (node.isObject) { + val fields = node.fields().asScala.map(entry => entry.getKey -> entry.getValue).toMap + fields.get(BytesTypeMarker) match { + case Some(typeNode) if typeNode.isTextual && typeNode.asText() == BytesValue => + Base64.getDecoder.decode(fields(PayloadMarker).asText()) + case _ => + fields.view.mapValues(fromJsonValue).toMap + } + } else if (node.isArray) { + node.elements().asScala.map(fromJsonValue).toList + } else if (node.isBoolean) { + node.asBoolean() + } else if (node.isIntegralNumber) { + node.longValue() + } else if (node.isFloatingPointNumber) { + node.doubleValue() + } else { + node.asText() + } + } } diff --git a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/package.scala b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/package.scala new file mode 100644 index 0000000000..c110f9d814 --- /dev/null +++ b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/state/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.core + +package object state { + type State = Map[String, Any] +} diff --git a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/DocumentFactory.scala b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/DocumentFactory.scala index 15949ef471..ae37def667 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/DocumentFactory.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/DocumentFactory.scala @@ -72,6 +72,7 @@ object DocumentFactory { case RESULT => StorageConfig.icebergTableResultNamespace case CONSOLE_MESSAGES => StorageConfig.icebergTableConsoleMessagesNamespace case RUNTIME_STATISTICS => StorageConfig.icebergTableRuntimeStatisticsNamespace + case STATE => "state" case _ => throw new IllegalArgumentException(s"Resource type $resourceType is not supported") } @@ -119,6 +120,7 @@ object DocumentFactory { case RESULT => StorageConfig.icebergTableResultNamespace case CONSOLE_MESSAGES => StorageConfig.icebergTableConsoleMessagesNamespace case RUNTIME_STATISTICS => StorageConfig.icebergTableRuntimeStatisticsNamespace + case STATE => "state" case _ => throw new IllegalArgumentException(s"Resource type $resourceType is not supported") } diff --git a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/VFSURIFactory.scala b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/VFSURIFactory.scala index 3513ac5ecd..990776a69f 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/VFSURIFactory.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/VFSURIFactory.scala @@ -34,6 +34,7 @@ object VFSResourceType extends Enumeration { val RESULT: Value = Value("result") val RUNTIME_STATISTICS: Value = Value("runtimeStatistics") val CONSOLE_MESSAGES: Value = Value("consoleMessages") + val STATE: Value = Value("state") } object VFSURIFactory {
