This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1ab2474aeeb Fix bug with in-process request handling for `dag.test`
(#50419)
1ab2474aeeb is described below
commit 1ab2474aeeba23be3248cb5101e2fe7740647182
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat May 10 11:20:18 2025 +0530
Fix bug with in-process request handling for `dag.test` (#50419)
---
.../src/airflow/sdk/execution_time/supervisor.py | 17 ++--
.../task_sdk/execution_time/test_supervisor.py | 92 +++++++++++++++++++++-
2 files changed, 103 insertions(+), 6 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 0e7b7f54cd1..90ab13305bc 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -1320,15 +1320,22 @@ def set_supervisor_comms(temp_comms):
"""
from airflow.sdk.execution_time import task_runner
- old = getattr(task_runner, "SUPERVISOR_COMMS", None)
- task_runner.SUPERVISOR_COMMS = temp_comms
+ sentinel = object()
+ old = getattr(task_runner, "SUPERVISOR_COMMS", sentinel)
+
+ if temp_comms is not None:
+ task_runner.SUPERVISOR_COMMS = temp_comms
+ elif old is not sentinel:
+ delattr(task_runner, "SUPERVISOR_COMMS")
+
try:
yield
finally:
- if old is not None:
- task_runner.SUPERVISOR_COMMS = old
+ if old is sentinel:
+ if hasattr(task_runner, "SUPERVISOR_COMMS"):
+ delattr(task_runner, "SUPERVISOR_COMMS")
else:
- delattr(task_runner, "SUPERVISOR_COMMS")
+ task_runner.SUPERVISOR_COMMS = old
def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult:
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 5690f9b418e..1e6aec6bed0 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -51,6 +51,7 @@ from airflow.sdk.api.datamodels._generated import (
TaskInstanceState,
)
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
+from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.comms import (
AssetEventsResult,
AssetResult,
@@ -92,7 +93,15 @@ from airflow.sdk.execution_time.comms import (
XComResult,
)
from airflow.sdk.execution_time.secrets_masker import SecretsMasker
-from airflow.sdk.execution_time.supervisor import BUFFER_SIZE,
ActivitySubprocess, mkpipe, supervise
+from airflow.sdk.execution_time.supervisor import (
+ BUFFER_SIZE,
+ ActivitySubprocess,
+ InProcessSupervisorComms,
+ InProcessTestSupervisor,
+ mkpipe,
+ set_supervisor_comms,
+ supervise,
+)
from airflow.sdk.execution_time.task_runner import CommsDecoder
from airflow.utils import timezone, timezone as tz
@@ -1600,3 +1609,84 @@ class TestHandleRequest:
"message": str(error),
"detail": error.response.json(),
}
+
+
+class TestSetSupervisorComms:
+ class DummyComms:
+ pass
+
+ @pytest.fixture(autouse=True)
+ def cleanup_supervisor_comms(self):
+ # Ensure clean state before/after test
+ if hasattr(task_runner, "SUPERVISOR_COMMS"):
+ delattr(task_runner, "SUPERVISOR_COMMS")
+ yield
+ if hasattr(task_runner, "SUPERVISOR_COMMS"):
+ delattr(task_runner, "SUPERVISOR_COMMS")
+
+ def test_set_supervisor_comms_overrides_and_restores(self):
+ task_runner.SUPERVISOR_COMMS = self.DummyComms()
+ original = task_runner.SUPERVISOR_COMMS
+ replacement = self.DummyComms()
+
+ with set_supervisor_comms(replacement):
+ assert task_runner.SUPERVISOR_COMMS is replacement
+ assert task_runner.SUPERVISOR_COMMS is original
+
+ def test_set_supervisor_comms_sets_temporarily_when_not_set(self):
+ assert not hasattr(task_runner, "SUPERVISOR_COMMS")
+ replacement = self.DummyComms()
+
+ with set_supervisor_comms(replacement):
+ assert task_runner.SUPERVISOR_COMMS is replacement
+ assert not hasattr(task_runner, "SUPERVISOR_COMMS")
+
+ def test_set_supervisor_comms_unsets_temporarily_when_not_set(self):
+ assert not hasattr(task_runner, "SUPERVISOR_COMMS")
+
+ # This will delete an attribute that isn't set, and restore it likewise
+ with set_supervisor_comms(None):
+ assert not hasattr(task_runner, "SUPERVISOR_COMMS")
+
+ assert not hasattr(task_runner, "SUPERVISOR_COMMS")
+
+
+class TestInProcessTestSupervisor:
+ def test_inprocess_supervisor_comms_roundtrip(self):
+ """
+ Test that InProcessSupervisorComms correctly sends a message to the
supervisor,
+ and that the supervisor's response is received via the message queue.
+
+ This verifies the end-to-end communication flow:
+ - send_request() dispatches a message to the supervisor
+ - the supervisor handles the request and appends a response via
send_msg()
+ - get_message() returns the enqueued response
+
+ This test mocks the supervisor's `_handle_request()` method to simulate
+ a simple echo-style response, avoiding full task execution.
+ """
+
+ class MinimalSupervisor(InProcessTestSupervisor):
+ def _handle_request(self, msg, log):
+ resp = VariableResult(key=msg.key, value="value")
+ self.send_msg(resp)
+
+ supervisor = MinimalSupervisor(
+ id="test",
+ pid=123,
+ requests_fd=-1,
+ process=MagicMock(),
+ process_log=MagicMock(),
+ client=MagicMock(),
+ )
+ comms = InProcessSupervisorComms(supervisor=supervisor)
+ supervisor.comms = comms
+
+ test_msg = GetVariable(key="test_key")
+
+ comms.send_request(log=MagicMock(), msg=test_msg)
+
+ # Ensure we got back what we expect
+ response = comms.get_message()
+ assert isinstance(response, VariableResult)
+ assert response.value == "value"