This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch rework-tasksdk-supervisor-comms-protocol in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 38d74f3232a0784c1ee852e25affbfbf67a2c704 Author: Ash Berlin-Taylor <a...@apache.org> AuthorDate: Tue Jun 10 11:55:46 2025 +0100 Switch the Supervisor/task process from line-based to length-prefixed The existing JSON Lines based approach had two major drawbacks 1. In the case of really large lines (in the region of 10 or 20MB) the python line buffering could _sometimes_ result in a partial read 2. The JSON based approach didn't have the ability to add any metadata (such as errors). 3. Not every message type/call-site waited for a response, which meant those client functions could never get told about an error This changes the communications protocol in in a couple of ways. First off at the python level the separate send and receive methods in the client/task side have been removed and replaced with a single `send()` that sends the request, reads the response and raises an error if one is returned. Secondly the JSON Lines approach has been changed from a line-based protocol to a binary "frame" one. The protocol (which is the same for whichever side is sending) is length-prefixed, i.e. we first send the length of the data as a 4byte big-endian integer, followed by the data itself. This should remove the possibility of JSON parse errors due to reading incomplete lines Finally the last change made in this PR is to remove the "extra" requests socket/channel. Upon closer examination with this comms path I realised that this socket is unnecessary: Since we are in 100% control of the client side we can make use of the bi-directional nature of `socketpair` and save file handles. --- .../src/airflow/dag_processing/processor.py | 16 +- .../src/airflow/jobs/triggerer_job_runner.py | 6 +- .../tests/unit/dag_processing/test_manager.py | 11 +- .../tests/unit/dag_processing/test_processor.py | 14 +- airflow-core/tests/unit/jobs/test_triggerer_job.py | 1 - devel-common/src/tests_common/pytest_plugin.py | 9 +- task-sdk/src/airflow/sdk/bases/xcom.py | 27 +- .../airflow/sdk/definitions/asset/decorators.py | 12 +- task-sdk/src/airflow/sdk/execution_time/comms.py | 165 ++++++++++- task-sdk/src/airflow/sdk/execution_time/context.py | 32 +-- .../airflow/sdk/execution_time/lazy_sequence.py | 18 +- .../src/airflow/sdk/execution_time/supervisor.py | 248 ++++++++++------ .../src/airflow/sdk/execution_time/task_runner.py | 165 +++-------- .../tests/task_sdk/execution_time/test_comms.py | 83 ++++++ .../task_sdk/execution_time/test_supervisor.py | 312 +++++++++++---------- .../task_sdk/execution_time/test_task_runner.py | 51 ---- 16 files changed, 675 insertions(+), 495 deletions(-) diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 5f69082c758..1d1ad25f00b 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -68,7 +68,6 @@ class DagFileParseRequest(BaseModel): bundle_path: Path """Passing bundle path around lets us figure out relative file path.""" - requests_fd: int callback_requests: list[CallbackRequest] = Field(default_factory=list) type: Literal["DagFileParseRequest"] = "DagFileParseRequest" @@ -102,18 +101,16 @@ ToDagProcessor = Annotated[ def _parse_file_entrypoint(): import structlog - from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time import comms, task_runner # Parse DAG file, send JSON back up! - comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager]( - input=sys.stdin, - decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), + comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager]( + body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), ) - msg = comms_decoder.get_message() + msg = comms_decoder._get_response() if not isinstance(msg, DagFileParseRequest): raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}") - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) task_runner.SUPERVISOR_COMMS = comms_decoder log = structlog.get_logger(logger_name="task") @@ -125,7 +122,7 @@ def _parse_file_entrypoint(): result = _parse_file(msg, log) if result is not None: - comms_decoder.send_request(log, result) + comms_decoder.send(result) def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None: @@ -266,10 +263,9 @@ class DagFileProcessorProcess(WatchedSubprocess): msg = DagFileParseRequest( file=os.fspath(path), bundle_path=bundle_path, - requests_fd=self._requests_fd, callback_requests=callbacks, ) - self.send_msg(msg) + self.send_msg(msg, in_response_to=0) def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override] from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 5df6a032615..0a516abb1a6 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -181,7 +181,6 @@ class messages: class StartTriggerer(BaseModel): """Tell the async trigger runner process to start, and where to send status update messages.""" - requests_fd: int type: Literal["StartTriggerer"] = "StartTriggerer" class TriggerStateChanges(BaseModel): @@ -342,8 +341,8 @@ class TriggerRunnerSupervisor(WatchedSubprocess): ): proc = super().start(id=job.id, job=job, target=cls.run_in_process, logger=logger, **kwargs) - msg = messages.StartTriggerer(requests_fd=proc._requests_fd) - proc.send_msg(msg) + msg = messages.StartTriggerer() + proc.send_msg(msg, in_response_to=0) return proc @functools.cached_property @@ -819,7 +818,6 @@ class TriggerRunner: if not isinstance(msg, messages.StartTriggerer): raise RuntimeError(f"Required first message to be a messages.StartTriggerer, it was {msg}") - comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) writer_transport, writer_protocol = await loop.connect_write_pipe( lambda: asyncio.streams.FlowControlMixin(loop=loop), comms_decoder.request_socket, diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index c9e974b8cdb..799d2ad51c8 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -30,7 +30,7 @@ from collections import deque from datetime import datetime, timedelta from logging.config import dictConfig from pathlib import Path -from socket import socket +from socket import socket, socketpair from unittest import mock from unittest.mock import MagicMock @@ -54,7 +54,6 @@ from airflow.models.dag_version import DagVersion from airflow.models.dagbundle import DagBundleModel from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel -from airflow.sdk.execution_time.supervisor import mkpipe from airflow.utils import timezone from airflow.utils.net import get_hostname from airflow.utils.session import create_session @@ -138,20 +137,19 @@ class TestDagFileProcessorManager: logger_filehandle = MagicMock() proc.create_time.return_value = time.time() proc.wait.return_value = 0 - read_end, write_end = mkpipe(remote_read=True) + read_end, write_end = socketpair() ret = DagFileProcessorProcess( process_log=MagicMock(), id=uuid7(), pid=1234, process=proc, stdin=write_end, - requests_fd=123, logger_filehandle=logger_filehandle, client=MagicMock(), ) if start_time: ret.start_time = start_time - ret._num_open_sockets = 0 + ret._open_sockets.clear() return ret, read_end @pytest.fixture @@ -560,7 +558,6 @@ class TestDagFileProcessorManager: b"{" b'"file":"/opt/airflow/dags/test_dag.py",' b'"bundle_path":"/opt/airflow/dags",' - b'"requests_fd":123,' b'"callback_requests":[],' b'"type":"DagFileParseRequest"' b"}\n", @@ -580,7 +577,7 @@ class TestDagFileProcessorManager: b"{" b'"file":"/opt/airflow/dags/dag_callback_dag.py",' b'"bundle_path":"/opt/airflow/dags",' - b'"requests_fd":123,"callback_requests":' + b'"callback_requests":' b"[" b"{" b'"filepath":"dag_callback_dag.py",' diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 28ce7a8c23f..e673848a915 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -87,7 +87,6 @@ class TestDagFileProcessor: DagFileParseRequest( file=file_path, bundle_path=TEST_DAG_FOLDER, - requests_fd=1, callback_requests=callback_requests or [], ), log=structlog.get_logger(), @@ -395,13 +394,10 @@ def disable_capturing(): @pytest.mark.usefixtures("disable_capturing") def test_parse_file_entrypoint_parses_dag_callbacks(spy_agency): r, w = socketpair() - # Create a valid FD for the decoder to open - _, w2 = socketpair() w.makefile("wb").write( - b'{"file":"/files/dags/wait.py","bundle_path":"/files/dags","requests_fd":' - + str(w2.fileno()).encode("ascii") - + b',"callback_requests": [{"filepath": "wait.py", "bundle_name": "testing", "bundle_version": null, ' + b'{"file":"/files/dags/wait.py","bundle_path":"/files/dags",' + b'"callback_requests": [{"filepath": "wait.py", "bundle_name": "testing", "bundle_version": null, ' b'"msg": "task_failure", "dag_id": "wait_to_fail", "run_id": ' b'"manual__2024-12-30T21:02:55.203691+00:00", ' b'"is_failure_callback": true, "type": "DagCallbackRequest"}], "type": "DagFileParseRequest"}\n' @@ -455,7 +451,7 @@ def test_parse_file_with_dag_callbacks(spy_agency): ) ] _parse_file( - DagFileParseRequest(file="A", bundle_path="no matter", requests_fd=1, callback_requests=requests), + DagFileParseRequest(file="A", bundle_path="no matter", callback_requests=requests), log=structlog.get_logger(), ) @@ -489,8 +485,6 @@ def test_parse_file_with_task_callbacks(spy_agency): bundle_version=None, ) ] - _parse_file( - DagFileParseRequest(file="A", requests_fd=1, callback_requests=requests), log=structlog.get_logger() - ) + _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) assert called is True diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index d3c9ae6f27d..dd30db7cc1a 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -172,7 +172,6 @@ def supervisor_builder(mocker, session): pid=process.pid, stdin=mocker.Mock(), process=process, - requests_fd=-1, capacity=10, ) # Mock the selector diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index b5ae53625e2..73d6c82bcaf 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1963,8 +1963,13 @@ def mock_supervisor_comms(): if not AIRFLOW_V_3_0_PLUS: yield None return + + from airflow.sdk.execution_time.comms import CommsDecoder + with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", + create=True, + spec=CommsDecoder, ) as supervisor_comms: yield supervisor_comms @@ -1991,7 +1996,6 @@ def mocked_parse(spy_agency): id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1 ), file="", - requests_fd=0, ), "example_dag_id", CustomOperator(task_id="hello"), @@ -2198,7 +2202,6 @@ def create_runtime_ti(mocked_parse): ), dag_rel_path="", bundle_info=BundleInfo(name="anything", version="any"), - requests_fd=0, ti_context=ti_context, start_date=start_date, # type: ignore ) diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 0c330652956..aa20c7d3b31 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -70,9 +70,8 @@ class BaseXCom: map_index=map_index, ) - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( + SUPERVISOR_COMMS.send( + SetXCom( key=key, value=value, dag_id=dag_id, @@ -107,9 +106,8 @@ class BaseXCom: """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( + SUPERVISOR_COMMS.send( + SetXCom( key=key, value=value, dag_id=dag_id, @@ -186,9 +184,8 @@ class BaseXCom: # back so that two triggers don't end up interleaving requests and create a possible # race condition where the wrong trigger reads the response. with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( + msg = SUPERVISOR_COMMS.send( + GetXCom( key=key, dag_id=dag_id, task_id=task_id, @@ -197,7 +194,6 @@ class BaseXCom: ), ) - msg = SUPERVISOR_COMMS.get_message() if not isinstance(msg, XComResult): raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") @@ -246,9 +242,8 @@ class BaseXCom: # back so that two triggers don't end up interleaving requests and create a possible # race condition where the wrong trigger reads the response. with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( + msg = SUPERVISOR_COMMS.send( + GetXCom( key=key, dag_id=dag_id, task_id=task_id, @@ -257,7 +252,6 @@ class BaseXCom: include_prior_dates=include_prior_dates, ), ) - msg = SUPERVISOR_COMMS.get_message() if not isinstance(msg, XComResult): raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") @@ -322,9 +316,8 @@ class BaseXCom: map_index=map_index, ) cls.purge(xcom_result) # type: ignore[call-arg] - SUPERVISOR_COMMS.send_request( - log=log, - msg=DeleteXCom( + SUPERVISOR_COMMS.send( + DeleteXCom( key=key, dag_id=dag_id, task_id=task_id, diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index d899dd37183..bfe5ac2fc96 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -69,18 +69,14 @@ class _AssetMainOperator(PythonOperator): ) def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, Any]]: - import structlog - from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - log = structlog.get_logger(logger_name=self.__class__.__qualname__) - def _fetch_asset(name: str) -> Asset: - SUPERVISOR_COMMS.send_request(log, GetAssetByName(name=name)) - if isinstance(msg := SUPERVISOR_COMMS.get_message(), ErrorResponse): - raise AirflowRuntimeError(msg) - return Asset(**msg.model_dump(exclude={"type"})) + resp = SUPERVISOR_COMMS.send(GetAssetByName(name=name)) + if isinstance(resp, ErrorResponse): + raise AirflowRuntimeError(resp) + return Asset(**resp.model_dump(exclude={"type"})) value: Any for key, param in inspect.signature(self.python_callable).parameters.items(): diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index d0622cf6ffc..4041323540a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -43,15 +43,20 @@ Execution API server is because: from __future__ import annotations +import itertools from collections.abc import Iterator from datetime import datetime from functools import cached_property -from typing import Annotated, Any, Literal, Union +from socket import socket +from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar, Union from uuid import UUID +import aiologic import attrs +import msgspec +import structlog from fastapi import Body -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, field_serializer +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, field_serializer from airflow.sdk.api.datamodels._generated import ( AssetEventDagRunReference, @@ -80,6 +85,156 @@ from airflow.sdk.api.datamodels._generated import ( ) from airflow.sdk.exceptions import ErrorType +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + +SendMsgType = TypeVar("SendMsgType", bound=BaseModel) +ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) + + +class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): + id: int + """ + The request id, set by the sender. + + This is used to allow "pipeling" of requests and to be able to tie response to requests, which is + particularly useful in the Triggerer where multiple async tasks can send a requests concurrently. + """ + body: dict[str, Any] + + +class _ResponseFrame(msgspec.Struct, array_like=True, frozen=True, omit_defaults=True): + id: int + """ + The id of the request this is a response to + """ + body: dict[str, Any] | None = None + error: dict[str, Any] | None = None + + +def _msgpack_enc_hook(obj: Any) -> Any: + import pendulum + + if isinstance(obj, pendulum.DateTime): + # convert the complex to a tuple of real, imag + return datetime( + obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo + ) + if isinstance(obj, BaseModel): + return obj.model_dump(exclude_unset=True) + + # Raise a NotImplementedError for other types + raise NotImplementedError(f"Objects of type {type(obj)} are not supported") + + +def _new_encoder() -> msgspec.msgpack.Encoder: + return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook) + + +@attrs.define() +class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): + """Handle communication between the task in this process and the supervisor parent process.""" + + log: Logger = attrs.field(repr=False, factory=structlog.get_logger) + request_socket: socket = attrs.field(factory=lambda: socket(fileno=0)) + + # Do we still need the context lock? Keeps the change small for now + lock: aiologic.Lock = attrs.field(factory=aiologic.Lock, repr=False) + + resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field( + factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False + ) + req_encoder: msgspec.msgpack.Encoder = attrs.field(factory=_new_encoder, repr=False) + + id_counter: Iterator[int] = attrs.field(factory=itertools.count) + + # We could be "clever" here and set the default to this based type parameters and a custom + # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a + # "sort of wrong default" + body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) + + err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) + + def send(self, msg: SendMsgType) -> ReceiveMsgType: + """Send a request to the parent and block until the response is received.""" + bytes = self._encode(msg) + + # print( + # f"Subp: sending {type(msg)} request on {self.request_socket.fileno()}, total len={len(bytes)}", + # file=__import__("sys")._ash_out, + # ) + nsent = self.request_socket.send(bytes) + # print(f"Subp: {nsent=}", file=__import__("sys")._ash_out) + + return self._get_response() + + def _encode(self, msg: SendMsgType) -> bytearray: + # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration + buffer = bytearray(256) + + frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump()) + self.req_encoder.encode_into(frame, buffer, 4) + + n = len(buffer) - 4 + if n > 2**32: + raise OverflowError("Cannot send messages larger than 4GiB") + buffer[:4] = n.to_bytes(4, byteorder="big") + + return buffer + + def _read_frame(self): + """ + Get a message from the parent. + + This will block until the message has been received. + """ + if self.request_socket: + self.request_socket.setblocking(True) + # print("Subp: reading length prefix", file=__import__("sys")._ash_out) + len_bytes = self.request_socket.recv(4) + + if len_bytes == b"": + raise EOFError("Request socket closed before length") + + len = int.from_bytes(len_bytes, byteorder="big") + # print(f"Subp: frame {len=} ({len_bytes=})", file=__import__("sys")._ash_out) + + buffer = bytearray(len) + nread = self.request_socket.recv_into(buffer) + if nread != len: + raise RuntimeError( + f"unable to read full response in child. (We read {nread}, but expected {len})" + ) + if nread == 0: + # print("Subp: EOF when trying to read frame", file=__import__("sys")._ash_out) + raise EOFError("Request socket closed before response was complete") + + try: + return self.resp_decoder.decode(buffer) + except Exception as e: + raise e + + def _from_frame(self, frame): + from airflow.sdk.exceptions import AirflowRuntimeError + + # print(f"Subp: {frame.body=}", file=__import__("sys")._ash_out) + if frame.error is not None: + err = self.err_decoder.validate_python(frame.error) + raise AirflowRuntimeError(error=err) + + if frame.body is None: + return None + + try: + return self.body_decoder.validate_python(frame.body) + except Exception: + self.log.exception("Unable to decode message") + raise + + def _get_response(self) -> ReceiveMsgType: + frame = self._read_frame() + return self._from_frame(frame) + class StartupDetails(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -87,13 +242,7 @@ class StartupDetails(BaseModel): ti: TaskInstance dag_rel_path: str bundle_info: BundleInfo - requests_fd: int start_date: datetime - """ - The channel for the task to send requests over. - - Responses will come back on stdin - """ ti_context: TIRunContext type: Literal["StartupDetails"] = "StartupDetails" diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 438687bdeb8..3ef99fec248 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -155,8 +155,7 @@ def _get_connection(conn_id: str) -> Connection: # back so that two triggers don't end up interleaving requests and create a possible # race condition where the wrong trigger reads the response. with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id)) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -208,8 +207,7 @@ def _get_variable(key: str, deserialize_json: bool) -> Any: # back so that two triggers don't end up interleaving requests and create a possible # race condition where the wrong trigger reads the response. with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetVariable(key=key)) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -263,7 +261,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ # primarily added for triggers but it doesn't make sense to have it in some places # and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426 with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, value=value, description=description)) + SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description)) def _delete_variable(key: str) -> None: @@ -279,8 +277,7 @@ def _delete_variable(key: str) -> None: # primarily added for triggers but it doesn't make sense to have it in some places # and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426 with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=DeleteVariable(key=key)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key)) if TYPE_CHECKING: assert isinstance(msg, OKResponse) @@ -387,13 +384,13 @@ class _AssetRefResolutionMixin: from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS if name: - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByName(name=name)) + msg = GetAssetByName(name=name) elif uri: - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByUri(uri=uri)) + msg = GetAssetByUri(uri=uri) else: raise ValueError("Either name or uri must be provided") - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(msg) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -545,24 +542,26 @@ class InletEventsAccessors( if isinstance(obj, Asset): asset = self._assets[AssetUniqueKey.from_asset(obj)] - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=asset.name, uri=asset.uri)) + msg = GetAssetEventByAsset(name=asset.name, uri=asset.uri) elif isinstance(obj, AssetNameRef): try: asset = next(a for k, a in self._assets.items() if k.name == obj.name) except StopIteration: raise KeyError(obj) from None - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=asset.name, uri=None)) + msg = GetAssetEventByAsset(name=asset.name, uri=None) elif isinstance(obj, AssetUriRef): try: asset = next(a for k, a in self._assets.items() if k.uri == obj.uri) except StopIteration: raise KeyError(obj) from None - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAsset(name=None, uri=asset.uri)) + msg = GetAssetEventByAsset(name=None, uri=asset.uri) elif isinstance(obj, AssetAlias): asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)] - SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetEventByAssetAlias(alias_name=asset_alias.name)) + msg = GetAssetEventByAssetAlias(alias_name=asset_alias.name) + else: + raise TypeError(f"`key` is of unknown type ({type(key).__name__})") - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(msg) if isinstance(msg, ErrorResponse): raise AirflowRuntimeError(msg) @@ -626,8 +625,7 @@ def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse: ) from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - SUPERVISOR_COMMS.send_request(log=log, msg=GetPrevSuccessfulDagRun(ti_id=ti_id)) - msg = SUPERVISOR_COMMS.get_message() + msg = SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id)) if TYPE_CHECKING: assert isinstance(msg, PrevSuccessfulDagRunResult) diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 9cf9acfac81..536555a3e39 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -91,16 +91,14 @@ class LazyXComSequence(Sequence[T]): task = self._xcom_arg.operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComCount( + msg = SUPERVISOR_COMMS.send( + GetXComCount( key=self._xcom_arg.key, dag_id=task.dag_id, run_id=self._ti.run_id, task_id=task.task_id, ), ) - msg = SUPERVISOR_COMMS.get_message() if isinstance(msg, ErrorResponse): raise RuntimeError(msg) if not isinstance(msg, XComCountResponse): @@ -129,9 +127,8 @@ class LazyXComSequence(Sequence[T]): start, stop, step = _coerce_slice(key) with SUPERVISOR_COMMS.lock: source = (xcom_arg := self._xcom_arg).operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceSlice( + msg = SUPERVISOR_COMMS.send( + GetXComSequenceSlice( key=xcom_arg.key, dag_id=source.dag_id, task_id=source.task_id, @@ -141,7 +138,6 @@ class LazyXComSequence(Sequence[T]): step=step, ), ) - msg = SUPERVISOR_COMMS.get_message() if not isinstance(msg, XComSequenceSliceResult): raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}") return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root] @@ -153,9 +149,8 @@ class LazyXComSequence(Sequence[T]): with SUPERVISOR_COMMS.lock: source = (xcom_arg := self._xcom_arg).operator - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXComSequenceItem( + msg = SUPERVISOR_COMMS.send( + GetXComSequenceItem( key=xcom_arg.key, dag_id=source.dag_id, task_id=source.task_id, @@ -163,7 +158,6 @@ class LazyXComSequence(Sequence[T]): offset=key, ), ) - msg = SUPERVISOR_COMMS.get_message() if isinstance(msg, ErrorResponse): raise IndexError(key) if not isinstance(msg, XComSequenceIndexResult): diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 7e6b89043d2..a3c9bcbb3df 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -27,12 +27,13 @@ import selectors import signal import sys import time +import weakref from collections import deque from collections.abc import Generator from contextlib import contextmanager, suppress from datetime import datetime, timezone from http import HTTPStatus -from socket import SO_SNDBUF, SOL_SOCKET, SocketIO, socket, socketpair +from socket import SO_SNDBUF, SOL_SOCKET, socket, socketpair from typing import ( TYPE_CHECKING, Callable, @@ -63,6 +64,7 @@ from airflow.sdk.api.datamodels._generated import ( XComSequenceIndexResponse, ) from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time import comms from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, @@ -108,6 +110,8 @@ from airflow.sdk.execution_time.comms import ( XComResult, XComSequenceIndexResult, XComSequenceSliceResult, + _RequestFrame, + _ResponseFrame, ) from airflow.sdk.execution_time.secrets_masker import mask_secret @@ -192,6 +196,8 @@ def mkpipe( local.setsockopt(SO_SNDBUF, SOL_SOCKET, BUFFER_SIZE) # set nonblocking to True so that send or sendall waits till all data is sent local.setblocking(True) + else: + local.setblocking(False) return remote, local @@ -223,14 +229,13 @@ def _configure_logs_over_json_channel(log_fd: int): def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): # Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the # pipes from the supervisor - for handle_name, fd, sock, mode in ( - ("stdin", 0, child_stdin, "r"), + # Yes, we want to re-open stdin in write mode! This is cause it is a bi-directional socket, so we can + # read and write to it. + ("stdin", 0, child_stdin, "w"), ("stdout", 1, child_stdout, "w"), ("stderr", 2, child_stderr, "w"), ): - handle = getattr(sys, handle_name) - handle.close() os.dup2(sock.fileno(), fd) del sock @@ -318,7 +323,7 @@ def block_orm_access(): def _fork_main( - child_stdin: socket, + requests: socket, child_stdout: socket, child_stderr: socket, log_fd: int, @@ -345,10 +350,12 @@ def _fork_main( # Store original stderr for last-chance exception handling last_chance_stderr = _get_last_chance_stderr() + # os.environ["_AIRFLOW_SUPERVISOR_FD"] = str(requests.fileno()) + _reset_signals() if log_fd: _configure_logs_over_json_channel(log_fd) - _reopen_std_io_handles(child_stdin, child_stdout, child_stderr) + _reopen_std_io_handles(requests, child_stdout, child_stderr) def exit(n: int) -> NoReturn: with suppress(ValueError, OSError): @@ -359,11 +366,11 @@ def _fork_main( last_chance_stderr.flush() # Explicitly close the child-end of our supervisor sockets so - # the parent sees EOF on both "requests" and "logs" channels. + # the parent sees EOF on "logs" channel. with suppress(OSError): os.close(log_fd) with suppress(OSError): - os.close(child_stdin.fileno()) + os.close(requests.fileno()) os._exit(n) if hasattr(atexit, "_clear"): @@ -431,16 +438,18 @@ class WatchedSubprocess: """The decoder to use for incoming messages from the child process.""" _process: psutil.Process = attrs.field(repr=False) - _requests_fd: int """File descriptor for request handling.""" - _num_open_sockets: int = 4 _exit_code: int | None = attrs.field(default=None, init=False) _process_exit_monotonic: float | None = attrs.field(default=None, init=False) - _fd_to_socket_type: dict[int, str] = attrs.field(factory=dict, init=False) + _open_sockets: weakref.WeakKeyDictionary[socket, str] = attrs.field( + factory=weakref.WeakKeyDictionary, init=False + ) selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector, repr=False) + _frame_encoder: msgspec.msgpack.Encoder = attrs.field(factory=comms._new_encoder, repr=False) + process_log: FilteringBoundLogger = attrs.field(repr=False) subprocess_logs_to_stdout: bool = False @@ -459,18 +468,19 @@ class WatchedSubprocess: ) -> Self: """Fork and start a new subprocess with the specified target function.""" # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess - child_stdin, feed_stdin = mkpipe(remote_read=True) - child_stdout, read_stdout = mkpipe() - child_stderr, read_stderr = mkpipe() + child_stdout, read_stdout = socketpair() + child_stderr, read_stderr = socketpair() + + # Place for child to send requests/read responses, and the server side to read/respond + child_requests, read_requests = socketpair() - # Open these socketpair before forking off the child, so that it is open when we fork. - child_comms, read_msgs = mkpipe() - child_logs, read_logs = mkpipe() + # Open the socketpair before forking off the child, so that it is open when we fork. + child_logs, read_logs = socketpair() pid = os.fork() if pid == 0: # Close and delete of the parent end of the sockets. - cls._close_unused_sockets(feed_stdin, read_stdout, read_stderr, read_msgs, read_logs) + cls._close_unused_sockets(read_requests, read_stdout, read_stderr, read_logs) # Python GC should delete these for us, but lets make double sure that we don't keep anything # around in the forked processes, especially things that might involve open files or sockets! @@ -479,28 +489,28 @@ class WatchedSubprocess: try: # Run the child entrypoint - _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + _fork_main(child_requests, child_stdout, child_stderr, child_logs.fileno(), target) except BaseException as e: + import traceback + with suppress(BaseException): # We can't use log here, as if we except out of _fork_main something _weird_ went on. - print("Exception in _fork_main, exiting with code 124", e, file=sys.stderr) + print("Exception in _fork_main, exiting with code 124", file=sys.stderr) + traceback.print_exception(type(e), e, e.__traceback__, file=sys.stderr) # It's really super super important we never exit this block. We are in the forked child, and if we # do then _THINGS GET WEIRD_.. (Normally `_fork_main` itself will `_exit()` so we never get here) os._exit(124) - requests_fd = child_comms.fileno() - # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the # other end of the pair open - cls._close_unused_sockets(child_stdin, child_stdout, child_stderr, child_comms, child_logs) + cls._close_unused_sockets(child_stdout, child_stderr, child_logs) logger = logger or cast("FilteringBoundLogger", structlog.get_logger(logger_name="task").bind()) proc = cls( pid=pid, - stdin=feed_stdin, + stdin=read_requests, process=psutil.Process(pid), - requests_fd=requests_fd, process_log=logger, start_time=time.monotonic(), **constructor_kwargs, @@ -509,7 +519,7 @@ class WatchedSubprocess: proc._register_pipe_readers( stdout=read_stdout, stderr=read_stderr, - requests=read_msgs, + requests=read_requests, logs=read_logs, ) @@ -522,24 +532,26 @@ class WatchedSubprocess: # alternatives are used automatically) -- this is a way of having "event-based" code, but without # needing full async, to read and process output from each socket as it is received. - # Track socket types for debugging - self._fd_to_socket_type = { - stdout.fileno(): "stdout", - stderr.fileno(): "stderr", - requests.fileno(): "requests", - logs.fileno(): "logs", - } + # Track the open sockets, and for debugging what type each one is + self._open_sockets.update( + ( + (stdout, "stdout"), + (stderr, "stderr"), + (logs, "logs"), + (requests, "requests"), + ) + ) target_loggers: tuple[FilteringBoundLogger, ...] = (self.process_log,) if self.subprocess_logs_to_stdout: target_loggers += (log,) self.selector.register( - stdout, selectors.EVENT_READ, self._create_socket_handler(target_loggers, channel="stdout") + stdout, selectors.EVENT_READ, self._create_log_forwarder(target_loggers, channel="stdout") ) self.selector.register( stderr, selectors.EVENT_READ, - self._create_socket_handler(target_loggers, channel="stderr", log_level=logging.ERROR), + self._create_log_forwarder(target_loggers, channel="stderr", log_level=logging.ERROR), ) self.selector.register( logs, @@ -551,37 +563,52 @@ class WatchedSubprocess: self.selector.register( requests, selectors.EVENT_READ, - make_buffered_socket_reader(self.handle_requests(log), on_close=self._on_socket_closed), + length_prefixed_frame_reader(self.handle_requests(log), on_close=self._on_socket_closed), ) - def _create_socket_handler(self, loggers, channel, log_level=logging.INFO) -> Callable[[socket], bool]: + def _create_log_forwarder(self, loggers, channel, log_level=logging.INFO) -> Callable[[socket], bool]: """Create a socket handler that forwards logs to a logger.""" return make_buffered_socket_reader( forward_to_log(loggers, chan=channel, level=log_level), on_close=self._on_socket_closed ) - def _on_socket_closed(self): + def _on_socket_closed(self, sock: socket): # We want to keep servicing this process until we've read up to EOF from all the sockets. - self._num_open_sockets -= 1 + self._open_sockets.pop(sock, None) - def send_msg(self, msg: BaseModel, **dump_opts): - """Send the given pydantic message to the subprocess at once by encoding it and adding a line break.""" - b = msg.model_dump_json(**dump_opts).encode() + b"\n" - self.stdin.sendall(b) + def send_msg( + self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse | None = None, **dump_opts + ): + """Send the msg as a length-prefixed response frame.""" + # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing for inspiration + if msg: + frame = _ResponseFrame(id=in_response_to, body=msg.model_dump(**dump_opts)) + else: + err_resp = error.model_dump() if error else None + frame = _ResponseFrame(id=in_response_to, error=err_resp) + buffer = bytearray(256) + + self._frame_encoder.encode_into(frame, buffer, 4) + n = len(buffer) - 4 + if n > 2**32: + raise OverflowError("Cannot send messages larger than 4GiB") + buffer[:4] = n.to_bytes(4, byteorder="big") + + self.stdin.sendall(buffer) - def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]: + def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, _RequestFrame, None]: """Handle incoming requests from the task process, respond with the appropriate data.""" while True: - line = yield + request = yield try: - msg = self.decoder.validate_json(line) + msg = self.decoder.validate_python(request.body) except Exception: - log.exception("Unable to decode message", line=line) + log.exception("Unable to decode message", body=request.body) continue try: - self._handle_request(msg, log) + self._handle_request(msg, log, request.id) except ServerResponseError as e: error_details = e.response.json() if e.response else None log.error( @@ -593,27 +620,25 @@ class WatchedSubprocess: # Send error response back to task so that the error appears in the task logs self.send_msg( - ErrorResponse( + msg=None, + error=ErrorResponse( error=ErrorType.API_SERVER_ERROR, detail={ "status_code": e.response.status_code, "message": str(e), "detail": error_details, }, - ) + ), + in_response_to=request.id, ) - def _handle_request(self, msg, log: FilteringBoundLogger) -> None: + def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None: raise NotImplementedError() @staticmethod def _close_unused_sockets(*sockets): """Close unused ends of sockets after fork.""" for sock in sockets: - if isinstance(sock, SocketIO): - # If we have the socket IO object, we need to close the underlying socket foricebly here too, - # else we get unclosed socket warnings, and likely leaking FDs too - sock._sock.close() sock.close() def _cleanup_open_sockets(self): @@ -623,20 +648,18 @@ class WatchedSubprocess: # sockets the supervisor would wait forever thinking they are still # active. This cleanup ensures we always release resources and exit. stuck_sockets = [] - for key in list(self.selector.get_map().values()): - socket_type = self._fd_to_socket_type.get(key.fd, f"unknown-{key.fd}") - stuck_sockets.append(f"{socket_type}({key.fd})") + for sock, socket_type in self._open_sockets.items(): + fileno = "unknown" with suppress(Exception): - self.selector.unregister(key.fileobj) - with suppress(Exception): - key.fileobj.close() # type: ignore[union-attr] + fileno = sock.fileno() + sock.close() + stuck_sockets.append(f"{socket_type}(fd={fileno})") if stuck_sockets: log.warning("Force-closed stuck sockets", pid=self.pid, sockets=stuck_sockets) self.selector.close() - self._close_unused_sockets(self.stdin) - self._num_open_sockets = 0 + self.stdin.close() def kill( self, @@ -753,7 +776,9 @@ class WatchedSubprocess: # are removed. if not need_more: self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] + sock: socket = key.fileobj # type: ignore[assignment] + sock.close() + self._on_socket_closed(sock) # Check if the subprocess has exited return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout, expect_signal=expect_signal) @@ -860,7 +885,6 @@ class ActivitySubprocess(WatchedSubprocess): ti=ti, dag_rel_path=os.fspath(dag_rel_path), bundle_info=bundle_info, - requests_fd=self._requests_fd, ti_context=ti_context, start_date=start_date, ) @@ -869,7 +893,7 @@ class ActivitySubprocess(WatchedSubprocess): log.debug("Sending", msg=msg) try: - self.send_msg(msg) + self.send_msg(msg, in_response_to=0) except BrokenPipeError: # Debug is fine, the process will have shown _something_ in it's last_chance exception handler log.debug("Couldn't send startup message to Subprocess - it died very early", pid=self.pid) @@ -929,7 +953,7 @@ class ActivitySubprocess(WatchedSubprocess): - Processes events triggered on the monitored file objects, such as data availability or EOF. - Sends heartbeats to ensure the process is alive and checks if the subprocess has exited. """ - while self._exit_code is None or self._num_open_sockets > 0: + while self._exit_code is None or self._open_sockets: last_heartbeat_ago = time.monotonic() - self._last_successful_heartbeat # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible # so we notice the subprocess finishing as quick as we can. @@ -945,16 +969,11 @@ class ActivitySubprocess(WatchedSubprocess): # This listens for activity (e.g., subprocess output) on registered file objects alive = self._service_subprocess(max_wait_time=max_wait_time) is None - if self._exit_code is not None and self._num_open_sockets > 0: + if self._exit_code is not None and self._open_sockets: if ( self._process_exit_monotonic and time.monotonic() - self._process_exit_monotonic > SOCKET_CLEANUP_TIMEOUT ): - log.debug( - "Forcefully closing remaining sockets", - open_sockets=self._num_open_sockets, - pid=self.pid, - ) self._cleanup_open_sockets() if alive: @@ -1050,7 +1069,7 @@ class ActivitySubprocess(WatchedSubprocess): return SERVER_TERMINATED return TaskInstanceState.FAILED - def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): + def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int): log.debug("Received message from task runner", msg=msg) resp: BaseModel | None = None dump_opts = {} @@ -1223,10 +1242,21 @@ class ActivitySubprocess(WatchedSubprocess): dump_opts = {"exclude_unset": True} else: log.error("Unhandled request", msg=msg) + self.send_msg( + None, + in_response_to=req_id, + error=ErrorResponse( + error=ErrorType.API_SERVER_ERROR, + detail={"status_code": 400, "message": "Unhandled request"}, + ), + ) return if resp: - self.send_msg(resp, **dump_opts) + self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts) + else: + # Send an empty response frame (which signifies no error) if we dont have anything else to say + self.send_msg(None, in_response_to=req_id) def in_process_api_server(): @@ -1253,7 +1283,7 @@ class InProcessSupervisorComms: log.debug("Sending request", msg=msg) with set_supervisor_comms(None): - self.supervisor._handle_request(msg, log) # type: ignore[arg-type] + self.supervisor._handle_request(msg, log, 0) # type: ignore[arg-type] @attrs.define @@ -1297,7 +1327,6 @@ class InProcessTestSupervisor(ActivitySubprocess): id=what.id, pid=os.getpid(), # Use current process process=psutil.Process(), # Current process - requests_fd=-1, # Not used in in-process mode process_log=logger or structlog.get_logger(logger_name="task").bind(), client=cls._api_client(task.dag), **kwargs, @@ -1362,9 +1391,12 @@ class InProcessTestSupervisor(ActivitySubprocess): client.base_url = "http://in-process.invalid./" # type: ignore[assignment] return client - def send_msg(self, msg: BaseModel, **dump_opts): + def send_msg( + self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse | None = None, **dump_opts + ): """Override to use in-process comms.""" - self.comms.messages.append(msg) + if msg is not None: + self.comms.messages.append(msg) @property def final_state(self): @@ -1420,7 +1452,7 @@ def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: # to a (sync) generator def make_buffered_socket_reader( gen: Generator[None, bytes | bytearray, None], - on_close: Callable, + on_close: Callable[[socket], None], buffer_size: int = 4096, ) -> Callable[[socket], bool]: buffer = bytearray() # This will hold our accumulated binary data @@ -1440,7 +1472,7 @@ def make_buffered_socket_reader( with suppress(StopIteration): gen.send(buffer) # Tell loop to close this selector - on_close() + on_close(sock) return False buffer.extend(read_buffer[:n_received]) @@ -1451,7 +1483,7 @@ def make_buffered_socket_reader( try: gen.send(line) except StopIteration: - on_close() + on_close(sock) return False buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data @@ -1460,6 +1492,56 @@ def make_buffered_socket_reader( return cb +def length_prefixed_frame_reader( + gen: Generator[None, _RequestFrame, None], on_close: Callable[[socket], None] +): + length_needed: int | None = None + # This will hold our accumulated/partial binary frame if it doesn't come in a single read + buffer: memoryview | None = None + # position in the buffer to store next read + pos = 0 + decoder = msgspec.msgpack.Decoder[_RequestFrame](_RequestFrame) + + # We need to start up the generator to get it to the point it's at waiting on the yield + next(gen) + + def cb(sock: socket): + print("Main: length_prefixed_frame_reader.cb fired") + nonlocal buffer, length_needed, pos + # Read up to `buffer_size` bytes of data from the socket + + if length_needed is None: + # Read the 32bit length of the frame + bytes = sock.recv(4) + if bytes == b"": + on_close(sock) + return False + + length_needed = int.from_bytes(bytes, byteorder="big") + buffer = memoryview(bytearray(length_needed)) + if length_needed and buffer: + n = sock.recv_into(buffer[pos:]) + if n == 0: + # EOF + on_close(sock) + return False + pos += n + + if len(buffer) >= length_needed: + request = decoder.decode(buffer) + buffer = None + pos = 0 + length_needed = None + try: + gen.send(request) + except StopIteration: + on_close(sock) + return False + return True + + return cb + + def process_log_messages_from_subprocess( loggers: tuple[FilteringBoundLogger, ...], ) -> Generator[None, bytes, None]: diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index bbc7394a87a..199544ae835 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -28,16 +28,14 @@ import time from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import suppress from datetime import datetime, timezone -from io import FileIO from itertools import product from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TextIO, TypeVar +from typing import TYPE_CHECKING, Annotated, Any, Literal -import aiologic import attrs import lazy_object_proxy import structlog -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, TypeAdapter +from pydantic import AwareDatetime, ConfigDict, Field, JsonValue from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager @@ -59,6 +57,7 @@ from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType from airflow.sdk.execution_time.callback_runner import create_executable_runner from airflow.sdk.execution_time.comms import ( AssetEventDagRunReferenceResult, + CommsDecoder, DagRunStateResult, DeferTask, DRCount, @@ -412,10 +411,9 @@ class RuntimeTaskInstance(TaskInstance): log.debug("Requesting first reschedule date from supervisor") - SUPERVISOR_COMMS.send_request( - log=log, msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) + response = SUPERVISOR_COMMS.send( + msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) ) - response = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(response, TaskRescheduleStartDate) @@ -433,12 +431,9 @@ class RuntimeTaskInstance(TaskInstance): states: list[str] | None = None, ) -> int: """Return the number of task instances matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTICount( + response = SUPERVISOR_COMMS.send( + GetTICount( dag_id=dag_id, map_index=map_index, task_ids=task_ids, @@ -448,7 +443,6 @@ class RuntimeTaskInstance(TaskInstance): states=states, ), ) - response = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(response, TICount) @@ -465,12 +459,9 @@ class RuntimeTaskInstance(TaskInstance): run_ids: list[str] | None = None, ) -> dict[str, Any]: """Return the task states matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetTaskStates( + response = SUPERVISOR_COMMS.send( + GetTaskStates( dag_id=dag_id, map_index=map_index, task_ids=task_ids, @@ -479,7 +470,6 @@ class RuntimeTaskInstance(TaskInstance): run_ids=run_ids, ), ) - response = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(response, TaskStatesResult) @@ -494,19 +484,15 @@ class RuntimeTaskInstance(TaskInstance): states: list[str] | None = None, ) -> int: """Return the number of DAG runs matching the given criteria.""" - log = structlog.get_logger(logger_name="task") - with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetDRCount( + response = SUPERVISOR_COMMS.send( + GetDRCount( dag_id=dag_id, logical_dates=logical_dates, run_ids=run_ids, states=states, ), ) - response = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(response, DRCount) @@ -516,10 +502,8 @@ class RuntimeTaskInstance(TaskInstance): @staticmethod def get_dagrun_state(dag_id: str, run_id: str) -> str: """Return the state of the DAG run with the given Run ID.""" - log = structlog.get_logger(logger_name="task") with SUPERVISOR_COMMS.lock: - SUPERVISOR_COMMS.send_request(log=log, msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) - response = SUPERVISOR_COMMS.get_message() + response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, DagRunStateResult) @@ -638,62 +622,6 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: ) -SendMsgType = TypeVar("SendMsgType", bound=BaseModel) -ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel) - - -@attrs.define() -class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]): - """Handle communication between the task in this process and the supervisor parent process.""" - - input: TextIO - - request_socket: FileIO = attrs.field(init=False, default=None) - - # We could be "clever" here and set the default to this based type parameters and a custom - # `__class_getitem__`, but that's a lot of code the one subclass we've got currently. So we'll just use a - # "sort of wrong default" - decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: TypeAdapter(ToTask), repr=False) - - lock: aiologic.Lock = attrs.field(factory=aiologic.Lock, repr=False) - - def get_message(self) -> ReceiveMsgType: - """ - Get a message from the parent. - - This will block until the message has been received. - """ - line = None - - # TODO: Investigate why some empty lines are sent to the processes stdin. - # That was highlighted when working on https://github.com/apache/airflow/issues/48183 - # and is maybe related to deferred/triggerer only context. - while not line: - line = self.input.readline() - - try: - msg = self.decoder.validate_json(line) - except Exception: - structlog.get_logger(logger_name="CommsDecoder").exception("Unable to decode message", line=line) - raise - - if isinstance(msg, StartupDetails): - # If we read a startup message, pull out the FDs we care about! - if msg.requests_fd > 0: - self.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) - elif isinstance(msg, ErrorResponse) and msg.error == ErrorType.API_SERVER_ERROR: - structlog.get_logger(logger_name="task").error("Error response from the API Server") - raise AirflowRuntimeError(error=msg) - - return msg - - def send_request(self, log: Logger, msg: SendMsgType): - encoded_msg = msg.model_dump_json().encode() + b"\n" - - log.debug("Sending request", json=encoded_msg) - self.request_socket.write(encoded_msg) - - # This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, # to send requests back to the supervisor process. # @@ -713,31 +641,33 @@ SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: - msg = SUPERVISOR_COMMS.get_message() + # The parent sends us a StartupDetails message un-prompted. After this, ever single message is only sent + # in response to us sending a request. + msg = SUPERVISOR_COMMS._get_response() + + if not isinstance(msg, StartupDetails): + raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") log = structlog.get_logger(logger_name="task") + # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 + os_type = sys.platform + if os_type == "darwin": + log.debug("Mac OS detected, skipping setproctitle") + else: + from setproctitle import setproctitle + + setproctitle(f"airflow worker -- {msg.ti.id}") + try: get_listener_manager().hook.on_starting(component=TaskRunnerMarker()) except Exception: log.exception("error calling listener") - if isinstance(msg, StartupDetails): - # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 - os_type = sys.platform - if os_type == "darwin": - log.debug("Mac OS detected, skipping setproctitle") - else: - from setproctitle import setproctitle - - setproctitle(f"airflow worker -- {msg.ti.id}") - - with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): - ti = parse(msg, log) - ti.log_url = get_log_url_from_ti(ti) - log.debug("DAG file parsed", file=msg.dag_rel_path) - else: - raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") + with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): + ti = parse(msg, log) + ti.log_url = get_log_url_from_ti(ti) + log.debug("DAG file parsed", file=msg.dag_rel_path) return ti, ti.get_template_context(), log @@ -785,7 +715,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv if rendered_fields := _serialize_rendered_fields(ti.task): # so that we do not call the API unnecessarily - SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) + SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) _validate_task_inlets_and_outlets(ti=ti, log=log) @@ -805,8 +735,7 @@ def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) - if not ti.task.inlets and not ti.task.outlets: return - SUPERVISOR_COMMS.send_request(msg=ValidateInletsAndOutlets(ti_id=ti.id), log=log) - inactive_assets_resp = SUPERVISOR_COMMS.get_message() + inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) if TYPE_CHECKING: assert isinstance(inactive_assets_resp, InactiveAssetsResult) if inactive_assets := inactive_assets_resp.inactive_assets: @@ -902,7 +831,7 @@ def run( except DownstreamTasksSkipped as skip: log.info("Skipping downstream tasks.") tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] - SUPERVISOR_COMMS.send_request(log=log, msg=SkipDownstreamTasks(tasks=tasks_to_skip)) + SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) msg, state = _handle_current_task_success(context, ti) except DagRunTriggerException as drte: msg, state = _handle_trigger_dag_run(drte, context, ti, log) @@ -964,7 +893,7 @@ def run( error = e finally: if msg: - SUPERVISOR_COMMS.send_request(msg=msg, log=log) + SUPERVISOR_COMMS.send(msg=msg) # Return the message to make unit tests easier too ti.state = state @@ -1002,9 +931,8 @@ def _handle_trigger_dag_run( ) -> tuple[ToSupervisor, TaskInstanceState]: """Handle exception from TriggerDagRunOperator.""" log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) - SUPERVISOR_COMMS.send_request( - log=log, - msg=TriggerDagRun( + comms_msg = SUPERVISOR_COMMS.send( + TriggerDagRun( dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id, logical_date=drte.logical_date, @@ -1013,7 +941,6 @@ def _handle_trigger_dag_run( ), ) - comms_msg = SUPERVISOR_COMMS.get_message() if isinstance(comms_msg, ErrorResponse) and comms_msg.error == ErrorType.DAGRUN_ALREADY_EXISTS: if drte.skip_when_already_exists: log.info( @@ -1068,10 +995,9 @@ def _handle_trigger_dag_run( ) time.sleep(drte.poke_interval) - SUPERVISOR_COMMS.send_request( - log=log, msg=GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) + comms_msg = SUPERVISOR_COMMS.send( + GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) ) - comms_msg = SUPERVISOR_COMMS.get_message() if TYPE_CHECKING: assert isinstance(comms_msg, DagRunStateResult) if comms_msg.state in drte.failed_states: @@ -1260,10 +1186,7 @@ def finalize( if getattr(ti.task, "overwrite_rtif_after_execution", False): log.debug("Overwriting Rendered template fields.") if ti.task.template_fields: - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)), - ) + SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) log.debug("Running finalizers", ti=ti) if state == TaskInstanceState.SUCCESS: @@ -1304,9 +1227,11 @@ def finalize( def main(): - # TODO: add an exception here, it causes an oof of a stack trace! + # TODO: add an exception here, it causes an oof of a stack trace if it happens to early! + log = structlog.get_logger(logger_name="task") + global SUPERVISOR_COMMS - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin) + SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) try: ti, context, log = startup() @@ -1317,11 +1242,9 @@ def main(): state, msg, error = run(ti, context, log) finalize(ti, state, context, log, error) except KeyboardInterrupt: - log = structlog.get_logger(logger_name="task") log.exception("Ctrl-c hit") exit(2) except Exception: - log = structlog.get_logger(logger_name="task") log.exception("Top level error") exit(1) finally: diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py new file mode 100644 index 00000000000..ee2f956b063 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid +from socket import socketpair + +import msgspec +import pytest + +from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails, _ResponseFrame +from airflow.sdk.execution_time.task_runner import CommsDecoder +from airflow.utils import timezone + + +class TestCommsDecoder: + """Test the communication between the subprocess and the "supervisor".""" + + @pytest.mark.usefixtures("disable_capturing") + def test_recv_StartupDetails(self): + r, w = socketpair() + + msg = { + "type": "StartupDetails", + "ti": { + "id": uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"), + "task_id": "a", + "try_number": 1, + "run_id": "b", + "dag_id": "c", + }, + "ti_context": { + "dag_run": { + "dag_id": "c", + "run_id": "b", + "logical_date": "2024-12-01T01:00:00Z", + "data_interval_start": "2024-12-01T00:00:00Z", + "data_interval_end": "2024-12-01T01:00:00Z", + "start_date": "2024-12-01T01:00:00Z", + "run_after": "2024-12-01T01:00:00Z", + "end_date": None, + "run_type": "manual", + "conf": None, + "consumed_asset_events": [], + }, + "max_tries": 0, + "should_retry": False, + "variables": None, + "connections": None, + }, + "file": "/dev/null", + "start_date": "2024-12-01T01:00:00Z", + "dag_rel_path": "/dev/null", + "bundle_info": {"name": "any-name", "version": "any-version"}, + } + bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None)) + w.sendall(len(bytes).to_bytes(4, byteorder="big") + bytes) + + decoder = CommsDecoder(request_socket=r, log=None) + + msg = decoder._get_response() + assert isinstance(msg, StartupDetails) + assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") + assert msg.ti.task_id == "a" + assert msg.ti.dag_id == "c" + assert msg.dag_rel_path == "/dev/null" + assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") + assert msg.start_date == timezone.datetime(2024, 12, 1, 1) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 4f5e4cfc7ac..3fa05e7d757 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -27,13 +27,14 @@ import signal import socket import sys import time -from io import BytesIO from operator import attrgetter +from random import randint from time import sleep from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import httpx +import msgspec import psutil import pytest from pytest_unordered import unordered @@ -56,6 +57,7 @@ from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, + CommsDecoder, ConnectionResult, DagRunStateResult, DeferTask, @@ -97,17 +99,16 @@ from airflow.sdk.execution_time.comms import ( XComResult, XComSequenceIndexResult, XComSequenceSliceResult, + _RequestFrame, + _ResponseFrame, ) from airflow.sdk.execution_time.supervisor import ( - BUFFER_SIZE, ActivitySubprocess, InProcessSupervisorComms, InProcessTestSupervisor, - mkpipe, set_supervisor_comms, supervise, ) -from airflow.sdk.execution_time.task_runner import CommsDecoder from airflow.utils import timezone, timezone as tz if TYPE_CHECKING: @@ -136,18 +137,25 @@ def local_dag_bundle_cfg(path, name="my-bundle"): } +@pytest.fixture +def client_with_ti_start(make_ti_context): + client = MagicMock(spec=sdk_client.Client) + client.task_instances.start.return_value = make_ti_context() + return client + + @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: @pytest.fixture(autouse=True) def disable_log_upload(self, spy_agency): spy_agency.spy_on(ActivitySubprocess._upload_logs, call_original=False) - def test_reading_from_pipes(self, captured_logs, time_machine): + def test_reading_from_pipes(self, captured_logs, time_machine, client_with_ti_start): def subprocess_main(): # This is run in the subprocess! - # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + # Ensure we follow the "protocol" and get the startup message before we do anything else + CommsDecoder()._get_response() import logging import warnings @@ -180,7 +188,7 @@ class TestWatchedSubprocess: run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -228,12 +236,12 @@ class TestWatchedSubprocess: ] ) - def test_subprocess_sigkilled(self): + def test_subprocess_sigkilled(self, client_with_ti_start): main_pid = os.getpid() def subprocess_main(): # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + CommsDecoder()._get_response() assert os.getpid() != main_pid os.kill(os.getpid(), signal.SIGKILL) @@ -248,7 +256,7 @@ class TestWatchedSubprocess: run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -285,7 +293,7 @@ class TestWatchedSubprocess: monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() for _ in range(5): print("output", flush=True) @@ -314,7 +322,7 @@ class TestWatchedSubprocess: monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.1) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() for _ in range(5): print("output", flush=True) @@ -340,7 +348,7 @@ class TestWatchedSubprocess: assert proc.wait() == 0 spy_agency.assert_spy_not_called(heartbeat_spy) - def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, make_ti_context): + def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker, client_with_ti_start): """Test running a simple DAG in a subprocess and capturing the output.""" instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901) @@ -355,11 +363,6 @@ class TestWatchedSubprocess: try_number=1, ) - # Create a mock client to assert calls to the client - # We assume the implementation of the client is correct and only need to check the calls - mock_client = mocker.Mock(spec=sdk_client.Client) - mock_client.task_instances.start.return_value = make_ti_context() - bundle_info = BundleInfo(name="my-bundle", version=None) with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): exit_code = supervise( @@ -368,7 +371,7 @@ class TestWatchedSubprocess: token="", server="", dry_run=True, - client=mock_client, + client=client_with_ti_start, bundle_info=bundle_info, ) assert exit_code == 0, captured_logs @@ -498,7 +501,7 @@ class TestWatchedSubprocess: monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "MIN_HEARTBEAT_INTERVAL", 0.0) def subprocess_main(): - sys.stdin.readline() + CommsDecoder()._get_response() sleep(5) # Shouldn't get here exit(5) @@ -611,7 +614,6 @@ class TestWatchedSubprocess: stdin=mocker.MagicMock(), client=client, process=mock_process, - requests_fd=-1, ) time_now = tz.datetime(2024, 11, 28, 12, 0, 0) @@ -701,7 +703,6 @@ class TestWatchedSubprocess: stdin=mocker.Mock(), process=mocker.Mock(), client=mocker.Mock(), - requests_fd=-1, ) # Set the terminal state and task end datetime @@ -738,7 +739,7 @@ class TestWatchedSubprocess: ), ), ) - def test_exit_by_signal(self, monkeypatch, signal_to_raise, log_pattern, cap_structlog): + def test_exit_by_signal(self, signal_to_raise, log_pattern, cap_structlog, client_with_ti_start): def subprocess_main(): import faulthandler import os @@ -748,7 +749,7 @@ class TestWatchedSubprocess: faulthandler.disable() # Ensure we follow the "protocol" and get the startup message before we do anything - sys.stdin.readline() + CommsDecoder()._get_response() os.kill(os.getpid(), signal_to_raise) @@ -762,7 +763,7 @@ class TestWatchedSubprocess: run_id="d", try_number=1, ), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -791,26 +792,26 @@ class TestWatchedSubprocess: stdin=mocker.MagicMock(), client=mocker.MagicMock(), process=mock_process, - requests_fd=-1, ) proc.selector = mocker.MagicMock() proc.selector.select.return_value = [] proc._exit_code = 0 - proc._num_open_sockets = 1 + # Create a dummy placeholder in the open socket weekref + proc._open_sockets[mocker.MagicMock()] = "test placeholder" proc._process_exit_monotonic = time.monotonic() mocker.patch.object( ActivitySubprocess, "_cleanup_open_sockets", - side_effect=lambda: setattr(proc, "_num_open_sockets", 0), + side_effect=lambda: setattr(proc, "_open_sockets", {}), ) time_machine.shift(2) proc._monitor_subprocess() - assert proc._num_open_sockets == 0 + assert len(proc._open_sockets) == 0 class TestWatchedSubprocessKill: @@ -829,7 +830,6 @@ class TestWatchedSubprocessKill: stdin=mocker.Mock(), client=mocker.Mock(), process=mock_process, - requests_fd=-1, ) # Mock the selector mock_selector = mocker.Mock(spec=selectors.DefaultSelector) @@ -888,7 +888,7 @@ class TestWatchedSubprocessKill: ), ], ) - def test_kill_escalation_path(self, signal_to_send, exit_after, mocker, captured_logs, monkeypatch): + def test_kill_escalation_path(self, signal_to_send, exit_after, captured_logs, client_with_ti_start): def subprocess_main(): import signal @@ -905,7 +905,7 @@ class TestWatchedSubprocessKill: signal.signal(signal.SIGINT, _handler) signal.signal(signal.SIGTERM, _handler) try: - sys.stdin.readline() + CommsDecoder()._get_response() print("Ready") sleep(10) except Exception as e: @@ -919,7 +919,7 @@ class TestWatchedSubprocessKill: dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1), - client=MagicMock(spec=sdk_client.Client), + client=client_with_ti_start, target=subprocess_main, ) @@ -1073,16 +1073,15 @@ class TestWatchedSubprocessKill: class TestHandleRequest: @pytest.fixture def watched_subprocess(self, mocker): - read_end, write_end = mkpipe(remote_read=True) + read_end, write_end = socket.socketpair() subprocess = ActivitySubprocess( process_log=mocker.MagicMock(), id=TI_ID, pid=12345, - stdin=write_end, # this is the writer side + stdin=write_end, client=mocker.Mock(), process=mocker.Mock(), - requests_fd=-1, ) return subprocess, read_end @@ -1091,7 +1090,7 @@ class TestHandleRequest: @pytest.mark.parametrize( [ "message", - "expected_buffer", + "expected_body", "client_attr_path", "method_arg", "method_kwarg", @@ -1101,7 +1100,7 @@ class TestHandleRequest: [ pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n', + {"conn_id": "test_conn", "conn_type": "mysql", "type": "ConnectionResult"}, "connections.get", ("test_conn",), {}, @@ -1111,7 +1110,12 @@ class TestHandleRequest: ), pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","password":"password","type":"ConnectionResult"}\n', + { + "conn_id": "test_conn", + "conn_type": "mysql", + "password": "password", + "type": "ConnectionResult", + }, "connections.get", ("test_conn",), {}, @@ -1121,7 +1125,7 @@ class TestHandleRequest: ), pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql","schema":"mysql","type":"ConnectionResult"}\n', + {"conn_id": "test_conn", "conn_type": "mysql", "schema": "mysql", "type": "ConnectionResult"}, "connections.get", ("test_conn",), {}, @@ -1131,7 +1135,7 @@ class TestHandleRequest: ), pytest.param( GetVariable(key="test_key"), - b'{"key":"test_key","value":"test_value","type":"VariableResult"}\n', + {"key": "test_key", "value": "test_value", "type": "VariableResult"}, "variables.get", ("test_key",), {}, @@ -1141,7 +1145,7 @@ class TestHandleRequest: ), pytest.param( PutVariable(key="test_key", value="test_value", description="test_description"), - b"", + None, "variables.set", ("test_key", "test_value", "test_description"), {}, @@ -1151,7 +1155,7 @@ class TestHandleRequest: ), pytest.param( DeleteVariable(key="test_key"), - b'{"ok":true,"type":"OKResponse"}\n', + {"ok": True, "type": "OKResponse"}, "variables.delete", ("test_key",), {}, @@ -1161,7 +1165,7 @@ class TestHandleRequest: ), pytest.param( DeferTask(next_method="execute_callback", classpath="my-classpath"), - b"", + None, "task_instances.defer", (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), {}, @@ -1174,7 +1178,7 @@ class TestHandleRequest: reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), end_date=timezone.parse("2024-10-31T12:00:00Z"), ), - b"", + None, "task_instances.reschedule", ( TI_ID, @@ -1190,7 +1194,7 @@ class TestHandleRequest: ), pytest.param( GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + {"key": "test_key", "value": "test_value", "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, False), {}, @@ -1202,7 +1206,7 @@ class TestHandleRequest: GetXCom( dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2 ), - b'{"key":"test_key","value":"test_value","type":"XComResult"}\n', + {"key": "test_key", "value": "test_value", "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", 2, False), {}, @@ -1212,7 +1216,7 @@ class TestHandleRequest: ), pytest.param( GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":null,"type":"XComResult"}\n', + {"key": "test_key", "value": None, "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, False), {}, @@ -1228,7 +1232,7 @@ class TestHandleRequest: key="test_key", include_prior_dates=True, ), - b'{"key":"test_key","value":null,"type":"XComResult"}\n', + {"key": "test_key", "value": None, "type": "XComResult"}, "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", None, True), {}, @@ -1244,7 +1248,7 @@ class TestHandleRequest: key="test_key", value='{"key": "test_key", "value": {"key2": "value2"}}', ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1269,7 +1273,7 @@ class TestHandleRequest: value='{"key": "test_key", "value": {"key2": "value2"}}', map_index=2, ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1295,7 +1299,7 @@ class TestHandleRequest: map_index=2, mapped_length=3, ), - b"", + None, "xcoms.set", ( "test_dag", @@ -1319,15 +1323,9 @@ class TestHandleRequest: key="test_key", map_index=2, ), - b"", + None, "xcoms.delete", - ( - "test_dag", - "test_run", - "test_task", - "test_key", - 2, - ), + ("test_dag", "test_run", "test_task", "test_key", 2), {}, OKResponse(ok=True), None, @@ -1337,7 +1335,7 @@ class TestHandleRequest: # if it can handle TaskState message pytest.param( TaskState(state=TaskInstanceState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), - b"", + None, "", (), {}, @@ -1349,7 +1347,7 @@ class TestHandleRequest: RetryTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test retry task" ), - b"", + None, "task_instances.retry", (), { @@ -1363,7 +1361,7 @@ class TestHandleRequest: ), pytest.param( SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}), - b"", + None, "task_instances.set_rtif", (TI_ID, {"field1": "rendered_value1", "field2": "rendered_value2"}), {}, @@ -1373,7 +1371,7 @@ class TestHandleRequest: ), pytest.param( GetAssetByName(name="asset"), - b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, "assets.get", [], {"name": "asset"}, @@ -1383,7 +1381,7 @@ class TestHandleRequest: ), pytest.param( GetAssetByUri(uri="s3://bucket/obj"), - b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n', + {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", "type": "AssetResult"}, "assets.get", [], {"uri": "s3://bucket/obj"}, @@ -1393,11 +1391,17 @@ class TestHandleRequest: ), pytest.param( GetAssetEventByAsset(uri="s3://bucket/obj", name="test"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": "s3://bucket/obj", "name": "test"}, @@ -1416,11 +1420,17 @@ class TestHandleRequest: ), pytest.param( GetAssetEventByAsset(uri="s3://bucket/obj", name=None), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": "s3://bucket/obj", "name": None}, @@ -1439,11 +1449,17 @@ class TestHandleRequest: ), pytest.param( GetAssetEventByAsset(uri=None, name="test"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"uri": None, "name": "test"}, @@ -1462,11 +1478,17 @@ class TestHandleRequest: ), pytest.param( GetAssetEventByAssetAlias(alias_name="test_alias"), - ( - b'{"asset_events":' - b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},' - b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n' - ), + { + "asset_events": [ + { + "id": 1, + "timestamp": timezone.parse("2024-10-31T12:00:00Z"), + "asset": {"name": "asset", "uri": "s3://bucket/obj", "group": "asset"}, + "created_dagruns": [], + } + ], + "type": "AssetEventsResult", + }, "asset_events.get", [], {"alias_name": "test_alias"}, @@ -1485,7 +1507,10 @@ class TestHandleRequest: ), pytest.param( ValidateInletsAndOutlets(ti_id=TI_ID), - b'{"inactive_assets":[{"name":"asset_name","uri":"asset_uri","type":"asset"}],"type":"InactiveAssetsResult"}\n', + { + "inactive_assets": [{"name": "asset_name", "uri": "asset_uri", "type": "asset"}], + "type": "InactiveAssetsResult", + }, "task_instances.validate_inlets_and_outlets", (TI_ID,), {}, @@ -1499,7 +1524,7 @@ class TestHandleRequest: SucceedTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test success task" ), - b"", + None, "task_instances.succeed", (), { @@ -1515,11 +1540,13 @@ class TestHandleRequest: ), pytest.param( GetPrevSuccessfulDagRun(ti_id=TI_ID), - ( - b'{"data_interval_start":"2025-01-10T12:00:00Z","data_interval_end":"2025-01-10T14:00:00Z",' - b'"start_date":"2025-01-10T12:00:00Z","end_date":"2025-01-10T14:00:00Z",' - b'"type":"PrevSuccessfulDagRunResult"}\n' - ), + { + "data_interval_start": timezone.parse("2025-01-10T12:00:00Z"), + "data_interval_end": timezone.parse("2025-01-10T14:00:00Z"), + "start_date": timezone.parse("2025-01-10T12:00:00Z"), + "end_date": timezone.parse("2025-01-10T14:00:00Z"), + "type": "PrevSuccessfulDagRunResult", + }, "task_instances.get_previous_successful_dagrun", (TI_ID,), {}, @@ -1540,7 +1567,7 @@ class TestHandleRequest: logical_date=timezone.datetime(2025, 1, 1), reset_dag_run=True, ), - b'{"ok":true,"type":"OKResponse"}\n', + {"ok": True, "type": "OKResponse"}, "dag_runs.trigger", ("test_dag", "test_run", {"key": "value"}, timezone.datetime(2025, 1, 1), True), {}, @@ -1549,8 +1576,9 @@ class TestHandleRequest: id="dag_run_trigger", ), pytest.param( + # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR TriggerDagRun(dag_id="test_dag", run_id="test_run"), - b'{"error":"DAGRUN_ALREADY_EXISTS","detail":null,"type":"ErrorResponse"}\n', + {"error": "DAGRUN_ALREADY_EXISTS", "detail": None, "type": "ErrorResponse"}, "dag_runs.trigger", ("test_dag", "test_run", None, None, False), {}, @@ -1560,7 +1588,7 @@ class TestHandleRequest: ), pytest.param( GetDagRunState(dag_id="test_dag", run_id="test_run"), - b'{"state":"running","type":"DagRunStateResult"}\n', + {"state": "running", "type": "DagRunStateResult"}, "dag_runs.get_state", ("test_dag", "test_run"), {}, @@ -1570,7 +1598,7 @@ class TestHandleRequest: ), pytest.param( GetTaskRescheduleStartDate(ti_id=TI_ID), - b'{"start_date":"2024-10-31T12:00:00Z","type":"TaskRescheduleStartDate"}\n', + {"start_date": timezone.parse("2024-10-31T12:00:00Z"), "type": "TaskRescheduleStartDate"}, "task_instances.get_reschedule_start_date", (TI_ID, 1), {}, @@ -1580,7 +1608,7 @@ class TestHandleRequest: ), pytest.param( GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]), - b'{"count":2,"type":"TICount"}\n', + {"count": 2, "type": "TICount"}, "task_instances.get_count", (), { @@ -1598,7 +1626,7 @@ class TestHandleRequest: ), pytest.param( GetDRCount(dag_id="test_dag", states=["success", "failed"]), - b'{"count":2,"type":"DRCount"}\n', + {"count": 2, "type": "DRCount"}, "dag_runs.get_count", (), { @@ -1613,7 +1641,10 @@ class TestHandleRequest: ), pytest.param( GetTaskStates(dag_id="test_dag", task_group_id="test_group"), - b'{"task_states":{"run_id":{"task1":"success","task2":"failed"}},"type":"TaskStatesResult"}\n', + { + "task_states": {"run_id": {"task1": "success", "task2": "failed"}}, + "type": "TaskStatesResult", + }, "task_instances.get_task_states", (), { @@ -1636,7 +1667,7 @@ class TestHandleRequest: task_id="test_task", offset=0, ), - b'{"root":"test_value","type":"XComSequenceIndexResult"}\n', + {"root": "test_value", "type": "XComSequenceIndexResult"}, "xcoms.get_sequence_item", ("test_dag", "test_run", "test_task", "test_key", 0), {}, @@ -1645,6 +1676,7 @@ class TestHandleRequest: id="get_xcom_seq_item", ), pytest.param( + # TODO: This should be raise an exception, not returning an ErrorResponse. Fix this before PR GetXComSequenceItem( key="test_key", dag_id="test_dag", @@ -1652,7 +1684,7 @@ class TestHandleRequest: task_id="test_task", offset=2, ), - b'{"error":"XCOM_NOT_FOUND","detail":null,"type":"ErrorResponse"}\n', + {"error": "XCOM_NOT_FOUND", "detail": None, "type": "ErrorResponse"}, "xcoms.get_sequence_item", ("test_dag", "test_run", "test_task", "test_key", 2), {}, @@ -1670,7 +1702,7 @@ class TestHandleRequest: stop=None, step=None, ), - b'{"root":["foo","bar"],"type":"XComSequenceSliceResult"}\n', + {"root": ["foo", "bar"], "type": "XComSequenceSliceResult"}, "xcoms.get_sequence_slice", ("test_dag", "test_run", "test_task", "test_key", None, None, None), {}, @@ -1687,7 +1719,7 @@ class TestHandleRequest: mocker, time_machine, message, - expected_buffer, + expected_body, client_attr_path, method_arg, method_kwarg, @@ -1715,8 +1747,9 @@ class TestHandleRequest: generator = watched_subprocess.handle_requests(log=mocker.Mock()) # Initialize the generator next(generator) - msg = message.model_dump_json().encode() + b"\n" - generator.send(msg) + + req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=message.model_dump()) + generator.send(req_frame) if mask_secret_args: mock_mask_secret.assert_called_with(*mask_secret_args) @@ -1729,33 +1762,23 @@ class TestHandleRequest: # Read response from the read end of the socket read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - chunk = read_socket.recv(BUFFER_SIZE) - if not chunk: - break - val += chunk - except (BlockingIOError, TimeoutError, socket.timeout): - # no response written, valid for some message types like setters and TI operations. - pass + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.id == req_frame.id # Verify the response was added to the buffer - assert val == expected_buffer + assert frame.body == expected_body # Verify the response is correctly decoded # This is important because the subprocess/task runner will read the response # and deserialize it to the correct message type - # Only decode the buffer if it contains data. An empty buffer implies no response was written. - if not val and (mock_response and not isinstance(mock_response, OKResponse)): - pytest.fail("Expected a response, but got an empty buffer.") - - if val: + if frame.body is not None: # Using BytesIO to simulate a readable stream for CommsDecoder. - input_stream = BytesIO(val) - decoder = CommsDecoder(input=input_stream) - assert decoder.get_message() == mock_response + decoder = CommsDecoder(request_socket=None).body_decoder + assert decoder.validate_python(frame.body) == mock_response def test_handle_requests_api_server_error(self, watched_subprocess, mocker): """Test that API server errors are properly handled and sent back to the task.""" @@ -1777,28 +1800,32 @@ class TestHandleRequest: next(generator) - msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")).model_dump_json().encode() + b"\n" - generator.send(msg) + msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")) + req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=msg) + generator.send(req_frame) - # Read response directly from the reader socket + # Read response from the read end of the socket read_socket.settimeout(0.1) - val = b"" - try: - while not val.endswith(b"\n"): - val += read_socket.recv(4096) - except (BlockingIOError, TimeoutError): - pass - - assert val == ( - b'{"error":"API_SERVER_ERROR","detail":{"status_code":500,"message":"API Server Error",' - b'"detail":{"detail":"Internal Server Error"}},"type":"ErrorResponse"}\n' - ) + frame_len = int.from_bytes(read_socket.recv(4), "big") + bytes = read_socket.recv(frame_len) + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes) + + assert frame.id == req_frame.id + + assert frame.error == { + "error": "API_SERVER_ERROR", + "detail": { + "status_code": 500, + "message": "API Server Error", + "detail": {"detail": "Internal Server Error"}, + }, + "type": "ErrorResponse", + } # Verify the error can be decoded correctly - input_stream = BytesIO(val) - decoder = CommsDecoder(input=input_stream) + comms = CommsDecoder(request_socket=None) with pytest.raises(AirflowRuntimeError) as exc_info: - decoder.get_message() + comms._from_frame(frame) assert exc_info.value.error.error == ErrorType.API_SERVER_ERROR assert exc_info.value.error.detail == { @@ -1871,7 +1898,6 @@ class TestInProcessTestSupervisor: supervisor = MinimalSupervisor( id="test", pid=123, - requests_fd=-1, process=MagicMock(), process_log=MagicMock(), client=MagicMock(), diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index a994e995bac..ab91dbcf14a 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -22,11 +22,9 @@ import functools import json import os import textwrap -import uuid from collections.abc import Iterable from datetime import datetime, timedelta from pathlib import Path -from socket import socketpair from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch @@ -101,7 +99,6 @@ from airflow.sdk.execution_time.context import ( VariableAccessor, ) from airflow.sdk.execution_time.task_runner import ( - CommsDecoder, RuntimeTaskInstance, TaskRunnerMarker, _push_xcom_if_needed, @@ -137,47 +134,6 @@ class CustomOperator(BaseOperator): print(f"Hello World {task_id}!") -class TestCommsDecoder: - """Test the communication between the subprocess and the "supervisor".""" - - @pytest.mark.usefixtures("disable_capturing") - def test_recv_StartupDetails(self): - r, w = socketpair() - # Create a valid FD for the decoder to open - _, w2 = socketpair() - - w.makefile("wb").write( - b'{"type":"StartupDetails", "ti": {' - b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run_id": "b", ' - b'"dag_id": "c"}, "ti_context":{"dag_run":{"dag_id":"c","run_id":"b",' - b'"logical_date":"2024-12-01T01:00:00Z",' - b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",' - b'"start_date":"2024-12-01T01:00:00Z","run_after":"2024-12-01T01:00:00Z","end_date":null,' - b'"run_type":"manual","conf":null,"consumed_asset_events":[]},' - b'"max_tries":0,"should_retry":false,"variables":null,"connections":null},"file": "/dev/null",' - b'"start_date":"2024-12-01T01:00:00Z", "dag_rel_path": "/dev/null", "bundle_info": {"name": ' - b'"any-name", "version": "any-version"}, "requests_fd": ' - + str(w2.fileno()).encode("ascii") - + b"}\n" - ) - - decoder = CommsDecoder(input=r.makefile("r")) - - msg = decoder.get_message() - assert isinstance(msg, StartupDetails) - assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") - assert msg.ti.task_id == "a" - assert msg.ti.dag_id == "c" - assert msg.dag_rel_path == "/dev/null" - assert msg.bundle_info == BundleInfo(name="any-name", version="any-version") - assert msg.start_date == timezone.datetime(2024, 12, 1, 1) - - # Since this was a StartupDetails message, the decoder should open the other socket - assert decoder.request_socket is not None - assert decoder.request_socket.writable() - assert decoder.request_socket.fileno() == w2.fileno() - - def test_parse(test_dags_dir: Path, make_ti_context): """Test that checks parsing of a basic dag with an un-mocked parse.""" what = StartupDetails( @@ -190,7 +146,6 @@ def test_parse(test_dags_dir: Path, make_ti_context): ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -246,7 +201,6 @@ def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id, ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -300,7 +254,6 @@ def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context): ), dag_rel_path="path_test.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -571,7 +524,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm ), bundle_info=FAKE_BUNDLE, dag_rel_path="", - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -687,7 +639,6 @@ def test_startup_and_run_dag_with_rtif( ), dag_rel_path="", bundle_info=FAKE_BUNDLE, - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), ) @@ -832,7 +783,6 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), dag_rel_path="dag_parsing_context.py", bundle_info=BundleInfo(name="my-bundle", version=None), - requests_fd=0, ti_context=make_ti_context(dag_id=dag_id, run_id="c"), start_date=timezone.utcnow(), ) @@ -2190,7 +2140,6 @@ class TestTaskRunnerCallsListeners: ), dag_rel_path="", bundle_info=FAKE_BUNDLE, - requests_fd=0, ti_context=make_ti_context(), start_date=timezone.utcnow(), )