kaxil commented on code in PR #51699:
URL: https://github.com/apache/airflow/pull/51699#discussion_r2150190057
##########
airflow-core/src/airflow/dag_processing/processor.py:
##########
@@ -301,18 +296,24 @@ def _handle_request(self, msg: ToManager, log:
FilteringBoundLogger) -> None: #
resp = self.client.variables.delete(msg.key)
else:
log.error("Unhandled request", msg=msg)
+ self.send_msg(
+ None,
+ in_response_to=req_id,
Review Comment:
Should we call this arg as `request_id` instead of `in_response_to`?
##########
airflow-core/src/airflow/dag_processing/processor.py:
##########
@@ -102,18 +101,16 @@ class DagFileParsingResult(BaseModel):
def _parse_file_entrypoint():
import structlog
- from airflow.sdk.execution_time import task_runner
+ from airflow.sdk.execution_time import comms, task_runner
# Parse DAG file, send JSON back up!
- comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager](
- input=sys.stdin,
- decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
+ comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager](
+ body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
)
- msg = comms_decoder.get_message()
+ msg = comms_decoder._get_response()
Review Comment:
Why not `msg = comms_decoder.send()`
##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -691,14 +690,58 @@ class TriggerDetails(TypedDict):
events: int
[email protected](kw_only=True)
+class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]):
+ _async_writer: asyncio.StreamWriter = attrs.field(alias="async_writer")
+ _async_reader: asyncio.StreamReader = attrs.field(alias="async_reader")
+
+ body_decoder: TypeAdapter[ToTriggerRunner] = attrs.field(
+ factory=lambda: TypeAdapter(ToTriggerRunner), repr=False
+ )
+
+ _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)
+
+ def _read_frame(self):
+ from asgiref.sync import async_to_sync
+
+ return async_to_sync(self._aread_frame)()
+
+ def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
+ from asgiref.sync import async_to_sync
+
+ return async_to_sync(self.asend)(msg)
+
+ async def _aread_frame(self):
+ len_bytes = await self._async_reader.readexactly(4)
+ len = int.from_bytes(len_bytes, byteorder="big")
+
+ buffer = await self._async_reader.readexactly(len)
+ return self.resp_decoder.decode(buffer)
+
+ async def _aget_response(self, expect_id: int) -> ToTriggerRunner | None:
+ frame = await self._aread_frame()
+ if frame.id != expect_id:
+ # Given the lock we take out in `asend`, this _shouldn't_ be
possible, but I'd rather fail with
+ # this explicit error return the wrong type of message back to a
Trigger
+ raise RuntimeError(f"Response read out of order! Got {frame.id=},
{expect_id=}")
+ return self._from_frame(frame)
+
+ async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
+ frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
+ bytes = frame.as_bytes()
+
+ async with self._lock:
+ self._async_writer.write(bytes)
Review Comment:
I think we need `await self._async_writer.drain()` after this line too based
on
https://docs.python.org/3/library/asyncio-stream.html#asyncio.StreamWriter.write
##########
airflow-core/src/airflow/dag_processing/processor.py:
##########
@@ -266,20 +263,18 @@ def _on_child_started(
msg = DagFileParseRequest(
file=os.fspath(path),
bundle_path=bundle_path,
- requests_fd=self._requests_fd,
callback_requests=callbacks,
)
- self.send_msg(msg)
+ self.send_msg(msg, in_response_to=0)
Review Comment:
Is `in_response_to` auto-incrementing or we can provide any request_id like
`0` here?
##########
devel-common/src/tests_common/pytest_plugin.py:
##########
@@ -1956,17 +1956,20 @@ def override_caplog(request):
@pytest.fixture
-def mock_supervisor_comms():
+def mock_supervisor_comms(monkeypatch):
# for back-compat
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
if not AIRFLOW_V_3_0_PLUS:
yield None
return
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
- ) as supervisor_comms:
- yield supervisor_comms
+
+ import airflow.sdk.execution_time.task_runner
+ from airflow.sdk.execution_time.comms import CommsDecoder
+
+ comms = mock.create_autospec(CommsDecoder)
+ monkeypatch.setattr(airflow.sdk.execution_time.task_runner,
"SUPERVISOR_COMMS", comms, raising=False)
Review Comment:
```suggestion
from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.comms import CommsDecoder
comms = mock.create_autospec(CommsDecoder)
monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms,
raising=False)
```
##########
task-sdk/src/airflow/sdk/execution_time/comms.py:
##########
@@ -80,20 +90,152 @@
)
from airflow.sdk.exceptions import ErrorType
+if TYPE_CHECKING:
+ from structlog.typing import FilteringBoundLogger as Logger
+
+SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
+ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
+
+
+def _msgpack_enc_hook(obj: Any) -> Any:
+ import pendulum
+
+ if isinstance(obj, pendulum.DateTime):
+ # convert the complex to a tuple of real, imag
+ return datetime(
+ obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second,
obj.microsecond, tzinfo=obj.tzinfo
+ )
+ if isinstance(obj, Path):
+ return str(obj)
+ if isinstance(obj, BaseModel):
+ return obj.model_dump(exclude_unset=True)
+
+ # Raise a NotImplementedError for other types
+ raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
+
+
+def _new_encoder() -> msgspec.msgpack.Encoder:
+ return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook)
+
+
+class _RequestFrame(msgspec.Struct, array_like=True, frozen=True,
omit_defaults=True):
+ id: int
+ """
+ The request id, set by the sender.
+
+ This is used to allow "pipeling" of requests and to be able to tie
response to requests, which is
+ particularly useful in the Triggerer where multiple async tasks can send a
requests concurrently.
+ """
+ body: dict[str, Any] | None
+
+ req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder()
+
+ def as_bytes(self) -> bytearray:
+ # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing
for inspiration
+ buffer = bytearray(256)
+
+ self.req_encoder.encode_into(self, buffer, 4)
+
+ n = len(buffer) - 4
+ if n > 2**32:
Review Comment:
Shouldn't this be `>=`?
##########
task-sdk/src/airflow/sdk/execution_time/comms.py:
##########
@@ -80,20 +90,152 @@
)
from airflow.sdk.exceptions import ErrorType
+if TYPE_CHECKING:
+ from structlog.typing import FilteringBoundLogger as Logger
+
+SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
+ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
+
+
+def _msgpack_enc_hook(obj: Any) -> Any:
+ import pendulum
+
+ if isinstance(obj, pendulum.DateTime):
+ # convert the complex to a tuple of real, imag
+ return datetime(
+ obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second,
obj.microsecond, tzinfo=obj.tzinfo
+ )
+ if isinstance(obj, Path):
+ return str(obj)
+ if isinstance(obj, BaseModel):
+ return obj.model_dump(exclude_unset=True)
+
+ # Raise a NotImplementedError for other types
+ raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
+
+
+def _new_encoder() -> msgspec.msgpack.Encoder:
+ return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook)
+
+
+class _RequestFrame(msgspec.Struct, array_like=True, frozen=True,
omit_defaults=True):
+ id: int
+ """
+ The request id, set by the sender.
+
+ This is used to allow "pipeling" of requests and to be able to tie
response to requests, which is
+ particularly useful in the Triggerer where multiple async tasks can send a
requests concurrently.
+ """
+ body: dict[str, Any] | None
+
+ req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder()
+
+ def as_bytes(self) -> bytearray:
+ # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing
for inspiration
+ buffer = bytearray(256)
+
+ self.req_encoder.encode_into(self, buffer, 4)
+
+ n = len(buffer) - 4
+ if n > 2**32:
+ raise OverflowError("Cannot send messages larger than 4GiB")
+ buffer[:4] = n.to_bytes(4, byteorder="big")
+
+ return buffer
+
+
+class _ResponseFrame(_RequestFrame, msgspec.Struct, array_like=True,
frozen=True, omit_defaults=True):
+ id: int
+ """
+ The id of the request this is a response to
+ """
+ body: dict[str, Any] | None = None
+ error: dict[str, Any] | None = None
+
+
[email protected]()
+class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
+ """Handle communication between the task in this process and the
supervisor parent process."""
+
+ log: Logger = attrs.field(repr=False, factory=structlog.get_logger)
+ request_socket: socket = attrs.field(factory=lambda: socket(fileno=0))
Review Comment:
This is now both: request and response. I don't have a good suggestion for
the name (`comms_socket` -- maybe?) -- so probably fine to keep it🤷
##########
task-sdk/src/airflow/sdk/execution_time/comms.py:
##########
@@ -80,20 +90,152 @@
)
from airflow.sdk.exceptions import ErrorType
+if TYPE_CHECKING:
+ from structlog.typing import FilteringBoundLogger as Logger
+
+SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
+ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
+
+
+def _msgpack_enc_hook(obj: Any) -> Any:
+ import pendulum
+
+ if isinstance(obj, pendulum.DateTime):
+ # convert the complex to a tuple of real, imag
+ return datetime(
+ obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second,
obj.microsecond, tzinfo=obj.tzinfo
+ )
+ if isinstance(obj, Path):
+ return str(obj)
+ if isinstance(obj, BaseModel):
+ return obj.model_dump(exclude_unset=True)
+
+ # Raise a NotImplementedError for other types
+ raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
+
+
+def _new_encoder() -> msgspec.msgpack.Encoder:
+ return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook)
+
+
+class _RequestFrame(msgspec.Struct, array_like=True, frozen=True,
omit_defaults=True):
+ id: int
+ """
+ The request id, set by the sender.
+
+ This is used to allow "pipeling" of requests and to be able to tie
response to requests, which is
+ particularly useful in the Triggerer where multiple async tasks can send a
requests concurrently.
+ """
+ body: dict[str, Any] | None
+
+ req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder()
+
+ def as_bytes(self) -> bytearray:
+ # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing
for inspiration
+ buffer = bytearray(256)
+
+ self.req_encoder.encode_into(self, buffer, 4)
+
+ n = len(buffer) - 4
+ if n > 2**32:
+ raise OverflowError("Cannot send messages larger than 4GiB")
+ buffer[:4] = n.to_bytes(4, byteorder="big")
+
+ return buffer
+
+
+class _ResponseFrame(_RequestFrame, msgspec.Struct, array_like=True,
frozen=True, omit_defaults=True):
Review Comment:
```suggestion
class _ResponseFrame(_RequestFrame):
```
Should be ok?
##########
task-sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -1441,7 +1445,6 @@ def cb(sock: socket):
with suppress(StopIteration):
gen.send(buffer)
# Tell loop to close this selector
Review Comment:
```suggestion
```
##########
devel-common/src/tests_common/pytest_plugin.py:
##########
@@ -1956,17 +1956,20 @@ def override_caplog(request):
@pytest.fixture
-def mock_supervisor_comms():
+def mock_supervisor_comms(monkeypatch):
# for back-compat
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
if not AIRFLOW_V_3_0_PLUS:
yield None
return
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
- ) as supervisor_comms:
- yield supervisor_comms
+
+ import airflow.sdk.execution_time.task_runner
+ from airflow.sdk.execution_time.comms import CommsDecoder
+
+ comms = mock.create_autospec(CommsDecoder)
+ monkeypatch.setattr(airflow.sdk.execution_time.task_runner,
"SUPERVISOR_COMMS", comms, raising=False)
Review Comment:
nit -- feel free to ignore
##########
airflow-core/src/airflow/dag_processing/processor.py:
##########
@@ -266,20 +263,18 @@ def _on_child_started(
msg = DagFileParseRequest(
file=os.fspath(path),
bundle_path=bundle_path,
- requests_fd=self._requests_fd,
callback_requests=callbacks,
)
- self.send_msg(msg)
+ self.send_msg(msg, in_response_to=0)
Review Comment:
I am reviewing files in GH's diff order -- so if I later find the answer for
this, will add a comment
##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -691,14 +690,58 @@ class TriggerDetails(TypedDict):
events: int
[email protected](kw_only=True)
+class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]):
+ _async_writer: asyncio.StreamWriter = attrs.field(alias="async_writer")
+ _async_reader: asyncio.StreamReader = attrs.field(alias="async_reader")
+
+ body_decoder: TypeAdapter[ToTriggerRunner] = attrs.field(
+ factory=lambda: TypeAdapter(ToTriggerRunner), repr=False
+ )
+
+ _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)
+
+ def _read_frame(self):
+ from asgiref.sync import async_to_sync
+
+ return async_to_sync(self._aread_frame)()
+
+ def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
+ from asgiref.sync import async_to_sync
+
+ return async_to_sync(self.asend)(msg)
+
+ async def _aread_frame(self):
+ len_bytes = await self._async_reader.readexactly(4)
+ len = int.from_bytes(len_bytes, byteorder="big")
+
+ buffer = await self._async_reader.readexactly(len)
Review Comment:
Should we set an upper limit for the size?
##########
airflow-core/tests/unit/dag_processing/test_manager.py:
##########
@@ -577,44 +575,40 @@ def test_kill_timed_out_processors_no_kill(self):
)
],
"/opt/airflow/dags/dag_callback_dag.py",
- b"{"
- b'"file":"/opt/airflow/dags/dag_callback_dag.py",'
- b'"bundle_path":"/opt/airflow/dags",'
- b'"requests_fd":123,"callback_requests":'
- b"["
- b"{"
- b'"filepath":"dag_callback_dag.py",'
- b'"bundle_name":"testing",'
- b'"bundle_version":null,'
- b'"msg":null,'
- b'"dag_id":"dag_id",'
- b'"run_id":"run_id",'
- b'"is_failure_callback":false,'
- b'"type":"DagCallbackRequest"'
- b"}"
- b"],"
- b'"type":"DagFileParseRequest"'
- b"}\n",
+ {
+ "file": "/opt/airflow/dags/dag_callback_dag.py",
+ "bundle_path": "/opt/airflow/dags",
+ "callback_requests": [
+ {
+ "filepath": "dag_callback_dag.py",
+ "bundle_name": "testing",
+ "bundle_version": None,
+ "msg": None,
+ "dag_id": "dag_id",
+ "run_id": "run_id",
+ "is_failure_callback": False,
+ "type": "DagCallbackRequest",
+ }
+ ],
+ "type": "DagFileParseRequest",
+ },
),
],
)
- def test_serialize_callback_requests(self, callbacks, path,
expected_buffer):
+ def test_serialize_callback_requests(self, callbacks, path, expected_body):
+ from airflow.sdk.execution_time.comms import _ResponseFrame
+
processor, read_socket = self.mock_processor()
processor._on_child_started(callbacks, path,
bundle_path=Path("/opt/airflow/dags"))
read_socket.settimeout(0.1)
- val = b""
- try:
- while not val.endswith(b"\n"):
- chunk = read_socket.recv(4096)
- if not chunk:
- break
- val += chunk
- except (BlockingIOError, TimeoutError):
- # no response written, valid for some message types.
- pass
-
- assert val == expected_buffer
+ # Read response from the read end of the socket
+ read_socket.settimeout(0.1)
Review Comment:
It is already set on L604
##########
task-sdk/src/airflow/sdk/execution_time/comms.py:
##########
@@ -80,20 +90,152 @@
)
from airflow.sdk.exceptions import ErrorType
+if TYPE_CHECKING:
+ from structlog.typing import FilteringBoundLogger as Logger
+
+SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
+ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
+
+
+def _msgpack_enc_hook(obj: Any) -> Any:
+ import pendulum
+
+ if isinstance(obj, pendulum.DateTime):
+ # convert the complex to a tuple of real, imag
+ return datetime(
+ obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second,
obj.microsecond, tzinfo=obj.tzinfo
+ )
Review Comment:
Is this because we were previously relying purely on Pydantic for
ser/de-ser?
Everything should be part of BaseModel though? When would I need (for
example) to add a new type here?
##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -691,14 +690,58 @@ class TriggerDetails(TypedDict):
events: int
[email protected](kw_only=True)
+class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner, ToTriggerSupervisor]):
+ _async_writer: asyncio.StreamWriter = attrs.field(alias="async_writer")
+ _async_reader: asyncio.StreamReader = attrs.field(alias="async_reader")
+
+ body_decoder: TypeAdapter[ToTriggerRunner] = attrs.field(
+ factory=lambda: TypeAdapter(ToTriggerRunner), repr=False
+ )
+
+ _lock: asyncio.Lock = attrs.field(factory=asyncio.Lock, repr=False)
+
+ def _read_frame(self):
+ from asgiref.sync import async_to_sync
+
+ return async_to_sync(self._aread_frame)()
+
+ def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
+ from asgiref.sync import async_to_sync
+
+ return async_to_sync(self.asend)(msg)
+
+ async def _aread_frame(self):
+ len_bytes = await self._async_reader.readexactly(4)
+ len = int.from_bytes(len_bytes, byteorder="big")
+
+ buffer = await self._async_reader.readexactly(len)
+ return self.resp_decoder.decode(buffer)
Review Comment:
Can this hang in any scenario?
##########
task-sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -346,10 +331,12 @@ def _fork_main(
# Store original stderr for last-chance exception handling
last_chance_stderr = _get_last_chance_stderr()
+ # os.environ["_AIRFLOW_SUPERVISOR_FD"] = str(requests.fileno())
Review Comment:
?
##########
task-sdk/src/airflow/sdk/execution_time/comms.py:
##########
@@ -80,20 +90,152 @@
)
from airflow.sdk.exceptions import ErrorType
+if TYPE_CHECKING:
+ from structlog.typing import FilteringBoundLogger as Logger
+
+SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
+ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
+
+
+def _msgpack_enc_hook(obj: Any) -> Any:
+ import pendulum
+
+ if isinstance(obj, pendulum.DateTime):
+ # convert the complex to a tuple of real, imag
+ return datetime(
+ obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second,
obj.microsecond, tzinfo=obj.tzinfo
+ )
+ if isinstance(obj, Path):
+ return str(obj)
+ if isinstance(obj, BaseModel):
+ return obj.model_dump(exclude_unset=True)
+
+ # Raise a NotImplementedError for other types
+ raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
+
+
+def _new_encoder() -> msgspec.msgpack.Encoder:
+ return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook)
+
+
+class _RequestFrame(msgspec.Struct, array_like=True, frozen=True,
omit_defaults=True):
+ id: int
+ """
+ The request id, set by the sender.
+
+ This is used to allow "pipeling" of requests and to be able to tie
response to requests, which is
+ particularly useful in the Triggerer where multiple async tasks can send a
requests concurrently.
+ """
+ body: dict[str, Any] | None
+
+ req_encoder: ClassVar[msgspec.msgpack.Encoder] = _new_encoder()
+
+ def as_bytes(self) -> bytearray:
+ # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing
for inspiration
+ buffer = bytearray(256)
+
+ self.req_encoder.encode_into(self, buffer, 4)
+
+ n = len(buffer) - 4
+ if n > 2**32:
+ raise OverflowError("Cannot send messages larger than 4GiB")
+ buffer[:4] = n.to_bytes(4, byteorder="big")
+
+ return buffer
+
+
+class _ResponseFrame(_RequestFrame, msgspec.Struct, array_like=True,
frozen=True, omit_defaults=True):
+ id: int
+ """
+ The id of the request this is a response to
+ """
+ body: dict[str, Any] | None = None
+ error: dict[str, Any] | None = None
+
+
[email protected]()
+class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
+ """Handle communication between the task in this process and the
supervisor parent process."""
+
+ log: Logger = attrs.field(repr=False, factory=structlog.get_logger)
+ request_socket: socket = attrs.field(factory=lambda: socket(fileno=0))
+
+ resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field(
+ factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False
+ )
+
+ id_counter: Iterator[int] = attrs.field(factory=itertools.count)
+
+ # We could be "clever" here and set the default to this based type
parameters and a custom
+ # `__class_getitem__`, but that's a lot of code the one subclass we've got
currently. So we'll just use a
+ # "sort of wrong default"
+ body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda:
TypeAdapter(ToTask), repr=False)
+
+ err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda:
TypeAdapter(ToTask), repr=False)
+
+ def send(self, msg: SendMsgType) -> ReceiveMsgType | None:
+ """Send a request to the parent and block until the response is
received."""
+ frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
+ bytes = frame.as_bytes()
+
+ self.request_socket.sendall(bytes)
+
+ return self._get_response()
+
+ def _read_frame(self):
+ """
+ Get a message from the parent.
+
+ This will block until the message has been received.
+ """
+ if self.request_socket:
+ self.request_socket.setblocking(True)
+ len_bytes = self.request_socket.recv(4)
+
+ if len_bytes == b"":
+ raise EOFError("Request socket closed before length")
+
+ len = int.from_bytes(len_bytes, byteorder="big")
+
+ buffer = bytearray(len)
+ nread = self.request_socket.recv_into(buffer)
+ if nread != len:
+ raise RuntimeError(
+ f"unable to read full response in child. (We read {nread}, but
expected {len})"
+ )
+ if nread == 0:
+ raise EOFError("Request socket closed before response was
complete")
Review Comment:
Probably worth logging the id?
##########
task-sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -725,31 +649,33 @@ def send_request(self, log: Logger, msg: SendMsgType):
def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
- msg = SUPERVISOR_COMMS.get_message()
+ # The parent sends us a StartupDetails message un-prompted. After this,
ever single message is only sent
Review Comment:
```suggestion
# The parent sends us a StartupDetails message un-prompted. After this,
every single message is only sent
```
##########
airflow-core/src/airflow/dag_processing/processor.py:
##########
@@ -266,20 +263,18 @@ def _on_child_started(
msg = DagFileParseRequest(
file=os.fspath(path),
bundle_path=bundle_path,
- requests_fd=self._requests_fd,
callback_requests=callbacks,
)
- self.send_msg(msg)
+ self.send_msg(msg, in_response_to=0)
Review Comment:
what happens if those requests are sent with same ids
##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -1729,33 +1763,23 @@ def test_handle_requests(
# Read response from the read end of the socket
read_socket.settimeout(0.1)
- val = b""
- try:
- while not val.endswith(b"\n"):
- chunk = read_socket.recv(BUFFER_SIZE)
- if not chunk:
- break
- val += chunk
- except (BlockingIOError, TimeoutError, socket.timeout):
- # no response written, valid for some message types like setters
and TI operations.
- pass
+ frame_len = int.from_bytes(read_socket.recv(4), "big")
+ bytes = read_socket.recv(frame_len)
+ frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes)
+
+ assert frame.id == req_frame.id
# Verify the response was added to the buffer
- assert val == expected_buffer
+ assert frame.body == expected_body
# Verify the response is correctly decoded
# This is important because the subprocess/task runner will read the
response
# and deserialize it to the correct message type
- # Only decode the buffer if it contains data. An empty buffer implies
no response was written.
- if not val and (mock_response and not isinstance(mock_response,
OKResponse)):
- pytest.fail("Expected a response, but got an empty buffer.")
-
- if val:
+ if frame.body is not None:
# Using BytesIO to simulate a readable stream for CommsDecoder.
Review Comment:
```suggestion
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]