jedcunningham commented on code in PR #66412:
URL: https://github.com/apache/airflow/pull/66412#discussion_r3191440581
##########
airflow-core/tests/unit/jobs/test_triggerer_job.py:
##########
@@ -1853,3 +1906,175 @@ def
test_make_trigger_span_sets_only_trigger_name_without_ti(self):
assert attrs["airflow.trigger.name"] == "OnlyTrigger"
assert "airflow.dag_id" not in attrs
assert "airflow.task_id" not in attrs
+
+
+def _read_frame_sync(sock) -> _RequestFrame | None:
+ """Read a length-prefixed msgpack frame from a blocking socket."""
+ lb = b""
+ while len(lb) < 4:
+ chunk = sock.recv(4 - len(lb))
+ if not chunk:
+ return None
+ lb += chunk
+ n = int.from_bytes(lb, "big")
+ data = b""
+ while len(data) < n:
+ chunk = sock.recv(n - len(data))
+ if not chunk:
+ return None
+ data += chunk
+ return msgspec.msgpack.decode(data, type=_RequestFrame)
+
+
+@pytest_asyncio.fixture
+async def decoder_pair():
+ """Yield (decoder, server_sock). Caller owns closing."""
+ server_sock, client_sock = socketpair()
+ reader, writer = await asyncio.open_connection(sock=client_sock)
+ decoder = TriggerCommsDecoder(async_writer=writer, async_reader=reader,
socket=client_sock)
+ await decoder.start_reader()
+ yield decoder, server_sock
+ if decoder._reader_task:
+ if not decoder._reader_task.done():
+ decoder._reader_task.cancel()
+ with contextlib.suppress(asyncio.CancelledError, Exception):
+ await decoder._reader_task
+ writer.close()
+ server_sock.close()
+
+
[email protected]
[email protected]_timeout(15)
+async def test_all_send_paths_concurrent(decoder_pair):
+ """
+ All four send() paths running concurrently with responses returned out of
order:
+
+ 1. asend() directly from async code — pure-async path
+ 2. send() via asyncio.to_thread() — mirrors
apache/airflow#63913:
+
sync_to_async(hook_class)() → get_connection()
+ →
SUPERVISOR_COMMS.send() from a thread pool thread
+ 3. send() from the event-loop thread — mirrors
apache/airflow#63760:
+ via greenback async_to_sync raised
RuntimeError in same thread
+ 4. async_to_sync(asend)() from a thread — trigger code that wraps
an async fn which
+ internally calls asend;
bridges via wrap_future
+
+ The concurrent mix with shuffled responses also covers
apache/airflow#65286: the
+ _thread_lock + async_to_sync approach stalled the triggerer under this
exact load pattern.
+ """
+ decoder, server_sock = decoder_pair
+ N = 5
+ N_TOTAL = N * 4
+
+ def supervisor():
+ frames = []
+ for _ in range(N_TOTAL):
+ f = _read_frame_sync(server_sock)
+ if f is None:
+ break
+ frames.append(f)
+ random.shuffle(frames)
+ for f in frames:
+ server_sock.sendall(
+ _ResponseFrame(
+ id=f.id,
+ body={"type": "TriggerStateSync", "to_create": [],
"to_cancel": []},
+ ).as_bytes()
+ )
+
+ sup = threading.Thread(target=supervisor, daemon=True)
+ sup.start()
+
+ async def async_send(idx):
+ return await decoder.asend(messages.TriggerStateChanges(events=None,
finished=[idx], failures=None))
+
+ async def from_thread_send(idx):
+ # In production this path is taken by asgiref's own thread pool
(sync_to_async),
+ # which is invisible to asyncio's default executor. We avoid
asyncio.to_thread()
+ # here because on Python < 3.12 loop.shutdown_default_executor() has
no timeout
+ # and hangs if any executor threads are still alive at loop teardown.
+ # TODO: simplify with asyncio.to_thread() when Python 3.12 is the
minimum.
+ loop = asyncio.get_running_loop()
+ fut: asyncio.Future[messages.TriggerStateSync] = loop.create_future()
+
+ def sync_send():
+ try:
+ result = decoder.send(
+ messages.TriggerStateChanges(events=None, finished=[N +
idx], failures=None)
+ )
+ loop.call_soon_threadsafe(fut.set_result, result)
+ except Exception as exc:
+ loop.call_soon_threadsafe(fut.set_exception, exc)
+
+ threading.Thread(target=sync_send, daemon=True).start()
+ return await fut
+
+ async def greenback_send(idx):
+ await greenback.ensure_portal()
+ return decoder.send(messages.TriggerStateChanges(events=None,
finished=[2 * N + idx], failures=None))
+
+ async def async_to_sync_send(idx):
+ # Same executor-avoidance reason as from_thread_send above.
+ # TODO: simplify with asyncio.to_thread() when Python 3.12 is the
minimum.
+ loop = asyncio.get_running_loop()
+ fut: asyncio.Future[messages.TriggerStateSync] = loop.create_future()
+
+ def thread_fn():
+ try:
+ result = async_to_sync(decoder.asend)(
+ messages.TriggerStateChanges(events=None, finished=[3 * N
+ idx], failures=None)
+ )
+ loop.call_soon_threadsafe(fut.set_result, result)
+ except Exception as exc:
+ loop.call_soon_threadsafe(fut.set_exception, exc)
+
+ threading.Thread(target=thread_fn, daemon=True).start()
+ return await fut
+
+ results = await asyncio.gather(
+ *[asyncio.create_task(async_send(i)) for i in range(N)],
+ *[asyncio.create_task(from_thread_send(i)) for i in range(N)],
+ *[asyncio.create_task(greenback_send(i)) for i in range(N)],
+ *[asyncio.create_task(async_to_sync_send(i)) for i in range(N)],
+ return_exceptions=True,
+ )
+
+ sup.join(timeout=5)
+
+ errors = [r for r in results if isinstance(r, Exception)]
+ assert not errors, f"errors: {errors}"
+ assert len(results) == N_TOTAL
+ assert all(isinstance(r, messages.TriggerStateSync) for r in results)
+
+
[email protected]
+async def test_connection_close_cancels_pending(decoder_pair):
+ """When the connection closes while asend() is awaiting, the future is
cancelled."""
+ decoder, server_sock = decoder_pair
+
+ task = asyncio.create_task(
+ decoder.asend(messages.TriggerStateChanges(events=None, finished=[1],
failures=None))
+ )
+ await asyncio.sleep(0)
+
+ server_sock.close()
+
+ with pytest.raises((asyncio.CancelledError, Exception)):
+ await asyncio.wait_for(task, timeout=5)
Review Comment:
@parkhojeong want to open a follow up to change this? Thanks :)
##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -951,46 +977,80 @@ class TriggerCommsDecoder(CommsDecoder[ToTriggerRunner,
ToTriggerSupervisor]):
factory=lambda: TypeAdapter(ToTriggerRunner), repr=False
)
- def _read_frame(self):
- from asgiref.sync import async_to_sync
-
- with self._thread_lock:
- return async_to_sync(self._aread_frame)()
-
- def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
- from asgiref.sync import async_to_sync
-
- with self._thread_lock:
- return async_to_sync(self.asend)(msg)
+ _pending: dict[int, asyncio.Future] = attrs.field(factory=dict, repr=False)
+ _loop: asyncio.AbstractEventLoop | None = attrs.field(default=None,
repr=False)
+ _loop_thread_id: int | None = attrs.field(default=None, repr=False)
+ _reader_task: asyncio.Task | None = attrs.field(default=None, repr=False)
async def _aread_frame(self):
try:
len_bytes = await self._async_reader.readexactly(4)
except ConnectionResetError:
asyncio.current_task().cancel("Supervisor closed")
+ raise
length = int.from_bytes(len_bytes, byteorder="big")
if length >= 2**32:
raise OverflowError(f"Refusing to receive messages larger than
4GiB {length=}")
-
buffer = await self._async_reader.readexactly(length)
return self.resp_decoder.decode(buffer)
- async def _aget_response(self, expect_id: int) -> ToTriggerRunner | None:
- frame = await self._aread_frame()
- if frame.id != expect_id:
- # Given the lock we take out in `asend`, this _shouldn't_ be
possible, but I'd rather fail with
- # this explicit error return the wrong type of message back to a
Trigger
- raise RuntimeError(f"Response read out of order! Got {frame.id=},
{expect_id=}")
- return self._from_frame(frame)
+ async def _reader_loop(self) -> None:
+ try:
+ while True:
+ frame = await self._aread_frame()
+ future = self._pending.pop(frame.id, None)
+ if future is not None and not future.done():
+ future.set_result(frame)
+ else:
+ self.log.warning("Got response for unknown request frame",
frame_id=frame.id)
+ finally:
+ for fut in self._pending.values():
+ if not fut.done():
+ fut.cancel("Reader loop exited")
+ self._pending.clear()
- async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
- frame = _RequestFrame(id=next(self.id_counter), body=msg.model_dump())
- bytes = frame.as_bytes()
+ async def start_reader(self) -> None:
+ self._loop = asyncio.get_running_loop()
+ self._loop_thread_id = threading.get_ident()
+ self._reader_task = asyncio.create_task(self._reader_loop(),
name="trigger-comms-reader")
- async with self._async_lock:
- self._async_writer.write(bytes)
+ def send(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
+ if self._loop is None:
+ raise RuntimeError("start_reader() must be called before send()")
+ if threading.get_ident() == self._loop_thread_id:
+ # Called from the event loop thread itself (e.g. a trigger calling
a sync SDK method
+ # directly from async def run()).
run_coroutine_threadsafe(...).result() would deadlock
+ # here because .result() blocks the thread the event loop runs on.
+ # greenback.await_() teleports the coroutine back into the running
loop instead.
+ if not greenback.has_portal():
+ raise RuntimeError(
+ "Sync SDK methods (e.g. get_connection(), get_variable())
cannot be called "
+ "from a trigger's async def run() when
AIRFLOW_DISABLE_GREENBACK_PORTAL is "
+ "set. Either remove that environment variable, or use the
async equivalent "
+ "(e.g. aget_connection(), aget_variable())."
+ )
+ return greenback.await_(self.asend(msg))
+ return asyncio.run_coroutine_threadsafe(self.asend(msg),
self._loop).result()
- return await self._aget_response(frame.id)
+ async def asend(self, msg: ToTriggerSupervisor) -> ToTriggerRunner | None:
+ if self._loop is None:
+ raise RuntimeError("start_reader() must be called before asend()")
+ current_loop = asyncio.get_running_loop()
+ if self._loop is not None and current_loop is not self._loop:
Review Comment:
@parkhojeong want to open a follow up to change this? Thanks :)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]