This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch rework-tasksdk-supervisor-comms-protocol
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 38d74f3232a0784c1ee852e25affbfbf67a2c704
Author: Ash Berlin-Taylor <a...@apache.org>
AuthorDate: Tue Jun 10 11:55:46 2025 +0100

    Switch the Supervisor/task process from line-based to length-prefixed
    
    The existing JSON Lines based approach had two major drawbacks
    
    1. In the case of really large lines (in the region of 10 or 20MB) the 
python
       line buffering could _sometimes_ result in a partial read
    2. The JSON based approach didn't have the ability to add any metadata (such
       as errors).
    3. Not every message type/call-site waited for a response, which meant those
       client functions could never get told about an error
    
    This changes the communications protocol in in a couple of ways.
    
    First off at the python level the separate send and receive methods in the
    client/task side have been removed and replaced with a single `send()` that
    sends the request, reads the response and raises an error if one is 
returned.
    
    Secondly the JSON Lines approach has been changed from a line-based protocol
    to a binary "frame" one. The protocol (which is the same for whichever side 
is
    sending) is length-prefixed, i.e. we first send the length of the data as a
    4byte big-endian integer, followed by the data itself. This should remove 
the
    possibility of JSON parse errors due to reading incomplete lines
    
    Finally the last change made in this PR is to remove the "extra" requests
    socket/channel. Upon closer examination with this comms path I realised that
    this socket is unnecessary: Since we are in 100% control of the client side 
we
    can make use of the bi-directional nature of `socketpair` and save file
    handles.
---
 .../src/airflow/dag_processing/processor.py        |  16 +-
 .../src/airflow/jobs/triggerer_job_runner.py       |   6 +-
 .../tests/unit/dag_processing/test_manager.py      |  11 +-
 .../tests/unit/dag_processing/test_processor.py    |  14 +-
 airflow-core/tests/unit/jobs/test_triggerer_job.py |   1 -
 devel-common/src/tests_common/pytest_plugin.py     |   9 +-
 task-sdk/src/airflow/sdk/bases/xcom.py             |  27 +-
 .../airflow/sdk/definitions/asset/decorators.py    |  12 +-
 task-sdk/src/airflow/sdk/execution_time/comms.py   | 165 ++++++++++-
 task-sdk/src/airflow/sdk/execution_time/context.py |  32 +--
 .../airflow/sdk/execution_time/lazy_sequence.py    |  18 +-
 .../src/airflow/sdk/execution_time/supervisor.py   | 248 ++++++++++------
 .../src/airflow/sdk/execution_time/task_runner.py  | 165 +++--------
 .../tests/task_sdk/execution_time/test_comms.py    |  83 ++++++
 .../task_sdk/execution_time/test_supervisor.py     | 312 +++++++++++----------
 .../task_sdk/execution_time/test_task_runner.py    |  51 ----
 16 files changed, 675 insertions(+), 495 deletions(-)

diff --git a/airflow-core/src/airflow/dag_processing/processor.py 
b/airflow-core/src/airflow/dag_processing/processor.py
index 5f69082c758..1d1ad25f00b 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -68,7 +68,6 @@ class DagFileParseRequest(BaseModel):
     bundle_path: Path
     """Passing bundle path around lets us figure out relative file path."""
 
-    requests_fd: int
     callback_requests: list[CallbackRequest] = Field(default_factory=list)
     type: Literal["DagFileParseRequest"] = "DagFileParseRequest"
 
@@ -102,18 +101,16 @@ ToDagProcessor = Annotated[
 def _parse_file_entrypoint():
     import structlog
 
-    from airflow.sdk.execution_time import task_runner
+    from airflow.sdk.execution_time import comms, task_runner
 
     # Parse DAG file, send JSON back up!
-    comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager](
-        input=sys.stdin,
-        decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
+    comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager](
+        body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
     )
 
-    msg = comms_decoder.get_message()
+    msg = comms_decoder._get_response()
     if not isinstance(msg, DagFileParseRequest):
         raise RuntimeError(f"Required first message to be a 
DagFileParseRequest, it was {msg}")
-    comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", 
buffering=0)
 
     task_runner.SUPERVISOR_COMMS = comms_decoder
     log = structlog.get_logger(logger_name="task")
@@ -125,7 +122,7 @@ def _parse_file_entrypoint():
 
     result = _parse_file(msg, log)
     if result is not None:
-        comms_decoder.send_request(log, result)
+        comms_decoder.send(result)
 
 
 def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> 
DagFileParsingResult | None:
@@ -266,10 +263,9 @@ class DagFileProcessorProcess(WatchedSubprocess):
         msg = DagFileParseRequest(
             file=os.fspath(path),
             bundle_path=bundle_path,
-            requests_fd=self._requests_fd,
             callback_requests=callbacks,
         )
-        self.send_msg(msg)
+        self.send_msg(msg, in_response_to=0)
 
     def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> 
None:  # type: ignore[override]
         from airflow.sdk.api.datamodels._generated import ConnectionResponse, 
VariableResponse
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py 
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 5df6a032615..0a516abb1a6 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -181,7 +181,6 @@ class messages:
     class StartTriggerer(BaseModel):
         """Tell the async trigger runner process to start, and where to send 
status update messages."""
 
-        requests_fd: int
         type: Literal["StartTriggerer"] = "StartTriggerer"
 
     class TriggerStateChanges(BaseModel):
@@ -342,8 +341,8 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
     ):
         proc = super().start(id=job.id, job=job, target=cls.run_in_process, 
logger=logger, **kwargs)
 
-        msg = messages.StartTriggerer(requests_fd=proc._requests_fd)
-        proc.send_msg(msg)
+        msg = messages.StartTriggerer()
+        proc.send_msg(msg, in_response_to=0)
         return proc
 
     @functools.cached_property
@@ -819,7 +818,6 @@ class TriggerRunner:
         if not isinstance(msg, messages.StartTriggerer):
             raise RuntimeError(f"Required first message to be a 
messages.StartTriggerer, it was {msg}")
 
-        comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", 
buffering=0)
         writer_transport, writer_protocol = await loop.connect_write_pipe(
             lambda: asyncio.streams.FlowControlMixin(loop=loop),
             comms_decoder.request_socket,
diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py 
b/airflow-core/tests/unit/dag_processing/test_manager.py
index c9e974b8cdb..799d2ad51c8 100644
--- a/airflow-core/tests/unit/dag_processing/test_manager.py
+++ b/airflow-core/tests/unit/dag_processing/test_manager.py
@@ -30,7 +30,7 @@ from collections import deque
 from datetime import datetime, timedelta
 from logging.config import dictConfig
 from pathlib import Path
-from socket import socket
+from socket import socket, socketpair
 from unittest import mock
 from unittest.mock import MagicMock
 
@@ -54,7 +54,6 @@ from airflow.models.dag_version import DagVersion
 from airflow.models.dagbundle import DagBundleModel
 from airflow.models.dagcode import DagCode
 from airflow.models.serialized_dag import SerializedDagModel
-from airflow.sdk.execution_time.supervisor import mkpipe
 from airflow.utils import timezone
 from airflow.utils.net import get_hostname
 from airflow.utils.session import create_session
@@ -138,20 +137,19 @@ class TestDagFileProcessorManager:
         logger_filehandle = MagicMock()
         proc.create_time.return_value = time.time()
         proc.wait.return_value = 0
-        read_end, write_end = mkpipe(remote_read=True)
+        read_end, write_end = socketpair()
         ret = DagFileProcessorProcess(
             process_log=MagicMock(),
             id=uuid7(),
             pid=1234,
             process=proc,
             stdin=write_end,
-            requests_fd=123,
             logger_filehandle=logger_filehandle,
             client=MagicMock(),
         )
         if start_time:
             ret.start_time = start_time
-        ret._num_open_sockets = 0
+        ret._open_sockets.clear()
         return ret, read_end
 
     @pytest.fixture
@@ -560,7 +558,6 @@ class TestDagFileProcessorManager:
                 b"{"
                 b'"file":"/opt/airflow/dags/test_dag.py",'
                 b'"bundle_path":"/opt/airflow/dags",'
-                b'"requests_fd":123,'
                 b'"callback_requests":[],'
                 b'"type":"DagFileParseRequest"'
                 b"}\n",
@@ -580,7 +577,7 @@ class TestDagFileProcessorManager:
                 b"{"
                 b'"file":"/opt/airflow/dags/dag_callback_dag.py",'
                 b'"bundle_path":"/opt/airflow/dags",'
-                b'"requests_fd":123,"callback_requests":'
+                b'"callback_requests":'
                 b"["
                 b"{"
                 b'"filepath":"dag_callback_dag.py",'
diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py 
b/airflow-core/tests/unit/dag_processing/test_processor.py
index 28ce7a8c23f..e673848a915 100644
--- a/airflow-core/tests/unit/dag_processing/test_processor.py
+++ b/airflow-core/tests/unit/dag_processing/test_processor.py
@@ -87,7 +87,6 @@ class TestDagFileProcessor:
             DagFileParseRequest(
                 file=file_path,
                 bundle_path=TEST_DAG_FOLDER,
-                requests_fd=1,
                 callback_requests=callback_requests or [],
             ),
             log=structlog.get_logger(),
@@ -395,13 +394,10 @@ def disable_capturing():
 @pytest.mark.usefixtures("disable_capturing")
 def test_parse_file_entrypoint_parses_dag_callbacks(spy_agency):
     r, w = socketpair()
-    # Create a valid FD for the decoder to open
-    _, w2 = socketpair()
 
     w.makefile("wb").write(
-        
b'{"file":"/files/dags/wait.py","bundle_path":"/files/dags","requests_fd":'
-        + str(w2.fileno()).encode("ascii")
-        + b',"callback_requests": [{"filepath": "wait.py", "bundle_name": 
"testing", "bundle_version": null, '
+        b'{"file":"/files/dags/wait.py","bundle_path":"/files/dags",'
+        b'"callback_requests": [{"filepath": "wait.py", "bundle_name": 
"testing", "bundle_version": null, '
         b'"msg": "task_failure", "dag_id": "wait_to_fail", "run_id": '
         b'"manual__2024-12-30T21:02:55.203691+00:00", '
         b'"is_failure_callback": true, "type": "DagCallbackRequest"}], "type": 
"DagFileParseRequest"}\n'
@@ -455,7 +451,7 @@ def test_parse_file_with_dag_callbacks(spy_agency):
         )
     ]
     _parse_file(
-        DagFileParseRequest(file="A", bundle_path="no matter", requests_fd=1, 
callback_requests=requests),
+        DagFileParseRequest(file="A", bundle_path="no matter", 
callback_requests=requests),
         log=structlog.get_logger(),
     )
 
@@ -489,8 +485,6 @@ def test_parse_file_with_task_callbacks(spy_agency):
             bundle_version=None,
         )
     ]
-    _parse_file(
-        DagFileParseRequest(file="A", requests_fd=1, 
callback_requests=requests), log=structlog.get_logger()
-    )
+    _parse_file(DagFileParseRequest(file="A", callback_requests=requests), 
log=structlog.get_logger())
 
     assert called is True
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py 
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index d3c9ae6f27d..dd30db7cc1a 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -172,7 +172,6 @@ def supervisor_builder(mocker, session):
             pid=process.pid,
             stdin=mocker.Mock(),
             process=process,
-            requests_fd=-1,
             capacity=10,
         )
         # Mock the selector
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index b5ae53625e2..73d6c82bcaf 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -1963,8 +1963,13 @@ def mock_supervisor_comms():
     if not AIRFLOW_V_3_0_PLUS:
         yield None
         return
+
+    from airflow.sdk.execution_time.comms import CommsDecoder
+
     with mock.patch(
-        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS",
+        create=True,
+        spec=CommsDecoder,
     ) as supervisor_comms:
         yield supervisor_comms
 
@@ -1991,7 +1996,6 @@ def mocked_parse(spy_agency):
                         id=uuid7(), task_id="hello", dag_id="super_basic_run", 
run_id="c", try_number=1
                     ),
                     file="",
-                    requests_fd=0,
                 ),
                 "example_dag_id",
                 CustomOperator(task_id="hello"),
@@ -2198,7 +2202,6 @@ def create_runtime_ti(mocked_parse):
             ),
             dag_rel_path="",
             bundle_info=BundleInfo(name="anything", version="any"),
-            requests_fd=0,
             ti_context=ti_context,
             start_date=start_date,  # type: ignore
         )
diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py 
b/task-sdk/src/airflow/sdk/bases/xcom.py
index 0c330652956..aa20c7d3b31 100644
--- a/task-sdk/src/airflow/sdk/bases/xcom.py
+++ b/task-sdk/src/airflow/sdk/bases/xcom.py
@@ -70,9 +70,8 @@ class BaseXCom:
             map_index=map_index,
         )
 
-        SUPERVISOR_COMMS.send_request(
-            log=log,
-            msg=SetXCom(
+        SUPERVISOR_COMMS.send(
+            SetXCom(
                 key=key,
                 value=value,
                 dag_id=dag_id,
@@ -107,9 +106,8 @@ class BaseXCom:
         """
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
-        SUPERVISOR_COMMS.send_request(
-            log=log,
-            msg=SetXCom(
+        SUPERVISOR_COMMS.send(
+            SetXCom(
                 key=key,
                 value=value,
                 dag_id=dag_id,
@@ -186,9 +184,8 @@ class BaseXCom:
         # back so that two triggers don't end up interleaving requests and 
create a possible
         # race condition where the wrong trigger reads the response.
         with SUPERVISOR_COMMS.lock:
-            SUPERVISOR_COMMS.send_request(
-                log=log,
-                msg=GetXCom(
+            msg = SUPERVISOR_COMMS.send(
+                GetXCom(
                     key=key,
                     dag_id=dag_id,
                     task_id=task_id,
@@ -197,7 +194,6 @@ class BaseXCom:
                 ),
             )
 
-            msg = SUPERVISOR_COMMS.get_message()
         if not isinstance(msg, XComResult):
             raise TypeError(f"Expected XComResult, received: {type(msg)} 
{msg}")
 
@@ -246,9 +242,8 @@ class BaseXCom:
         # back so that two triggers don't end up interleaving requests and 
create a possible
         # race condition where the wrong trigger reads the response.
         with SUPERVISOR_COMMS.lock:
-            SUPERVISOR_COMMS.send_request(
-                log=log,
-                msg=GetXCom(
+            msg = SUPERVISOR_COMMS.send(
+                GetXCom(
                     key=key,
                     dag_id=dag_id,
                     task_id=task_id,
@@ -257,7 +252,6 @@ class BaseXCom:
                     include_prior_dates=include_prior_dates,
                 ),
             )
-            msg = SUPERVISOR_COMMS.get_message()
 
         if not isinstance(msg, XComResult):
             raise TypeError(f"Expected XComResult, received: {type(msg)} 
{msg}")
@@ -322,9 +316,8 @@ class BaseXCom:
             map_index=map_index,
         )
         cls.purge(xcom_result)  # type: ignore[call-arg]
-        SUPERVISOR_COMMS.send_request(
-            log=log,
-            msg=DeleteXCom(
+        SUPERVISOR_COMMS.send(
+            DeleteXCom(
                 key=key,
                 dag_id=dag_id,
                 task_id=task_id,
diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py 
b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
index d899dd37183..bfe5ac2fc96 100644
--- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
+++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
@@ -69,18 +69,14 @@ class _AssetMainOperator(PythonOperator):
         )
 
     def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, 
Any]]:
-        import structlog
-
         from airflow.sdk.execution_time.comms import ErrorResponse, 
GetAssetByName
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
-        log = structlog.get_logger(logger_name=self.__class__.__qualname__)
-
         def _fetch_asset(name: str) -> Asset:
-            SUPERVISOR_COMMS.send_request(log, GetAssetByName(name=name))
-            if isinstance(msg := SUPERVISOR_COMMS.get_message(), 
ErrorResponse):
-                raise AirflowRuntimeError(msg)
-            return Asset(**msg.model_dump(exclude={"type"}))
+            resp = SUPERVISOR_COMMS.send(GetAssetByName(name=name))
+            if isinstance(resp, ErrorResponse):
+                raise AirflowRuntimeError(resp)
+            return Asset(**resp.model_dump(exclude={"type"}))
 
         value: Any
         for key, param in 
inspect.signature(self.python_callable).parameters.items():
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index d0622cf6ffc..4041323540a 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -43,15 +43,20 @@ Execution API server is because:
 
 from __future__ import annotations
 
+import itertools
 from collections.abc import Iterator
 from datetime import datetime
 from functools import cached_property
-from typing import Annotated, Any, Literal, Union
+from socket import socket
+from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar, 
Union
 from uuid import UUID
 
+import aiologic
 import attrs
+import msgspec
+import structlog
 from fastapi import Body
-from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, 
field_serializer
+from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, 
TypeAdapter, field_serializer
 
 from airflow.sdk.api.datamodels._generated import (
     AssetEventDagRunReference,
@@ -80,6 +85,156 @@ from airflow.sdk.api.datamodels._generated import (
 )
 from airflow.sdk.exceptions import ErrorType
 
+if TYPE_CHECKING:
+    from structlog.typing import FilteringBoundLogger as Logger
+
+SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
+ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
+
+
+class _RequestFrame(msgspec.Struct, array_like=True, frozen=True, 
omit_defaults=True):
+    id: int
+    """
+    The request id, set by the sender.
+
+    This is used to allow "pipeling" of requests and to be able to tie 
response to requests, which is
+    particularly useful in the Triggerer where multiple async tasks can send a 
requests concurrently.
+    """
+    body: dict[str, Any]
+
+
+class _ResponseFrame(msgspec.Struct, array_like=True, frozen=True, 
omit_defaults=True):
+    id: int
+    """
+    The id of the request this is a response to
+    """
+    body: dict[str, Any] | None = None
+    error: dict[str, Any] | None = None
+
+
+def _msgpack_enc_hook(obj: Any) -> Any:
+    import pendulum
+
+    if isinstance(obj, pendulum.DateTime):
+        # convert the complex to a tuple of real, imag
+        return datetime(
+            obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, 
obj.microsecond, tzinfo=obj.tzinfo
+        )
+    if isinstance(obj, BaseModel):
+        return obj.model_dump(exclude_unset=True)
+
+    # Raise a NotImplementedError for other types
+    raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
+
+
+def _new_encoder() -> msgspec.msgpack.Encoder:
+    return msgspec.msgpack.Encoder(enc_hook=_msgpack_enc_hook)
+
+
+@attrs.define()
+class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
+    """Handle communication between the task in this process and the 
supervisor parent process."""
+
+    log: Logger = attrs.field(repr=False, factory=structlog.get_logger)
+    request_socket: socket = attrs.field(factory=lambda: socket(fileno=0))
+
+    # Do we still need the context lock? Keeps the change small for now
+    lock: aiologic.Lock = attrs.field(factory=aiologic.Lock, repr=False)
+
+    resp_decoder: msgspec.msgpack.Decoder[_ResponseFrame] = attrs.field(
+        factory=lambda: msgspec.msgpack.Decoder(_ResponseFrame), repr=False
+    )
+    req_encoder: msgspec.msgpack.Encoder = attrs.field(factory=_new_encoder, 
repr=False)
+
+    id_counter: Iterator[int] = attrs.field(factory=itertools.count)
+
+    # We could be "clever" here and set the default to this based type 
parameters and a custom
+    # `__class_getitem__`, but that's a lot of code the one subclass we've got 
currently. So we'll just use a
+    # "sort of wrong default"
+    body_decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: 
TypeAdapter(ToTask), repr=False)
+
+    err_decoder: TypeAdapter[ErrorResponse] = attrs.field(factory=lambda: 
TypeAdapter(ToTask), repr=False)
+
+    def send(self, msg: SendMsgType) -> ReceiveMsgType:
+        """Send a request to the parent and block until the response is 
received."""
+        bytes = self._encode(msg)
+
+        # print(
+        #     f"Subp: sending {type(msg)} request on 
{self.request_socket.fileno()}, total len={len(bytes)}",
+        #     file=__import__("sys")._ash_out,
+        # )
+        nsent = self.request_socket.send(bytes)
+        # print(f"Subp: {nsent=}", file=__import__("sys")._ash_out)
+
+        return self._get_response()
+
+    def _encode(self, msg: SendMsgType) -> bytearray:
+        # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing 
for inspiration
+        buffer = bytearray(256)
+
+        frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
+        self.req_encoder.encode_into(frame, buffer, 4)
+
+        n = len(buffer) - 4
+        if n > 2**32:
+            raise OverflowError("Cannot send messages larger than 4GiB")
+        buffer[:4] = n.to_bytes(4, byteorder="big")
+
+        return buffer
+
+    def _read_frame(self):
+        """
+        Get a message from the parent.
+
+        This will block until the message has been received.
+        """
+        if self.request_socket:
+            self.request_socket.setblocking(True)
+        # print("Subp: reading length prefix", file=__import__("sys")._ash_out)
+        len_bytes = self.request_socket.recv(4)
+
+        if len_bytes == b"":
+            raise EOFError("Request socket closed before length")
+
+        len = int.from_bytes(len_bytes, byteorder="big")
+        # print(f"Subp: frame {len=} ({len_bytes=})", 
file=__import__("sys")._ash_out)
+
+        buffer = bytearray(len)
+        nread = self.request_socket.recv_into(buffer)
+        if nread != len:
+            raise RuntimeError(
+                f"unable to read full response in child. (We read {nread}, but 
expected {len})"
+            )
+        if nread == 0:
+            # print("Subp: EOF when trying to read frame", 
file=__import__("sys")._ash_out)
+            raise EOFError("Request socket closed before response was 
complete")
+
+        try:
+            return self.resp_decoder.decode(buffer)
+        except Exception as e:
+            raise e
+
+    def _from_frame(self, frame):
+        from airflow.sdk.exceptions import AirflowRuntimeError
+
+        # print(f"Subp: {frame.body=}", file=__import__("sys")._ash_out)
+        if frame.error is not None:
+            err = self.err_decoder.validate_python(frame.error)
+            raise AirflowRuntimeError(error=err)
+
+        if frame.body is None:
+            return None
+
+        try:
+            return self.body_decoder.validate_python(frame.body)
+        except Exception:
+            self.log.exception("Unable to decode message")
+            raise
+
+    def _get_response(self) -> ReceiveMsgType:
+        frame = self._read_frame()
+        return self._from_frame(frame)
+
 
 class StartupDetails(BaseModel):
     model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -87,13 +242,7 @@ class StartupDetails(BaseModel):
     ti: TaskInstance
     dag_rel_path: str
     bundle_info: BundleInfo
-    requests_fd: int
     start_date: datetime
-    """
-    The channel for the task to send requests over.
-
-    Responses will come back on stdin
-    """
     ti_context: TIRunContext
     type: Literal["StartupDetails"] = "StartupDetails"
 
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py 
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 438687bdeb8..3ef99fec248 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -155,8 +155,7 @@ def _get_connection(conn_id: str) -> Connection:
     # back so that two triggers don't end up interleaving requests and create 
a possible
     # race condition where the wrong trigger reads the response.
     with SUPERVISOR_COMMS.lock:
-        SUPERVISOR_COMMS.send_request(log=log, 
msg=GetConnection(conn_id=conn_id))
-        msg = SUPERVISOR_COMMS.get_message()
+        msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id))
 
     if isinstance(msg, ErrorResponse):
         raise AirflowRuntimeError(msg)
@@ -208,8 +207,7 @@ def _get_variable(key: str, deserialize_json: bool) -> Any:
     # back so that two triggers don't end up interleaving requests and create 
a possible
     # race condition where the wrong trigger reads the response.
     with SUPERVISOR_COMMS.lock:
-        SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key))
-        msg = SUPERVISOR_COMMS.get_message()
+        msg = SUPERVISOR_COMMS.send(GetVariable(key=key))
 
     if isinstance(msg, ErrorResponse):
         raise AirflowRuntimeError(msg)
@@ -263,7 +261,7 @@ def _set_variable(key: str, value: Any, description: str | 
None = None, serializ
     # primarily added for triggers but it doesn't make sense to have it in 
some places
     # and not in the rest. A lot of this will be simplified by 
https://github.com/apache/airflow/issues/46426
     with SUPERVISOR_COMMS.lock:
-        SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, 
value=value, description=description))
+        SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, 
description=description))
 
 
 def _delete_variable(key: str) -> None:
@@ -279,8 +277,7 @@ def _delete_variable(key: str) -> None:
     # primarily added for triggers but it doesn't make sense to have it in 
some places
     # and not in the rest. A lot of this will be simplified by 
https://github.com/apache/airflow/issues/46426
     with SUPERVISOR_COMMS.lock:
-        SUPERVISOR_COMMS.send_request(log=log, msg=DeleteVariable(key=key))
-        msg = SUPERVISOR_COMMS.get_message()
+        msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key))
     if TYPE_CHECKING:
         assert isinstance(msg, OKResponse)
 
@@ -387,13 +384,13 @@ class _AssetRefResolutionMixin:
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
         if name:
-            SUPERVISOR_COMMS.send_request(log=log, 
msg=GetAssetByName(name=name))
+            msg = GetAssetByName(name=name)
         elif uri:
-            SUPERVISOR_COMMS.send_request(log=log, msg=GetAssetByUri(uri=uri))
+            msg = GetAssetByUri(uri=uri)
         else:
             raise ValueError("Either name or uri must be provided")
 
-        msg = SUPERVISOR_COMMS.get_message()
+        msg = SUPERVISOR_COMMS.send(msg)
         if isinstance(msg, ErrorResponse):
             raise AirflowRuntimeError(msg)
 
@@ -545,24 +542,26 @@ class InletEventsAccessors(
 
         if isinstance(obj, Asset):
             asset = self._assets[AssetUniqueKey.from_asset(obj)]
-            SUPERVISOR_COMMS.send_request(log=log, 
msg=GetAssetEventByAsset(name=asset.name, uri=asset.uri))
+            msg = GetAssetEventByAsset(name=asset.name, uri=asset.uri)
         elif isinstance(obj, AssetNameRef):
             try:
                 asset = next(a for k, a in self._assets.items() if k.name == 
obj.name)
             except StopIteration:
                 raise KeyError(obj) from None
-            SUPERVISOR_COMMS.send_request(log=log, 
msg=GetAssetEventByAsset(name=asset.name, uri=None))
+            msg = GetAssetEventByAsset(name=asset.name, uri=None)
         elif isinstance(obj, AssetUriRef):
             try:
                 asset = next(a for k, a in self._assets.items() if k.uri == 
obj.uri)
             except StopIteration:
                 raise KeyError(obj) from None
-            SUPERVISOR_COMMS.send_request(log=log, 
msg=GetAssetEventByAsset(name=None, uri=asset.uri))
+            msg = GetAssetEventByAsset(name=None, uri=asset.uri)
         elif isinstance(obj, AssetAlias):
             asset_alias = 
self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)]
-            SUPERVISOR_COMMS.send_request(log=log, 
msg=GetAssetEventByAssetAlias(alias_name=asset_alias.name))
+            msg = GetAssetEventByAssetAlias(alias_name=asset_alias.name)
+        else:
+            raise TypeError(f"`key` is of unknown type ({type(key).__name__})")
 
-        msg = SUPERVISOR_COMMS.get_message()
+        msg = SUPERVISOR_COMMS.send(msg)
         if isinstance(msg, ErrorResponse):
             raise AirflowRuntimeError(msg)
 
@@ -626,8 +625,7 @@ def get_previous_dagrun_success(ti_id: UUID) -> 
PrevSuccessfulDagRunResponse:
     )
     from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
-    SUPERVISOR_COMMS.send_request(log=log, 
msg=GetPrevSuccessfulDagRun(ti_id=ti_id))
-    msg = SUPERVISOR_COMMS.get_message()
+    msg = SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id))
 
     if TYPE_CHECKING:
         assert isinstance(msg, PrevSuccessfulDagRunResult)
diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py 
b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
index 9cf9acfac81..536555a3e39 100644
--- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
+++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
@@ -91,16 +91,14 @@ class LazyXComSequence(Sequence[T]):
 
             task = self._xcom_arg.operator
 
-            SUPERVISOR_COMMS.send_request(
-                log=log,
-                msg=GetXComCount(
+            msg = SUPERVISOR_COMMS.send(
+                GetXComCount(
                     key=self._xcom_arg.key,
                     dag_id=task.dag_id,
                     run_id=self._ti.run_id,
                     task_id=task.task_id,
                 ),
             )
-            msg = SUPERVISOR_COMMS.get_message()
             if isinstance(msg, ErrorResponse):
                 raise RuntimeError(msg)
             if not isinstance(msg, XComCountResponse):
@@ -129,9 +127,8 @@ class LazyXComSequence(Sequence[T]):
             start, stop, step = _coerce_slice(key)
             with SUPERVISOR_COMMS.lock:
                 source = (xcom_arg := self._xcom_arg).operator
-                SUPERVISOR_COMMS.send_request(
-                    log=log,
-                    msg=GetXComSequenceSlice(
+                msg = SUPERVISOR_COMMS.send(
+                    GetXComSequenceSlice(
                         key=xcom_arg.key,
                         dag_id=source.dag_id,
                         task_id=source.task_id,
@@ -141,7 +138,6 @@ class LazyXComSequence(Sequence[T]):
                         step=step,
                     ),
                 )
-                msg = SUPERVISOR_COMMS.get_message()
                 if not isinstance(msg, XComSequenceSliceResult):
                     raise TypeError(f"Got unexpected response to 
GetXComSequenceSlice: {msg!r}")
             return [XCom.deserialize_value(_XComWrapper(value)) for value in 
msg.root]
@@ -153,9 +149,8 @@ class LazyXComSequence(Sequence[T]):
 
         with SUPERVISOR_COMMS.lock:
             source = (xcom_arg := self._xcom_arg).operator
-            SUPERVISOR_COMMS.send_request(
-                log=log,
-                msg=GetXComSequenceItem(
+            msg = SUPERVISOR_COMMS.send(
+                GetXComSequenceItem(
                     key=xcom_arg.key,
                     dag_id=source.dag_id,
                     task_id=source.task_id,
@@ -163,7 +158,6 @@ class LazyXComSequence(Sequence[T]):
                     offset=key,
                 ),
             )
-            msg = SUPERVISOR_COMMS.get_message()
         if isinstance(msg, ErrorResponse):
             raise IndexError(key)
         if not isinstance(msg, XComSequenceIndexResult):
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 7e6b89043d2..a3c9bcbb3df 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -27,12 +27,13 @@ import selectors
 import signal
 import sys
 import time
+import weakref
 from collections import deque
 from collections.abc import Generator
 from contextlib import contextmanager, suppress
 from datetime import datetime, timezone
 from http import HTTPStatus
-from socket import SO_SNDBUF, SOL_SOCKET, SocketIO, socket, socketpair
+from socket import SO_SNDBUF, SOL_SOCKET, socket, socketpair
 from typing import (
     TYPE_CHECKING,
     Callable,
@@ -63,6 +64,7 @@ from airflow.sdk.api.datamodels._generated import (
     XComSequenceIndexResponse,
 )
 from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.execution_time import comms
 from airflow.sdk.execution_time.comms import (
     AssetEventsResult,
     AssetResult,
@@ -108,6 +110,8 @@ from airflow.sdk.execution_time.comms import (
     XComResult,
     XComSequenceIndexResult,
     XComSequenceSliceResult,
+    _RequestFrame,
+    _ResponseFrame,
 )
 from airflow.sdk.execution_time.secrets_masker import mask_secret
 
@@ -192,6 +196,8 @@ def mkpipe(
             local.setsockopt(SO_SNDBUF, SOL_SOCKET, BUFFER_SIZE)
         # set nonblocking to True so that send or sendall waits till all data 
is sent
         local.setblocking(True)
+    else:
+        local.setblocking(False)
 
     return remote, local
 
@@ -223,14 +229,13 @@ def _configure_logs_over_json_channel(log_fd: int):
 def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr):
     # Ensure that sys.stdout et al (and the underlying filehandles for C 
libraries etc) are connected to the
     # pipes from the supervisor
-
     for handle_name, fd, sock, mode in (
-        ("stdin", 0, child_stdin, "r"),
+        # Yes, we want to re-open stdin in write mode! This is cause it is a 
bi-directional socket, so we can
+        # read and write to it.
+        ("stdin", 0, child_stdin, "w"),
         ("stdout", 1, child_stdout, "w"),
         ("stderr", 2, child_stderr, "w"),
     ):
-        handle = getattr(sys, handle_name)
-        handle.close()
         os.dup2(sock.fileno(), fd)
         del sock
 
@@ -318,7 +323,7 @@ def block_orm_access():
 
 
 def _fork_main(
-    child_stdin: socket,
+    requests: socket,
     child_stdout: socket,
     child_stderr: socket,
     log_fd: int,
@@ -345,10 +350,12 @@ def _fork_main(
     # Store original stderr for last-chance exception handling
     last_chance_stderr = _get_last_chance_stderr()
 
+    # os.environ["_AIRFLOW_SUPERVISOR_FD"] = str(requests.fileno())
+
     _reset_signals()
     if log_fd:
         _configure_logs_over_json_channel(log_fd)
-    _reopen_std_io_handles(child_stdin, child_stdout, child_stderr)
+    _reopen_std_io_handles(requests, child_stdout, child_stderr)
 
     def exit(n: int) -> NoReturn:
         with suppress(ValueError, OSError):
@@ -359,11 +366,11 @@ def _fork_main(
             last_chance_stderr.flush()
 
         # Explicitly close the child-end of our supervisor sockets so
-        # the parent sees EOF on both "requests" and "logs" channels.
+        # the parent sees EOF on "logs" channel.
         with suppress(OSError):
             os.close(log_fd)
         with suppress(OSError):
-            os.close(child_stdin.fileno())
+            os.close(requests.fileno())
         os._exit(n)
 
     if hasattr(atexit, "_clear"):
@@ -431,16 +438,18 @@ class WatchedSubprocess:
     """The decoder to use for incoming messages from the child process."""
 
     _process: psutil.Process = attrs.field(repr=False)
-    _requests_fd: int
     """File descriptor for request handling."""
 
-    _num_open_sockets: int = 4
     _exit_code: int | None = attrs.field(default=None, init=False)
     _process_exit_monotonic: float | None = attrs.field(default=None, 
init=False)
-    _fd_to_socket_type: dict[int, str] = attrs.field(factory=dict, init=False)
+    _open_sockets: weakref.WeakKeyDictionary[socket, str] = attrs.field(
+        factory=weakref.WeakKeyDictionary, init=False
+    )
 
     selector: selectors.BaseSelector = 
attrs.field(factory=selectors.DefaultSelector, repr=False)
 
+    _frame_encoder: msgspec.msgpack.Encoder = 
attrs.field(factory=comms._new_encoder, repr=False)
+
     process_log: FilteringBoundLogger = attrs.field(repr=False)
 
     subprocess_logs_to_stdout: bool = False
@@ -459,18 +468,19 @@ class WatchedSubprocess:
     ) -> Self:
         """Fork and start a new subprocess with the specified target 
function."""
         # Create socketpairs/"pipes" to connect to the stdin and out from the 
subprocess
-        child_stdin, feed_stdin = mkpipe(remote_read=True)
-        child_stdout, read_stdout = mkpipe()
-        child_stderr, read_stderr = mkpipe()
+        child_stdout, read_stdout = socketpair()
+        child_stderr, read_stderr = socketpair()
+
+        # Place for child to send requests/read responses, and the server side 
to read/respond
+        child_requests, read_requests = socketpair()
 
-        # Open these socketpair before forking off the child, so that it is 
open when we fork.
-        child_comms, read_msgs = mkpipe()
-        child_logs, read_logs = mkpipe()
+        # Open the socketpair before forking off the child, so that it is open 
when we fork.
+        child_logs, read_logs = socketpair()
 
         pid = os.fork()
         if pid == 0:
             # Close and delete of the parent end of the sockets.
-            cls._close_unused_sockets(feed_stdin, read_stdout, read_stderr, 
read_msgs, read_logs)
+            cls._close_unused_sockets(read_requests, read_stdout, read_stderr, 
read_logs)
 
             # Python GC should delete these for us, but lets make double sure 
that we don't keep anything
             # around in the forked processes, especially things that might 
involve open files or sockets!
@@ -479,28 +489,28 @@ class WatchedSubprocess:
 
             try:
                 # Run the child entrypoint
-                _fork_main(child_stdin, child_stdout, child_stderr, 
child_logs.fileno(), target)
+                _fork_main(child_requests, child_stdout, child_stderr, 
child_logs.fileno(), target)
             except BaseException as e:
+                import traceback
+
                 with suppress(BaseException):
                     # We can't use log here, as if we except out of _fork_main 
something _weird_ went on.
-                    print("Exception in _fork_main, exiting with code 124", e, 
file=sys.stderr)
+                    print("Exception in _fork_main, exiting with code 124", 
file=sys.stderr)
+                    traceback.print_exception(type(e), e, e.__traceback__, 
file=sys.stderr)
 
             # It's really super super important we never exit this block. We 
are in the forked child, and if we
             # do then _THINGS GET WEIRD_.. (Normally `_fork_main` itself will 
`_exit()` so we never get here)
             os._exit(124)
 
-        requests_fd = child_comms.fileno()
-
         # Close the remaining parent-end of the sockets we've passed to the 
child via fork. We still have the
         # other end of the pair open
-        cls._close_unused_sockets(child_stdin, child_stdout, child_stderr, 
child_comms, child_logs)
+        cls._close_unused_sockets(child_stdout, child_stderr, child_logs)
 
         logger = logger or cast("FilteringBoundLogger", 
structlog.get_logger(logger_name="task").bind())
         proc = cls(
             pid=pid,
-            stdin=feed_stdin,
+            stdin=read_requests,
             process=psutil.Process(pid),
-            requests_fd=requests_fd,
             process_log=logger,
             start_time=time.monotonic(),
             **constructor_kwargs,
@@ -509,7 +519,7 @@ class WatchedSubprocess:
         proc._register_pipe_readers(
             stdout=read_stdout,
             stderr=read_stderr,
-            requests=read_msgs,
+            requests=read_requests,
             logs=read_logs,
         )
 
@@ -522,24 +532,26 @@ class WatchedSubprocess:
         # alternatives are used automatically) -- this is a way of having 
"event-based" code, but without
         # needing full async, to read and process output from each socket as 
it is received.
 
-        # Track socket types for debugging
-        self._fd_to_socket_type = {
-            stdout.fileno(): "stdout",
-            stderr.fileno(): "stderr",
-            requests.fileno(): "requests",
-            logs.fileno(): "logs",
-        }
+        # Track the open sockets, and for debugging what type each one is
+        self._open_sockets.update(
+            (
+                (stdout, "stdout"),
+                (stderr, "stderr"),
+                (logs, "logs"),
+                (requests, "requests"),
+            )
+        )
 
         target_loggers: tuple[FilteringBoundLogger, ...] = (self.process_log,)
         if self.subprocess_logs_to_stdout:
             target_loggers += (log,)
         self.selector.register(
-            stdout, selectors.EVENT_READ, 
self._create_socket_handler(target_loggers, channel="stdout")
+            stdout, selectors.EVENT_READ, 
self._create_log_forwarder(target_loggers, channel="stdout")
         )
         self.selector.register(
             stderr,
             selectors.EVENT_READ,
-            self._create_socket_handler(target_loggers, channel="stderr", 
log_level=logging.ERROR),
+            self._create_log_forwarder(target_loggers, channel="stderr", 
log_level=logging.ERROR),
         )
         self.selector.register(
             logs,
@@ -551,37 +563,52 @@ class WatchedSubprocess:
         self.selector.register(
             requests,
             selectors.EVENT_READ,
-            make_buffered_socket_reader(self.handle_requests(log), 
on_close=self._on_socket_closed),
+            length_prefixed_frame_reader(self.handle_requests(log), 
on_close=self._on_socket_closed),
         )
 
-    def _create_socket_handler(self, loggers, channel, log_level=logging.INFO) 
-> Callable[[socket], bool]:
+    def _create_log_forwarder(self, loggers, channel, log_level=logging.INFO) 
-> Callable[[socket], bool]:
         """Create a socket handler that forwards logs to a logger."""
         return make_buffered_socket_reader(
             forward_to_log(loggers, chan=channel, level=log_level), 
on_close=self._on_socket_closed
         )
 
-    def _on_socket_closed(self):
+    def _on_socket_closed(self, sock: socket):
         # We want to keep servicing this process until we've read up to EOF 
from all the sockets.
-        self._num_open_sockets -= 1
+        self._open_sockets.pop(sock, None)
 
-    def send_msg(self, msg: BaseModel, **dump_opts):
-        """Send the given pydantic message to the subprocess at once by 
encoding it and adding a line break."""
-        b = msg.model_dump_json(**dump_opts).encode() + b"\n"
-        self.stdin.sendall(b)
+    def send_msg(
+        self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse 
| None = None, **dump_opts
+    ):
+        """Send the msg as a length-prefixed response frame."""
+        # https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing 
for inspiration
+        if msg:
+            frame = _ResponseFrame(id=in_response_to, 
body=msg.model_dump(**dump_opts))
+        else:
+            err_resp = error.model_dump() if error else None
+            frame = _ResponseFrame(id=in_response_to, error=err_resp)
+        buffer = bytearray(256)
+
+        self._frame_encoder.encode_into(frame, buffer, 4)
+        n = len(buffer) - 4
+        if n > 2**32:
+            raise OverflowError("Cannot send messages larger than 4GiB")
+        buffer[:4] = n.to_bytes(4, byteorder="big")
+
+        self.stdin.sendall(buffer)
 
-    def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, 
bytes, None]:
+    def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, 
_RequestFrame, None]:
         """Handle incoming requests from the task process, respond with the 
appropriate data."""
         while True:
-            line = yield
+            request = yield
 
             try:
-                msg = self.decoder.validate_json(line)
+                msg = self.decoder.validate_python(request.body)
             except Exception:
-                log.exception("Unable to decode message", line=line)
+                log.exception("Unable to decode message", body=request.body)
                 continue
 
             try:
-                self._handle_request(msg, log)
+                self._handle_request(msg, log, request.id)
             except ServerResponseError as e:
                 error_details = e.response.json() if e.response else None
                 log.error(
@@ -593,27 +620,25 @@ class WatchedSubprocess:
 
                 # Send error response back to task so that the error appears 
in the task logs
                 self.send_msg(
-                    ErrorResponse(
+                    msg=None,
+                    error=ErrorResponse(
                         error=ErrorType.API_SERVER_ERROR,
                         detail={
                             "status_code": e.response.status_code,
                             "message": str(e),
                             "detail": error_details,
                         },
-                    )
+                    ),
+                    in_response_to=request.id,
                 )
 
-    def _handle_request(self, msg, log: FilteringBoundLogger) -> None:
+    def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> 
None:
         raise NotImplementedError()
 
     @staticmethod
     def _close_unused_sockets(*sockets):
         """Close unused ends of sockets after fork."""
         for sock in sockets:
-            if isinstance(sock, SocketIO):
-                # If we have the socket IO object, we need to close the 
underlying socket foricebly here too,
-                # else we get unclosed socket warnings, and likely leaking FDs 
too
-                sock._sock.close()
             sock.close()
 
     def _cleanup_open_sockets(self):
@@ -623,20 +648,18 @@ class WatchedSubprocess:
         # sockets the supervisor would wait forever thinking they are still
         # active. This cleanup ensures we always release resources and exit.
         stuck_sockets = []
-        for key in list(self.selector.get_map().values()):
-            socket_type = self._fd_to_socket_type.get(key.fd, 
f"unknown-{key.fd}")
-            stuck_sockets.append(f"{socket_type}({key.fd})")
+        for sock, socket_type in self._open_sockets.items():
+            fileno = "unknown"
             with suppress(Exception):
-                self.selector.unregister(key.fileobj)
-            with suppress(Exception):
-                key.fileobj.close()  # type: ignore[union-attr]
+                fileno = sock.fileno()
+                sock.close()
+            stuck_sockets.append(f"{socket_type}(fd={fileno})")
 
         if stuck_sockets:
             log.warning("Force-closed stuck sockets", pid=self.pid, 
sockets=stuck_sockets)
 
         self.selector.close()
-        self._close_unused_sockets(self.stdin)
-        self._num_open_sockets = 0
+        self.stdin.close()
 
     def kill(
         self,
@@ -753,7 +776,9 @@ class WatchedSubprocess:
             # are removed.
             if not need_more:
                 self.selector.unregister(key.fileobj)
-                key.fileobj.close()  # type: ignore[union-attr]
+                sock: socket = key.fileobj  # type: ignore[assignment]
+                sock.close()
+                self._on_socket_closed(sock)
 
         # Check if the subprocess has exited
         return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout, 
expect_signal=expect_signal)
@@ -860,7 +885,6 @@ class ActivitySubprocess(WatchedSubprocess):
             ti=ti,
             dag_rel_path=os.fspath(dag_rel_path),
             bundle_info=bundle_info,
-            requests_fd=self._requests_fd,
             ti_context=ti_context,
             start_date=start_date,
         )
@@ -869,7 +893,7 @@ class ActivitySubprocess(WatchedSubprocess):
         log.debug("Sending", msg=msg)
 
         try:
-            self.send_msg(msg)
+            self.send_msg(msg, in_response_to=0)
         except BrokenPipeError:
             # Debug is fine, the process will have shown _something_ in it's 
last_chance exception handler
             log.debug("Couldn't send startup message to Subprocess - it died 
very early", pid=self.pid)
@@ -929,7 +953,7 @@ class ActivitySubprocess(WatchedSubprocess):
         - Processes events triggered on the monitored file objects, such as 
data availability or EOF.
         - Sends heartbeats to ensure the process is alive and checks if the 
subprocess has exited.
         """
-        while self._exit_code is None or self._num_open_sockets > 0:
+        while self._exit_code is None or self._open_sockets:
             last_heartbeat_ago = time.monotonic() - 
self._last_successful_heartbeat
             # Monitor the task to see if it's done. Wait in a syscall 
(`select`) for as long as possible
             # so we notice the subprocess finishing as quick as we can.
@@ -945,16 +969,11 @@ class ActivitySubprocess(WatchedSubprocess):
             # This listens for activity (e.g., subprocess output) on 
registered file objects
             alive = self._service_subprocess(max_wait_time=max_wait_time) is 
None
 
-            if self._exit_code is not None and self._num_open_sockets > 0:
+            if self._exit_code is not None and self._open_sockets:
                 if (
                     self._process_exit_monotonic
                     and time.monotonic() - self._process_exit_monotonic > 
SOCKET_CLEANUP_TIMEOUT
                 ):
-                    log.debug(
-                        "Forcefully closing remaining sockets",
-                        open_sockets=self._num_open_sockets,
-                        pid=self.pid,
-                    )
                     self._cleanup_open_sockets()
 
             if alive:
@@ -1050,7 +1069,7 @@ class ActivitySubprocess(WatchedSubprocess):
             return SERVER_TERMINATED
         return TaskInstanceState.FAILED
 
-    def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
+    def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, 
req_id: int):
         log.debug("Received message from task runner", msg=msg)
         resp: BaseModel | None = None
         dump_opts = {}
@@ -1223,10 +1242,21 @@ class ActivitySubprocess(WatchedSubprocess):
             dump_opts = {"exclude_unset": True}
         else:
             log.error("Unhandled request", msg=msg)
+            self.send_msg(
+                None,
+                in_response_to=req_id,
+                error=ErrorResponse(
+                    error=ErrorType.API_SERVER_ERROR,
+                    detail={"status_code": 400, "message": "Unhandled 
request"},
+                ),
+            )
             return
 
         if resp:
-            self.send_msg(resp, **dump_opts)
+            self.send_msg(resp, in_response_to=req_id, error=None, **dump_opts)
+        else:
+            # Send an empty response frame (which signifies no error) if we 
dont have anything else to say
+            self.send_msg(None, in_response_to=req_id)
 
 
 def in_process_api_server():
@@ -1253,7 +1283,7 @@ class InProcessSupervisorComms:
         log.debug("Sending request", msg=msg)
 
         with set_supervisor_comms(None):
-            self.supervisor._handle_request(msg, log)  # type: ignore[arg-type]
+            self.supervisor._handle_request(msg, log, 0)  # type: 
ignore[arg-type]
 
 
 @attrs.define
@@ -1297,7 +1327,6 @@ class InProcessTestSupervisor(ActivitySubprocess):
             id=what.id,
             pid=os.getpid(),  # Use current process
             process=psutil.Process(),  # Current process
-            requests_fd=-1,  # Not used in in-process mode
             process_log=logger or 
structlog.get_logger(logger_name="task").bind(),
             client=cls._api_client(task.dag),
             **kwargs,
@@ -1362,9 +1391,12 @@ class InProcessTestSupervisor(ActivitySubprocess):
         client.base_url = "http://in-process.invalid./";  # type: 
ignore[assignment]
         return client
 
-    def send_msg(self, msg: BaseModel, **dump_opts):
+    def send_msg(
+        self, msg: BaseModel | None, in_response_to: int, error: ErrorResponse 
| None = None, **dump_opts
+    ):
         """Override to use in-process comms."""
-        self.comms.messages.append(msg)
+        if msg is not None:
+            self.comms.messages.append(msg)
 
     @property
     def final_state(self):
@@ -1420,7 +1452,7 @@ def run_task_in_process(ti: TaskInstance, task) -> 
TaskRunResult:
 # to a (sync) generator
 def make_buffered_socket_reader(
     gen: Generator[None, bytes | bytearray, None],
-    on_close: Callable,
+    on_close: Callable[[socket], None],
     buffer_size: int = 4096,
 ) -> Callable[[socket], bool]:
     buffer = bytearray()  # This will hold our accumulated binary data
@@ -1440,7 +1472,7 @@ def make_buffered_socket_reader(
                 with suppress(StopIteration):
                     gen.send(buffer)
             # Tell loop to close this selector
-            on_close()
+            on_close(sock)
             return False
 
         buffer.extend(read_buffer[:n_received])
@@ -1451,7 +1483,7 @@ def make_buffered_socket_reader(
             try:
                 gen.send(line)
             except StopIteration:
-                on_close()
+                on_close(sock)
                 return False
             buffer = buffer[newline_pos + 1 :]  # Update the buffer with 
remaining data
 
@@ -1460,6 +1492,56 @@ def make_buffered_socket_reader(
     return cb
 
 
+def length_prefixed_frame_reader(
+    gen: Generator[None, _RequestFrame, None], on_close: Callable[[socket], 
None]
+):
+    length_needed: int | None = None
+    # This will hold our accumulated/partial binary frame if it doesn't come 
in a single read
+    buffer: memoryview | None = None
+    # position in the buffer to store next read
+    pos = 0
+    decoder = msgspec.msgpack.Decoder[_RequestFrame](_RequestFrame)
+
+    # We need to start up the generator to get it to the point it's at waiting 
on the yield
+    next(gen)
+
+    def cb(sock: socket):
+        print("Main: length_prefixed_frame_reader.cb fired")
+        nonlocal buffer, length_needed, pos
+        # Read up to `buffer_size` bytes of data from the socket
+
+        if length_needed is None:
+            # Read the 32bit length of the frame
+            bytes = sock.recv(4)
+            if bytes == b"":
+                on_close(sock)
+                return False
+
+            length_needed = int.from_bytes(bytes, byteorder="big")
+            buffer = memoryview(bytearray(length_needed))
+        if length_needed and buffer:
+            n = sock.recv_into(buffer[pos:])
+            if n == 0:
+                # EOF
+                on_close(sock)
+                return False
+            pos += n
+
+            if len(buffer) >= length_needed:
+                request = decoder.decode(buffer)
+                buffer = None
+                pos = 0
+                length_needed = None
+                try:
+                    gen.send(request)
+                except StopIteration:
+                    on_close(sock)
+                    return False
+        return True
+
+    return cb
+
+
 def process_log_messages_from_subprocess(
     loggers: tuple[FilteringBoundLogger, ...],
 ) -> Generator[None, bytes, None]:
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index bbc7394a87a..199544ae835 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -28,16 +28,14 @@ import time
 from collections.abc import Callable, Iterable, Iterator, Mapping
 from contextlib import suppress
 from datetime import datetime, timezone
-from io import FileIO
 from itertools import product
 from pathlib import Path
-from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TextIO, 
TypeVar
+from typing import TYPE_CHECKING, Annotated, Any, Literal
 
-import aiologic
 import attrs
 import lazy_object_proxy
 import structlog
-from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, 
TypeAdapter
+from pydantic import AwareDatetime, ConfigDict, Field, JsonValue
 
 from airflow.dag_processing.bundles.base import BaseDagBundle, 
BundleVersionLock
 from airflow.dag_processing.bundles.manager import DagBundlesManager
@@ -59,6 +57,7 @@ from airflow.sdk.exceptions import AirflowRuntimeError, 
ErrorType
 from airflow.sdk.execution_time.callback_runner import create_executable_runner
 from airflow.sdk.execution_time.comms import (
     AssetEventDagRunReferenceResult,
+    CommsDecoder,
     DagRunStateResult,
     DeferTask,
     DRCount,
@@ -412,10 +411,9 @@ class RuntimeTaskInstance(TaskInstance):
 
         log.debug("Requesting first reschedule date from supervisor")
 
-        SUPERVISOR_COMMS.send_request(
-            log=log, msg=GetTaskRescheduleStartDate(ti_id=self.id, 
try_number=first_try_number)
+        response = SUPERVISOR_COMMS.send(
+            msg=GetTaskRescheduleStartDate(ti_id=self.id, 
try_number=first_try_number)
         )
-        response = SUPERVISOR_COMMS.get_message()
 
         if TYPE_CHECKING:
             assert isinstance(response, TaskRescheduleStartDate)
@@ -433,12 +431,9 @@ class RuntimeTaskInstance(TaskInstance):
         states: list[str] | None = None,
     ) -> int:
         """Return the number of task instances matching the given criteria."""
-        log = structlog.get_logger(logger_name="task")
-
         with SUPERVISOR_COMMS.lock:
-            SUPERVISOR_COMMS.send_request(
-                log=log,
-                msg=GetTICount(
+            response = SUPERVISOR_COMMS.send(
+                GetTICount(
                     dag_id=dag_id,
                     map_index=map_index,
                     task_ids=task_ids,
@@ -448,7 +443,6 @@ class RuntimeTaskInstance(TaskInstance):
                     states=states,
                 ),
             )
-            response = SUPERVISOR_COMMS.get_message()
 
         if TYPE_CHECKING:
             assert isinstance(response, TICount)
@@ -465,12 +459,9 @@ class RuntimeTaskInstance(TaskInstance):
         run_ids: list[str] | None = None,
     ) -> dict[str, Any]:
         """Return the task states matching the given criteria."""
-        log = structlog.get_logger(logger_name="task")
-
         with SUPERVISOR_COMMS.lock:
-            SUPERVISOR_COMMS.send_request(
-                log=log,
-                msg=GetTaskStates(
+            response = SUPERVISOR_COMMS.send(
+                GetTaskStates(
                     dag_id=dag_id,
                     map_index=map_index,
                     task_ids=task_ids,
@@ -479,7 +470,6 @@ class RuntimeTaskInstance(TaskInstance):
                     run_ids=run_ids,
                 ),
             )
-            response = SUPERVISOR_COMMS.get_message()
 
         if TYPE_CHECKING:
             assert isinstance(response, TaskStatesResult)
@@ -494,19 +484,15 @@ class RuntimeTaskInstance(TaskInstance):
         states: list[str] | None = None,
     ) -> int:
         """Return the number of DAG runs matching the given criteria."""
-        log = structlog.get_logger(logger_name="task")
-
         with SUPERVISOR_COMMS.lock:
-            SUPERVISOR_COMMS.send_request(
-                log=log,
-                msg=GetDRCount(
+            response = SUPERVISOR_COMMS.send(
+                GetDRCount(
                     dag_id=dag_id,
                     logical_dates=logical_dates,
                     run_ids=run_ids,
                     states=states,
                 ),
             )
-            response = SUPERVISOR_COMMS.get_message()
 
         if TYPE_CHECKING:
             assert isinstance(response, DRCount)
@@ -516,10 +502,8 @@ class RuntimeTaskInstance(TaskInstance):
     @staticmethod
     def get_dagrun_state(dag_id: str, run_id: str) -> str:
         """Return the state of the DAG run with the given Run ID."""
-        log = structlog.get_logger(logger_name="task")
         with SUPERVISOR_COMMS.lock:
-            SUPERVISOR_COMMS.send_request(log=log, 
msg=GetDagRunState(dag_id=dag_id, run_id=run_id))
-            response = SUPERVISOR_COMMS.get_message()
+            response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, 
run_id=run_id))
 
         if TYPE_CHECKING:
             assert isinstance(response, DagRunStateResult)
@@ -638,62 +622,6 @@ def parse(what: StartupDetails, log: Logger) -> 
RuntimeTaskInstance:
     )
 
 
-SendMsgType = TypeVar("SendMsgType", bound=BaseModel)
-ReceiveMsgType = TypeVar("ReceiveMsgType", bound=BaseModel)
-
-
-@attrs.define()
-class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
-    """Handle communication between the task in this process and the 
supervisor parent process."""
-
-    input: TextIO
-
-    request_socket: FileIO = attrs.field(init=False, default=None)
-
-    # We could be "clever" here and set the default to this based type 
parameters and a custom
-    # `__class_getitem__`, but that's a lot of code the one subclass we've got 
currently. So we'll just use a
-    # "sort of wrong default"
-    decoder: TypeAdapter[ReceiveMsgType] = attrs.field(factory=lambda: 
TypeAdapter(ToTask), repr=False)
-
-    lock: aiologic.Lock = attrs.field(factory=aiologic.Lock, repr=False)
-
-    def get_message(self) -> ReceiveMsgType:
-        """
-        Get a message from the parent.
-
-        This will block until the message has been received.
-        """
-        line = None
-
-        # TODO: Investigate why some empty lines are sent to the processes 
stdin.
-        #   That was highlighted when working on 
https://github.com/apache/airflow/issues/48183
-        #   and is maybe related to deferred/triggerer only context.
-        while not line:
-            line = self.input.readline()
-
-        try:
-            msg = self.decoder.validate_json(line)
-        except Exception:
-            structlog.get_logger(logger_name="CommsDecoder").exception("Unable 
to decode message", line=line)
-            raise
-
-        if isinstance(msg, StartupDetails):
-            # If we read a startup message, pull out the FDs we care about!
-            if msg.requests_fd > 0:
-                self.request_socket = os.fdopen(msg.requests_fd, "wb", 
buffering=0)
-        elif isinstance(msg, ErrorResponse) and msg.error == 
ErrorType.API_SERVER_ERROR:
-            structlog.get_logger(logger_name="task").error("Error response 
from the API Server")
-            raise AirflowRuntimeError(error=msg)
-
-        return msg
-
-    def send_request(self, log: Logger, msg: SendMsgType):
-        encoded_msg = msg.model_dump_json().encode() + b"\n"
-
-        log.debug("Sending request", json=encoded_msg)
-        self.request_socket.write(encoded_msg)
-
-
 # This global variable will be used by Connection/Variable/XCom classes, or 
other parts of the task's execution,
 # to send requests back to the supervisor process.
 #
@@ -713,31 +641,33 @@ SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
 
 
 def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
-    msg = SUPERVISOR_COMMS.get_message()
+    # The parent sends us a StartupDetails message un-prompted. After this, 
ever single message is only sent
+    # in response to us sending a request.
+    msg = SUPERVISOR_COMMS._get_response()
+
+    if not isinstance(msg, StartupDetails):
+        raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
 
     log = structlog.get_logger(logger_name="task")
 
+    # setproctitle causes issue on Mac OS: 
https://github.com/benoitc/gunicorn/issues/3021
+    os_type = sys.platform
+    if os_type == "darwin":
+        log.debug("Mac OS detected, skipping setproctitle")
+    else:
+        from setproctitle import setproctitle
+
+        setproctitle(f"airflow worker -- {msg.ti.id}")
+
     try:
         get_listener_manager().hook.on_starting(component=TaskRunnerMarker())
     except Exception:
         log.exception("error calling listener")
 
-    if isinstance(msg, StartupDetails):
-        # setproctitle causes issue on Mac OS: 
https://github.com/benoitc/gunicorn/issues/3021
-        os_type = sys.platform
-        if os_type == "darwin":
-            log.debug("Mac OS detected, skipping setproctitle")
-        else:
-            from setproctitle import setproctitle
-
-            setproctitle(f"airflow worker -- {msg.ti.id}")
-
-        with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, 
task_id=msg.ti.task_id):
-            ti = parse(msg, log)
-            ti.log_url = get_log_url_from_ti(ti)
-        log.debug("DAG file parsed", file=msg.dag_rel_path)
-    else:
-        raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
+    with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, 
task_id=msg.ti.task_id):
+        ti = parse(msg, log)
+        ti.log_url = get_log_url_from_ti(ti)
+    log.debug("DAG file parsed", file=msg.dag_rel_path)
 
     return ti, ti.get_template_context(), log
 
@@ -785,7 +715,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: 
Context) -> ToSuperv
 
     if rendered_fields := _serialize_rendered_fields(ti.task):
         # so that we do not call the API unnecessarily
-        SUPERVISOR_COMMS.send_request(log=log, 
msg=SetRenderedFields(rendered_fields=rendered_fields))
+        
SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields))
 
     _validate_task_inlets_and_outlets(ti=ti, log=log)
 
@@ -805,8 +735,7 @@ def _validate_task_inlets_and_outlets(*, ti: 
RuntimeTaskInstance, log: Logger) -
     if not ti.task.inlets and not ti.task.outlets:
         return
 
-    SUPERVISOR_COMMS.send_request(msg=ValidateInletsAndOutlets(ti_id=ti.id), 
log=log)
-    inactive_assets_resp = SUPERVISOR_COMMS.get_message()
+    inactive_assets_resp = 
SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id))
     if TYPE_CHECKING:
         assert isinstance(inactive_assets_resp, InactiveAssetsResult)
     if inactive_assets := inactive_assets_resp.inactive_assets:
@@ -902,7 +831,7 @@ def run(
     except DownstreamTasksSkipped as skip:
         log.info("Skipping downstream tasks.")
         tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else 
[skip.tasks]
-        SUPERVISOR_COMMS.send_request(log=log, 
msg=SkipDownstreamTasks(tasks=tasks_to_skip))
+        SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip))
         msg, state = _handle_current_task_success(context, ti)
     except DagRunTriggerException as drte:
         msg, state = _handle_trigger_dag_run(drte, context, ti, log)
@@ -964,7 +893,7 @@ def run(
         error = e
     finally:
         if msg:
-            SUPERVISOR_COMMS.send_request(msg=msg, log=log)
+            SUPERVISOR_COMMS.send(msg=msg)
 
     # Return the message to make unit tests easier too
     ti.state = state
@@ -1002,9 +931,8 @@ def _handle_trigger_dag_run(
 ) -> tuple[ToSupervisor, TaskInstanceState]:
     """Handle exception from TriggerDagRunOperator."""
     log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id)
-    SUPERVISOR_COMMS.send_request(
-        log=log,
-        msg=TriggerDagRun(
+    comms_msg = SUPERVISOR_COMMS.send(
+        TriggerDagRun(
             dag_id=drte.trigger_dag_id,
             run_id=drte.dag_run_id,
             logical_date=drte.logical_date,
@@ -1013,7 +941,6 @@ def _handle_trigger_dag_run(
         ),
     )
 
-    comms_msg = SUPERVISOR_COMMS.get_message()
     if isinstance(comms_msg, ErrorResponse) and comms_msg.error == 
ErrorType.DAGRUN_ALREADY_EXISTS:
         if drte.skip_when_already_exists:
             log.info(
@@ -1068,10 +995,9 @@ def _handle_trigger_dag_run(
             )
             time.sleep(drte.poke_interval)
 
-            SUPERVISOR_COMMS.send_request(
-                log=log, msg=GetDagRunState(dag_id=drte.trigger_dag_id, 
run_id=drte.dag_run_id)
+            comms_msg = SUPERVISOR_COMMS.send(
+                GetDagRunState(dag_id=drte.trigger_dag_id, 
run_id=drte.dag_run_id)
             )
-            comms_msg = SUPERVISOR_COMMS.get_message()
             if TYPE_CHECKING:
                 assert isinstance(comms_msg, DagRunStateResult)
             if comms_msg.state in drte.failed_states:
@@ -1260,10 +1186,7 @@ def finalize(
     if getattr(ti.task, "overwrite_rtif_after_execution", False):
         log.debug("Overwriting Rendered template fields.")
         if ti.task.template_fields:
-            SUPERVISOR_COMMS.send_request(
-                log=log,
-                
msg=SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)),
-            )
+            
SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task)))
 
     log.debug("Running finalizers", ti=ti)
     if state == TaskInstanceState.SUCCESS:
@@ -1304,9 +1227,11 @@ def finalize(
 
 
 def main():
-    # TODO: add an exception here, it causes an oof of a stack trace!
+    # TODO: add an exception here, it causes an oof of a stack trace if it 
happens to early!
+    log = structlog.get_logger(logger_name="task")
+
     global SUPERVISOR_COMMS
-    SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](input=sys.stdin)
+    SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log)
 
     try:
         ti, context, log = startup()
@@ -1317,11 +1242,9 @@ def main():
             state, msg, error = run(ti, context, log)
             finalize(ti, state, context, log, error)
     except KeyboardInterrupt:
-        log = structlog.get_logger(logger_name="task")
         log.exception("Ctrl-c hit")
         exit(2)
     except Exception:
-        log = structlog.get_logger(logger_name="task")
         log.exception("Top level error")
         exit(1)
     finally:
diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py 
b/task-sdk/tests/task_sdk/execution_time/test_comms.py
new file mode 100644
index 00000000000..ee2f956b063
--- /dev/null
+++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import uuid
+from socket import socketpair
+
+import msgspec
+import pytest
+
+from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails, 
_ResponseFrame
+from airflow.sdk.execution_time.task_runner import CommsDecoder
+from airflow.utils import timezone
+
+
+class TestCommsDecoder:
+    """Test the communication between the subprocess and the "supervisor"."""
+
+    @pytest.mark.usefixtures("disable_capturing")
+    def test_recv_StartupDetails(self):
+        r, w = socketpair()
+
+        msg = {
+            "type": "StartupDetails",
+            "ti": {
+                "id": uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"),
+                "task_id": "a",
+                "try_number": 1,
+                "run_id": "b",
+                "dag_id": "c",
+            },
+            "ti_context": {
+                "dag_run": {
+                    "dag_id": "c",
+                    "run_id": "b",
+                    "logical_date": "2024-12-01T01:00:00Z",
+                    "data_interval_start": "2024-12-01T00:00:00Z",
+                    "data_interval_end": "2024-12-01T01:00:00Z",
+                    "start_date": "2024-12-01T01:00:00Z",
+                    "run_after": "2024-12-01T01:00:00Z",
+                    "end_date": None,
+                    "run_type": "manual",
+                    "conf": None,
+                    "consumed_asset_events": [],
+                },
+                "max_tries": 0,
+                "should_retry": False,
+                "variables": None,
+                "connections": None,
+            },
+            "file": "/dev/null",
+            "start_date": "2024-12-01T01:00:00Z",
+            "dag_rel_path": "/dev/null",
+            "bundle_info": {"name": "any-name", "version": "any-version"},
+        }
+        bytes = msgspec.msgpack.encode(_ResponseFrame(0, msg, None))
+        w.sendall(len(bytes).to_bytes(4, byteorder="big") + bytes)
+
+        decoder = CommsDecoder(request_socket=r, log=None)
+
+        msg = decoder._get_response()
+        assert isinstance(msg, StartupDetails)
+        assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab")
+        assert msg.ti.task_id == "a"
+        assert msg.ti.dag_id == "c"
+        assert msg.dag_rel_path == "/dev/null"
+        assert msg.bundle_info == BundleInfo(name="any-name", 
version="any-version")
+        assert msg.start_date == timezone.datetime(2024, 12, 1, 1)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 4f5e4cfc7ac..3fa05e7d757 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -27,13 +27,14 @@ import signal
 import socket
 import sys
 import time
-from io import BytesIO
 from operator import attrgetter
+from random import randint
 from time import sleep
 from typing import TYPE_CHECKING
 from unittest.mock import MagicMock, patch
 
 import httpx
+import msgspec
 import psutil
 import pytest
 from pytest_unordered import unordered
@@ -56,6 +57,7 @@ from airflow.sdk.execution_time import task_runner
 from airflow.sdk.execution_time.comms import (
     AssetEventsResult,
     AssetResult,
+    CommsDecoder,
     ConnectionResult,
     DagRunStateResult,
     DeferTask,
@@ -97,17 +99,16 @@ from airflow.sdk.execution_time.comms import (
     XComResult,
     XComSequenceIndexResult,
     XComSequenceSliceResult,
+    _RequestFrame,
+    _ResponseFrame,
 )
 from airflow.sdk.execution_time.supervisor import (
-    BUFFER_SIZE,
     ActivitySubprocess,
     InProcessSupervisorComms,
     InProcessTestSupervisor,
-    mkpipe,
     set_supervisor_comms,
     supervise,
 )
-from airflow.sdk.execution_time.task_runner import CommsDecoder
 from airflow.utils import timezone, timezone as tz
 
 if TYPE_CHECKING:
@@ -136,18 +137,25 @@ def local_dag_bundle_cfg(path, name="my-bundle"):
     }
 
 
+@pytest.fixture
+def client_with_ti_start(make_ti_context):
+    client = MagicMock(spec=sdk_client.Client)
+    client.task_instances.start.return_value = make_ti_context()
+    return client
+
+
 @pytest.mark.usefixtures("disable_capturing")
 class TestWatchedSubprocess:
     @pytest.fixture(autouse=True)
     def disable_log_upload(self, spy_agency):
         spy_agency.spy_on(ActivitySubprocess._upload_logs, call_original=False)
 
-    def test_reading_from_pipes(self, captured_logs, time_machine):
+    def test_reading_from_pipes(self, captured_logs, time_machine, 
client_with_ti_start):
         def subprocess_main():
             # This is run in the subprocess!
 
-            # Ensure we follow the "protocol" and get the startup message 
before we do anything
-            sys.stdin.readline()
+            # Ensure we follow the "protocol" and get the startup message 
before we do anything else
+            CommsDecoder()._get_response()
 
             import logging
             import warnings
@@ -180,7 +188,7 @@ class TestWatchedSubprocess:
                 run_id="d",
                 try_number=1,
             ),
-            client=MagicMock(spec=sdk_client.Client),
+            client=client_with_ti_start,
             target=subprocess_main,
         )
 
@@ -228,12 +236,12 @@ class TestWatchedSubprocess:
             ]
         )
 
-    def test_subprocess_sigkilled(self):
+    def test_subprocess_sigkilled(self, client_with_ti_start):
         main_pid = os.getpid()
 
         def subprocess_main():
             # Ensure we follow the "protocol" and get the startup message 
before we do anything
-            sys.stdin.readline()
+            CommsDecoder()._get_response()
 
             assert os.getpid() != main_pid
             os.kill(os.getpid(), signal.SIGKILL)
@@ -248,7 +256,7 @@ class TestWatchedSubprocess:
                 run_id="d",
                 try_number=1,
             ),
-            client=MagicMock(spec=sdk_client.Client),
+            client=client_with_ti_start,
             target=subprocess_main,
         )
 
@@ -285,7 +293,7 @@ class TestWatchedSubprocess:
         monkeypatch.setattr(airflow.sdk.execution_time.supervisor, 
"MIN_HEARTBEAT_INTERVAL", 0.1)
 
         def subprocess_main():
-            sys.stdin.readline()
+            CommsDecoder()._get_response()
 
             for _ in range(5):
                 print("output", flush=True)
@@ -314,7 +322,7 @@ class TestWatchedSubprocess:
         monkeypatch.setattr(airflow.sdk.execution_time.supervisor, 
"MIN_HEARTBEAT_INTERVAL", 0.1)
 
         def subprocess_main():
-            sys.stdin.readline()
+            CommsDecoder()._get_response()
 
             for _ in range(5):
                 print("output", flush=True)
@@ -340,7 +348,7 @@ class TestWatchedSubprocess:
         assert proc.wait() == 0
         spy_agency.assert_spy_not_called(heartbeat_spy)
 
-    def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, 
mocker, make_ti_context):
+    def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, 
mocker, client_with_ti_start):
         """Test running a simple DAG in a subprocess and capturing the 
output."""
 
         instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901)
@@ -355,11 +363,6 @@ class TestWatchedSubprocess:
             try_number=1,
         )
 
-        # Create a mock client to assert calls to the client
-        # We assume the implementation of the client is correct and only need 
to check the calls
-        mock_client = mocker.Mock(spec=sdk_client.Client)
-        mock_client.task_instances.start.return_value = make_ti_context()
-
         bundle_info = BundleInfo(name="my-bundle", version=None)
         with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, 
bundle_info.name)):
             exit_code = supervise(
@@ -368,7 +371,7 @@ class TestWatchedSubprocess:
                 token="",
                 server="",
                 dry_run=True,
-                client=mock_client,
+                client=client_with_ti_start,
                 bundle_info=bundle_info,
             )
             assert exit_code == 0, captured_logs
@@ -498,7 +501,7 @@ class TestWatchedSubprocess:
         monkeypatch.setattr(airflow.sdk.execution_time.supervisor, 
"MIN_HEARTBEAT_INTERVAL", 0.0)
 
         def subprocess_main():
-            sys.stdin.readline()
+            CommsDecoder()._get_response()
             sleep(5)
             # Shouldn't get here
             exit(5)
@@ -611,7 +614,6 @@ class TestWatchedSubprocess:
             stdin=mocker.MagicMock(),
             client=client,
             process=mock_process,
-            requests_fd=-1,
         )
 
         time_now = tz.datetime(2024, 11, 28, 12, 0, 0)
@@ -701,7 +703,6 @@ class TestWatchedSubprocess:
             stdin=mocker.Mock(),
             process=mocker.Mock(),
             client=mocker.Mock(),
-            requests_fd=-1,
         )
 
         # Set the terminal state and task end datetime
@@ -738,7 +739,7 @@ class TestWatchedSubprocess:
             ),
         ),
     )
-    def test_exit_by_signal(self, monkeypatch, signal_to_raise, log_pattern, 
cap_structlog):
+    def test_exit_by_signal(self, signal_to_raise, log_pattern, cap_structlog, 
client_with_ti_start):
         def subprocess_main():
             import faulthandler
             import os
@@ -748,7 +749,7 @@ class TestWatchedSubprocess:
                 faulthandler.disable()
 
             # Ensure we follow the "protocol" and get the startup message 
before we do anything
-            sys.stdin.readline()
+            CommsDecoder()._get_response()
 
             os.kill(os.getpid(), signal_to_raise)
 
@@ -762,7 +763,7 @@ class TestWatchedSubprocess:
                 run_id="d",
                 try_number=1,
             ),
-            client=MagicMock(spec=sdk_client.Client),
+            client=client_with_ti_start,
             target=subprocess_main,
         )
 
@@ -791,26 +792,26 @@ class TestWatchedSubprocess:
             stdin=mocker.MagicMock(),
             client=mocker.MagicMock(),
             process=mock_process,
-            requests_fd=-1,
         )
 
         proc.selector = mocker.MagicMock()
         proc.selector.select.return_value = []
 
         proc._exit_code = 0
-        proc._num_open_sockets = 1
+        # Create a dummy placeholder in the open socket weekref
+        proc._open_sockets[mocker.MagicMock()] = "test placeholder"
         proc._process_exit_monotonic = time.monotonic()
 
         mocker.patch.object(
             ActivitySubprocess,
             "_cleanup_open_sockets",
-            side_effect=lambda: setattr(proc, "_num_open_sockets", 0),
+            side_effect=lambda: setattr(proc, "_open_sockets", {}),
         )
 
         time_machine.shift(2)
 
         proc._monitor_subprocess()
-        assert proc._num_open_sockets == 0
+        assert len(proc._open_sockets) == 0
 
 
 class TestWatchedSubprocessKill:
@@ -829,7 +830,6 @@ class TestWatchedSubprocessKill:
             stdin=mocker.Mock(),
             client=mocker.Mock(),
             process=mock_process,
-            requests_fd=-1,
         )
         # Mock the selector
         mock_selector = mocker.Mock(spec=selectors.DefaultSelector)
@@ -888,7 +888,7 @@ class TestWatchedSubprocessKill:
             ),
         ],
     )
-    def test_kill_escalation_path(self, signal_to_send, exit_after, mocker, 
captured_logs, monkeypatch):
+    def test_kill_escalation_path(self, signal_to_send, exit_after, 
captured_logs, client_with_ti_start):
         def subprocess_main():
             import signal
 
@@ -905,7 +905,7 @@ class TestWatchedSubprocessKill:
             signal.signal(signal.SIGINT, _handler)
             signal.signal(signal.SIGTERM, _handler)
             try:
-                sys.stdin.readline()
+                CommsDecoder()._get_response()
                 print("Ready")
                 sleep(10)
             except Exception as e:
@@ -919,7 +919,7 @@ class TestWatchedSubprocessKill:
             dag_rel_path=os.devnull,
             bundle_info=FAKE_BUNDLE,
             what=TaskInstance(id=ti_id, task_id="b", dag_id="c", run_id="d", 
try_number=1),
-            client=MagicMock(spec=sdk_client.Client),
+            client=client_with_ti_start,
             target=subprocess_main,
         )
 
@@ -1073,16 +1073,15 @@ class TestWatchedSubprocessKill:
 class TestHandleRequest:
     @pytest.fixture
     def watched_subprocess(self, mocker):
-        read_end, write_end = mkpipe(remote_read=True)
+        read_end, write_end = socket.socketpair()
 
         subprocess = ActivitySubprocess(
             process_log=mocker.MagicMock(),
             id=TI_ID,
             pid=12345,
-            stdin=write_end,  # this is the writer side
+            stdin=write_end,
             client=mocker.Mock(),
             process=mocker.Mock(),
-            requests_fd=-1,
         )
 
         return subprocess, read_end
@@ -1091,7 +1090,7 @@ class TestHandleRequest:
     @pytest.mark.parametrize(
         [
             "message",
-            "expected_buffer",
+            "expected_body",
             "client_attr_path",
             "method_arg",
             "method_kwarg",
@@ -1101,7 +1100,7 @@ class TestHandleRequest:
         [
             pytest.param(
                 GetConnection(conn_id="test_conn"),
-                
b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n',
+                {"conn_id": "test_conn", "conn_type": "mysql", "type": 
"ConnectionResult"},
                 "connections.get",
                 ("test_conn",),
                 {},
@@ -1111,7 +1110,12 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetConnection(conn_id="test_conn"),
-                
b'{"conn_id":"test_conn","conn_type":"mysql","password":"password","type":"ConnectionResult"}\n',
+                {
+                    "conn_id": "test_conn",
+                    "conn_type": "mysql",
+                    "password": "password",
+                    "type": "ConnectionResult",
+                },
                 "connections.get",
                 ("test_conn",),
                 {},
@@ -1121,7 +1125,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetConnection(conn_id="test_conn"),
-                
b'{"conn_id":"test_conn","conn_type":"mysql","schema":"mysql","type":"ConnectionResult"}\n',
+                {"conn_id": "test_conn", "conn_type": "mysql", "schema": 
"mysql", "type": "ConnectionResult"},
                 "connections.get",
                 ("test_conn",),
                 {},
@@ -1131,7 +1135,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetVariable(key="test_key"),
-                
b'{"key":"test_key","value":"test_value","type":"VariableResult"}\n',
+                {"key": "test_key", "value": "test_value", "type": 
"VariableResult"},
                 "variables.get",
                 ("test_key",),
                 {},
@@ -1141,7 +1145,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 PutVariable(key="test_key", value="test_value", 
description="test_description"),
-                b"",
+                None,
                 "variables.set",
                 ("test_key", "test_value", "test_description"),
                 {},
@@ -1151,7 +1155,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 DeleteVariable(key="test_key"),
-                b'{"ok":true,"type":"OKResponse"}\n',
+                {"ok": True, "type": "OKResponse"},
                 "variables.delete",
                 ("test_key",),
                 {},
@@ -1161,7 +1165,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 DeferTask(next_method="execute_callback", 
classpath="my-classpath"),
-                b"",
+                None,
                 "task_instances.defer",
                 (TI_ID, DeferTask(next_method="execute_callback", 
classpath="my-classpath")),
                 {},
@@ -1174,7 +1178,7 @@ class TestHandleRequest:
                     reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
                     end_date=timezone.parse("2024-10-31T12:00:00Z"),
                 ),
-                b"",
+                None,
                 "task_instances.reschedule",
                 (
                     TI_ID,
@@ -1190,7 +1194,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetXCom(dag_id="test_dag", run_id="test_run", 
task_id="test_task", key="test_key"),
-                
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
+                {"key": "test_key", "value": "test_value", "type": 
"XComResult"},
                 "xcoms.get",
                 ("test_dag", "test_run", "test_task", "test_key", None, False),
                 {},
@@ -1202,7 +1206,7 @@ class TestHandleRequest:
                 GetXCom(
                     dag_id="test_dag", run_id="test_run", task_id="test_task", 
key="test_key", map_index=2
                 ),
-                
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
+                {"key": "test_key", "value": "test_value", "type": 
"XComResult"},
                 "xcoms.get",
                 ("test_dag", "test_run", "test_task", "test_key", 2, False),
                 {},
@@ -1212,7 +1216,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetXCom(dag_id="test_dag", run_id="test_run", 
task_id="test_task", key="test_key"),
-                b'{"key":"test_key","value":null,"type":"XComResult"}\n',
+                {"key": "test_key", "value": None, "type": "XComResult"},
                 "xcoms.get",
                 ("test_dag", "test_run", "test_task", "test_key", None, False),
                 {},
@@ -1228,7 +1232,7 @@ class TestHandleRequest:
                     key="test_key",
                     include_prior_dates=True,
                 ),
-                b'{"key":"test_key","value":null,"type":"XComResult"}\n',
+                {"key": "test_key", "value": None, "type": "XComResult"},
                 "xcoms.get",
                 ("test_dag", "test_run", "test_task", "test_key", None, True),
                 {},
@@ -1244,7 +1248,7 @@ class TestHandleRequest:
                     key="test_key",
                     value='{"key": "test_key", "value": {"key2": "value2"}}',
                 ),
-                b"",
+                None,
                 "xcoms.set",
                 (
                     "test_dag",
@@ -1269,7 +1273,7 @@ class TestHandleRequest:
                     value='{"key": "test_key", "value": {"key2": "value2"}}',
                     map_index=2,
                 ),
-                b"",
+                None,
                 "xcoms.set",
                 (
                     "test_dag",
@@ -1295,7 +1299,7 @@ class TestHandleRequest:
                     map_index=2,
                     mapped_length=3,
                 ),
-                b"",
+                None,
                 "xcoms.set",
                 (
                     "test_dag",
@@ -1319,15 +1323,9 @@ class TestHandleRequest:
                     key="test_key",
                     map_index=2,
                 ),
-                b"",
+                None,
                 "xcoms.delete",
-                (
-                    "test_dag",
-                    "test_run",
-                    "test_task",
-                    "test_key",
-                    2,
-                ),
+                ("test_dag", "test_run", "test_task", "test_key", 2),
                 {},
                 OKResponse(ok=True),
                 None,
@@ -1337,7 +1335,7 @@ class TestHandleRequest:
             # if it can handle TaskState message
             pytest.param(
                 TaskState(state=TaskInstanceState.SKIPPED, 
end_date=timezone.parse("2024-10-31T12:00:00Z")),
-                b"",
+                None,
                 "",
                 (),
                 {},
@@ -1349,7 +1347,7 @@ class TestHandleRequest:
                 RetryTask(
                     end_date=timezone.parse("2024-10-31T12:00:00Z"), 
rendered_map_index="test retry task"
                 ),
-                b"",
+                None,
                 "task_instances.retry",
                 (),
                 {
@@ -1363,7 +1361,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 SetRenderedFields(rendered_fields={"field1": 
"rendered_value1", "field2": "rendered_value2"}),
-                b"",
+                None,
                 "task_instances.set_rtif",
                 (TI_ID, {"field1": "rendered_value1", "field2": 
"rendered_value2"}),
                 {},
@@ -1373,7 +1371,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetAssetByName(name="asset"),
-                
b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n',
+                {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", 
"type": "AssetResult"},
                 "assets.get",
                 [],
                 {"name": "asset"},
@@ -1383,7 +1381,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetAssetByUri(uri="s3://bucket/obj"),
-                
b'{"name":"asset","uri":"s3://bucket/obj","group":"asset","type":"AssetResult"}\n',
+                {"name": "asset", "uri": "s3://bucket/obj", "group": "asset", 
"type": "AssetResult"},
                 "assets.get",
                 [],
                 {"uri": "s3://bucket/obj"},
@@ -1393,11 +1391,17 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetAssetEventByAsset(uri="s3://bucket/obj", name="test"),
-                (
-                    b'{"asset_events":'
-                    
b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},'
-                    b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n'
-                ),
+                {
+                    "asset_events": [
+                        {
+                            "id": 1,
+                            "timestamp": 
timezone.parse("2024-10-31T12:00:00Z"),
+                            "asset": {"name": "asset", "uri": 
"s3://bucket/obj", "group": "asset"},
+                            "created_dagruns": [],
+                        }
+                    ],
+                    "type": "AssetEventsResult",
+                },
                 "asset_events.get",
                 [],
                 {"uri": "s3://bucket/obj", "name": "test"},
@@ -1416,11 +1420,17 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetAssetEventByAsset(uri="s3://bucket/obj", name=None),
-                (
-                    b'{"asset_events":'
-                    
b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},'
-                    b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n'
-                ),
+                {
+                    "asset_events": [
+                        {
+                            "id": 1,
+                            "timestamp": 
timezone.parse("2024-10-31T12:00:00Z"),
+                            "asset": {"name": "asset", "uri": 
"s3://bucket/obj", "group": "asset"},
+                            "created_dagruns": [],
+                        }
+                    ],
+                    "type": "AssetEventsResult",
+                },
                 "asset_events.get",
                 [],
                 {"uri": "s3://bucket/obj", "name": None},
@@ -1439,11 +1449,17 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetAssetEventByAsset(uri=None, name="test"),
-                (
-                    b'{"asset_events":'
-                    
b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},'
-                    b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n'
-                ),
+                {
+                    "asset_events": [
+                        {
+                            "id": 1,
+                            "timestamp": 
timezone.parse("2024-10-31T12:00:00Z"),
+                            "asset": {"name": "asset", "uri": 
"s3://bucket/obj", "group": "asset"},
+                            "created_dagruns": [],
+                        }
+                    ],
+                    "type": "AssetEventsResult",
+                },
                 "asset_events.get",
                 [],
                 {"uri": None, "name": "test"},
@@ -1462,11 +1478,17 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetAssetEventByAssetAlias(alias_name="test_alias"),
-                (
-                    b'{"asset_events":'
-                    
b'[{"id":1,"timestamp":"2024-10-31T12:00:00Z","asset":{"name":"asset","uri":"s3://bucket/obj","group":"asset"},'
-                    b'"created_dagruns":[]}],"type":"AssetEventsResult"}\n'
-                ),
+                {
+                    "asset_events": [
+                        {
+                            "id": 1,
+                            "timestamp": 
timezone.parse("2024-10-31T12:00:00Z"),
+                            "asset": {"name": "asset", "uri": 
"s3://bucket/obj", "group": "asset"},
+                            "created_dagruns": [],
+                        }
+                    ],
+                    "type": "AssetEventsResult",
+                },
                 "asset_events.get",
                 [],
                 {"alias_name": "test_alias"},
@@ -1485,7 +1507,10 @@ class TestHandleRequest:
             ),
             pytest.param(
                 ValidateInletsAndOutlets(ti_id=TI_ID),
-                
b'{"inactive_assets":[{"name":"asset_name","uri":"asset_uri","type":"asset"}],"type":"InactiveAssetsResult"}\n',
+                {
+                    "inactive_assets": [{"name": "asset_name", "uri": 
"asset_uri", "type": "asset"}],
+                    "type": "InactiveAssetsResult",
+                },
                 "task_instances.validate_inlets_and_outlets",
                 (TI_ID,),
                 {},
@@ -1499,7 +1524,7 @@ class TestHandleRequest:
                 SucceedTask(
                     end_date=timezone.parse("2024-10-31T12:00:00Z"), 
rendered_map_index="test success task"
                 ),
-                b"",
+                None,
                 "task_instances.succeed",
                 (),
                 {
@@ -1515,11 +1540,13 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetPrevSuccessfulDagRun(ti_id=TI_ID),
-                (
-                    
b'{"data_interval_start":"2025-01-10T12:00:00Z","data_interval_end":"2025-01-10T14:00:00Z",'
-                    
b'"start_date":"2025-01-10T12:00:00Z","end_date":"2025-01-10T14:00:00Z",'
-                    b'"type":"PrevSuccessfulDagRunResult"}\n'
-                ),
+                {
+                    "data_interval_start": 
timezone.parse("2025-01-10T12:00:00Z"),
+                    "data_interval_end": 
timezone.parse("2025-01-10T14:00:00Z"),
+                    "start_date": timezone.parse("2025-01-10T12:00:00Z"),
+                    "end_date": timezone.parse("2025-01-10T14:00:00Z"),
+                    "type": "PrevSuccessfulDagRunResult",
+                },
                 "task_instances.get_previous_successful_dagrun",
                 (TI_ID,),
                 {},
@@ -1540,7 +1567,7 @@ class TestHandleRequest:
                     logical_date=timezone.datetime(2025, 1, 1),
                     reset_dag_run=True,
                 ),
-                b'{"ok":true,"type":"OKResponse"}\n',
+                {"ok": True, "type": "OKResponse"},
                 "dag_runs.trigger",
                 ("test_dag", "test_run", {"key": "value"}, 
timezone.datetime(2025, 1, 1), True),
                 {},
@@ -1549,8 +1576,9 @@ class TestHandleRequest:
                 id="dag_run_trigger",
             ),
             pytest.param(
+                # TODO: This should be raise an exception, not returning an 
ErrorResponse. Fix this before PR
                 TriggerDagRun(dag_id="test_dag", run_id="test_run"),
-                
b'{"error":"DAGRUN_ALREADY_EXISTS","detail":null,"type":"ErrorResponse"}\n',
+                {"error": "DAGRUN_ALREADY_EXISTS", "detail": None, "type": 
"ErrorResponse"},
                 "dag_runs.trigger",
                 ("test_dag", "test_run", None, None, False),
                 {},
@@ -1560,7 +1588,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetDagRunState(dag_id="test_dag", run_id="test_run"),
-                b'{"state":"running","type":"DagRunStateResult"}\n',
+                {"state": "running", "type": "DagRunStateResult"},
                 "dag_runs.get_state",
                 ("test_dag", "test_run"),
                 {},
@@ -1570,7 +1598,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetTaskRescheduleStartDate(ti_id=TI_ID),
-                
b'{"start_date":"2024-10-31T12:00:00Z","type":"TaskRescheduleStartDate"}\n',
+                {"start_date": timezone.parse("2024-10-31T12:00:00Z"), "type": 
"TaskRescheduleStartDate"},
                 "task_instances.get_reschedule_start_date",
                 (TI_ID, 1),
                 {},
@@ -1580,7 +1608,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]),
-                b'{"count":2,"type":"TICount"}\n',
+                {"count": 2, "type": "TICount"},
                 "task_instances.get_count",
                 (),
                 {
@@ -1598,7 +1626,7 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetDRCount(dag_id="test_dag", states=["success", "failed"]),
-                b'{"count":2,"type":"DRCount"}\n',
+                {"count": 2, "type": "DRCount"},
                 "dag_runs.get_count",
                 (),
                 {
@@ -1613,7 +1641,10 @@ class TestHandleRequest:
             ),
             pytest.param(
                 GetTaskStates(dag_id="test_dag", task_group_id="test_group"),
-                
b'{"task_states":{"run_id":{"task1":"success","task2":"failed"}},"type":"TaskStatesResult"}\n',
+                {
+                    "task_states": {"run_id": {"task1": "success", "task2": 
"failed"}},
+                    "type": "TaskStatesResult",
+                },
                 "task_instances.get_task_states",
                 (),
                 {
@@ -1636,7 +1667,7 @@ class TestHandleRequest:
                     task_id="test_task",
                     offset=0,
                 ),
-                b'{"root":"test_value","type":"XComSequenceIndexResult"}\n',
+                {"root": "test_value", "type": "XComSequenceIndexResult"},
                 "xcoms.get_sequence_item",
                 ("test_dag", "test_run", "test_task", "test_key", 0),
                 {},
@@ -1645,6 +1676,7 @@ class TestHandleRequest:
                 id="get_xcom_seq_item",
             ),
             pytest.param(
+                # TODO: This should be raise an exception, not returning an 
ErrorResponse. Fix this before PR
                 GetXComSequenceItem(
                     key="test_key",
                     dag_id="test_dag",
@@ -1652,7 +1684,7 @@ class TestHandleRequest:
                     task_id="test_task",
                     offset=2,
                 ),
-                
b'{"error":"XCOM_NOT_FOUND","detail":null,"type":"ErrorResponse"}\n',
+                {"error": "XCOM_NOT_FOUND", "detail": None, "type": 
"ErrorResponse"},
                 "xcoms.get_sequence_item",
                 ("test_dag", "test_run", "test_task", "test_key", 2),
                 {},
@@ -1670,7 +1702,7 @@ class TestHandleRequest:
                     stop=None,
                     step=None,
                 ),
-                b'{"root":["foo","bar"],"type":"XComSequenceSliceResult"}\n',
+                {"root": ["foo", "bar"], "type": "XComSequenceSliceResult"},
                 "xcoms.get_sequence_slice",
                 ("test_dag", "test_run", "test_task", "test_key", None, None, 
None),
                 {},
@@ -1687,7 +1719,7 @@ class TestHandleRequest:
         mocker,
         time_machine,
         message,
-        expected_buffer,
+        expected_body,
         client_attr_path,
         method_arg,
         method_kwarg,
@@ -1715,8 +1747,9 @@ class TestHandleRequest:
         generator = watched_subprocess.handle_requests(log=mocker.Mock())
         # Initialize the generator
         next(generator)
-        msg = message.model_dump_json().encode() + b"\n"
-        generator.send(msg)
+
+        req_frame = _RequestFrame(id=randint(1, 2**32 - 1), 
body=message.model_dump())
+        generator.send(req_frame)
 
         if mask_secret_args:
             mock_mask_secret.assert_called_with(*mask_secret_args)
@@ -1729,33 +1762,23 @@ class TestHandleRequest:
 
         # Read response from the read end of the socket
         read_socket.settimeout(0.1)
-        val = b""
-        try:
-            while not val.endswith(b"\n"):
-                chunk = read_socket.recv(BUFFER_SIZE)
-                if not chunk:
-                    break
-                val += chunk
-        except (BlockingIOError, TimeoutError, socket.timeout):
-            # no response written, valid for some message types like setters 
and TI operations.
-            pass
+        frame_len = int.from_bytes(read_socket.recv(4), "big")
+        bytes = read_socket.recv(frame_len)
+        frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes)
+
+        assert frame.id == req_frame.id
 
         # Verify the response was added to the buffer
-        assert val == expected_buffer
+        assert frame.body == expected_body
 
         # Verify the response is correctly decoded
         # This is important because the subprocess/task runner will read the 
response
         # and deserialize it to the correct message type
 
-        # Only decode the buffer if it contains data. An empty buffer implies 
no response was written.
-        if not val and (mock_response and not isinstance(mock_response, 
OKResponse)):
-            pytest.fail("Expected a response, but got an empty buffer.")
-
-        if val:
+        if frame.body is not None:
             # Using BytesIO to simulate a readable stream for CommsDecoder.
-            input_stream = BytesIO(val)
-            decoder = CommsDecoder(input=input_stream)
-            assert decoder.get_message() == mock_response
+            decoder = CommsDecoder(request_socket=None).body_decoder
+            assert decoder.validate_python(frame.body) == mock_response
 
     def test_handle_requests_api_server_error(self, watched_subprocess, 
mocker):
         """Test that API server errors are properly handled and sent back to 
the task."""
@@ -1777,28 +1800,32 @@ class TestHandleRequest:
 
         next(generator)
 
-        msg = 
SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")).model_dump_json().encode()
 + b"\n"
-        generator.send(msg)
+        msg = SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z"))
+        req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=msg)
+        generator.send(req_frame)
 
-        # Read response directly from the reader socket
+        # Read response from the read end of the socket
         read_socket.settimeout(0.1)
-        val = b""
-        try:
-            while not val.endswith(b"\n"):
-                val += read_socket.recv(4096)
-        except (BlockingIOError, TimeoutError):
-            pass
-
-        assert val == (
-            
b'{"error":"API_SERVER_ERROR","detail":{"status_code":500,"message":"API Server 
Error",'
-            b'"detail":{"detail":"Internal Server 
Error"}},"type":"ErrorResponse"}\n'
-        )
+        frame_len = int.from_bytes(read_socket.recv(4), "big")
+        bytes = read_socket.recv(frame_len)
+        frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(bytes)
+
+        assert frame.id == req_frame.id
+
+        assert frame.error == {
+            "error": "API_SERVER_ERROR",
+            "detail": {
+                "status_code": 500,
+                "message": "API Server Error",
+                "detail": {"detail": "Internal Server Error"},
+            },
+            "type": "ErrorResponse",
+        }
 
         # Verify the error can be decoded correctly
-        input_stream = BytesIO(val)
-        decoder = CommsDecoder(input=input_stream)
+        comms = CommsDecoder(request_socket=None)
         with pytest.raises(AirflowRuntimeError) as exc_info:
-            decoder.get_message()
+            comms._from_frame(frame)
 
         assert exc_info.value.error.error == ErrorType.API_SERVER_ERROR
         assert exc_info.value.error.detail == {
@@ -1871,7 +1898,6 @@ class TestInProcessTestSupervisor:
         supervisor = MinimalSupervisor(
             id="test",
             pid=123,
-            requests_fd=-1,
             process=MagicMock(),
             process_log=MagicMock(),
             client=MagicMock(),
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index a994e995bac..ab91dbcf14a 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -22,11 +22,9 @@ import functools
 import json
 import os
 import textwrap
-import uuid
 from collections.abc import Iterable
 from datetime import datetime, timedelta
 from pathlib import Path
-from socket import socketpair
 from typing import TYPE_CHECKING
 from unittest import mock
 from unittest.mock import patch
@@ -101,7 +99,6 @@ from airflow.sdk.execution_time.context import (
     VariableAccessor,
 )
 from airflow.sdk.execution_time.task_runner import (
-    CommsDecoder,
     RuntimeTaskInstance,
     TaskRunnerMarker,
     _push_xcom_if_needed,
@@ -137,47 +134,6 @@ class CustomOperator(BaseOperator):
         print(f"Hello World {task_id}!")
 
 
-class TestCommsDecoder:
-    """Test the communication between the subprocess and the "supervisor"."""
-
-    @pytest.mark.usefixtures("disable_capturing")
-    def test_recv_StartupDetails(self):
-        r, w = socketpair()
-        # Create a valid FD for the decoder to open
-        _, w2 = socketpair()
-
-        w.makefile("wb").write(
-            b'{"type":"StartupDetails", "ti": {'
-            b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", 
"try_number": 1, "run_id": "b", '
-            b'"dag_id": "c"}, 
"ti_context":{"dag_run":{"dag_id":"c","run_id":"b",'
-            b'"logical_date":"2024-12-01T01:00:00Z",'
-            
b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",'
-            
b'"start_date":"2024-12-01T01:00:00Z","run_after":"2024-12-01T01:00:00Z","end_date":null,'
-            b'"run_type":"manual","conf":null,"consumed_asset_events":[]},'
-            
b'"max_tries":0,"should_retry":false,"variables":null,"connections":null},"file":
 "/dev/null",'
-            b'"start_date":"2024-12-01T01:00:00Z", "dag_rel_path": 
"/dev/null", "bundle_info": {"name": '
-            b'"any-name", "version": "any-version"}, "requests_fd": '
-            + str(w2.fileno()).encode("ascii")
-            + b"}\n"
-        )
-
-        decoder = CommsDecoder(input=r.makefile("r"))
-
-        msg = decoder.get_message()
-        assert isinstance(msg, StartupDetails)
-        assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab")
-        assert msg.ti.task_id == "a"
-        assert msg.ti.dag_id == "c"
-        assert msg.dag_rel_path == "/dev/null"
-        assert msg.bundle_info == BundleInfo(name="any-name", 
version="any-version")
-        assert msg.start_date == timezone.datetime(2024, 12, 1, 1)
-
-        # Since this was a StartupDetails message, the decoder should open the 
other socket
-        assert decoder.request_socket is not None
-        assert decoder.request_socket.writable()
-        assert decoder.request_socket.fileno() == w2.fileno()
-
-
 def test_parse(test_dags_dir: Path, make_ti_context):
     """Test that checks parsing of a basic dag with an un-mocked parse."""
     what = StartupDetails(
@@ -190,7 +146,6 @@ def test_parse(test_dags_dir: Path, make_ti_context):
         ),
         dag_rel_path="super_basic.py",
         bundle_info=BundleInfo(name="my-bundle", version=None),
-        requests_fd=0,
         ti_context=make_ti_context(),
         start_date=timezone.utcnow(),
     )
@@ -246,7 +201,6 @@ def test_parse_not_found(test_dags_dir: Path, 
make_ti_context, dag_id, task_id,
         ),
         dag_rel_path="super_basic.py",
         bundle_info=BundleInfo(name="my-bundle", version=None),
-        requests_fd=0,
         ti_context=make_ti_context(),
         start_date=timezone.utcnow(),
     )
@@ -300,7 +254,6 @@ def test_parse_module_in_bundle_root(tmp_path: Path, 
make_ti_context):
         ),
         dag_rel_path="path_test.py",
         bundle_info=BundleInfo(name="my-bundle", version=None),
-        requests_fd=0,
         ti_context=make_ti_context(),
         start_date=timezone.utcnow(),
     )
@@ -571,7 +524,6 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, 
mock_supervisor_comm
         ),
         bundle_info=FAKE_BUNDLE,
         dag_rel_path="",
-        requests_fd=0,
         ti_context=make_ti_context(),
         start_date=timezone.utcnow(),
     )
@@ -687,7 +639,6 @@ def test_startup_and_run_dag_with_rtif(
         ),
         dag_rel_path="",
         bundle_info=FAKE_BUNDLE,
-        requests_fd=0,
         ti_context=make_ti_context(),
         start_date=timezone.utcnow(),
     )
@@ -832,7 +783,6 @@ def test_dag_parsing_context(make_ti_context, 
mock_supervisor_comms, monkeypatch
         ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, 
run_id="c", try_number=1),
         dag_rel_path="dag_parsing_context.py",
         bundle_info=BundleInfo(name="my-bundle", version=None),
-        requests_fd=0,
         ti_context=make_ti_context(dag_id=dag_id, run_id="c"),
         start_date=timezone.utcnow(),
     )
@@ -2190,7 +2140,6 @@ class TestTaskRunnerCallsListeners:
             ),
             dag_rel_path="",
             bundle_info=FAKE_BUNDLE,
-            requests_fd=0,
             ti_context=make_ti_context(),
             start_date=timezone.utcnow(),
         )

Reply via email to