This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1ab2474aeeb Fix bug with in-process request handling for `dag.test` 
(#50419)
1ab2474aeeb is described below

commit 1ab2474aeeba23be3248cb5101e2fe7740647182
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat May 10 11:20:18 2025 +0530

    Fix bug with in-process request handling for `dag.test` (#50419)
---
 .../src/airflow/sdk/execution_time/supervisor.py   | 17 ++--
 .../task_sdk/execution_time/test_supervisor.py     | 92 +++++++++++++++++++++-
 2 files changed, 103 insertions(+), 6 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 0e7b7f54cd1..90ab13305bc 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -1320,15 +1320,22 @@ def set_supervisor_comms(temp_comms):
     """
     from airflow.sdk.execution_time import task_runner
 
-    old = getattr(task_runner, "SUPERVISOR_COMMS", None)
-    task_runner.SUPERVISOR_COMMS = temp_comms
+    sentinel = object()
+    old = getattr(task_runner, "SUPERVISOR_COMMS", sentinel)
+
+    if temp_comms is not None:
+        task_runner.SUPERVISOR_COMMS = temp_comms
+    elif old is not sentinel:
+        delattr(task_runner, "SUPERVISOR_COMMS")
+
     try:
         yield
     finally:
-        if old is not None:
-            task_runner.SUPERVISOR_COMMS = old
+        if old is sentinel:
+            if hasattr(task_runner, "SUPERVISOR_COMMS"):
+                delattr(task_runner, "SUPERVISOR_COMMS")
         else:
-            delattr(task_runner, "SUPERVISOR_COMMS")
+            task_runner.SUPERVISOR_COMMS = old
 
 
 def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult:
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 5690f9b418e..1e6aec6bed0 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -51,6 +51,7 @@ from airflow.sdk.api.datamodels._generated import (
     TaskInstanceState,
 )
 from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
+from airflow.sdk.execution_time import task_runner
 from airflow.sdk.execution_time.comms import (
     AssetEventsResult,
     AssetResult,
@@ -92,7 +93,15 @@ from airflow.sdk.execution_time.comms import (
     XComResult,
 )
 from airflow.sdk.execution_time.secrets_masker import SecretsMasker
-from airflow.sdk.execution_time.supervisor import BUFFER_SIZE, 
ActivitySubprocess, mkpipe, supervise
+from airflow.sdk.execution_time.supervisor import (
+    BUFFER_SIZE,
+    ActivitySubprocess,
+    InProcessSupervisorComms,
+    InProcessTestSupervisor,
+    mkpipe,
+    set_supervisor_comms,
+    supervise,
+)
 from airflow.sdk.execution_time.task_runner import CommsDecoder
 from airflow.utils import timezone, timezone as tz
 
@@ -1600,3 +1609,84 @@ class TestHandleRequest:
             "message": str(error),
             "detail": error.response.json(),
         }
+
+
+class TestSetSupervisorComms:
+    class DummyComms:
+        pass
+
+    @pytest.fixture(autouse=True)
+    def cleanup_supervisor_comms(self):
+        # Ensure clean state before/after test
+        if hasattr(task_runner, "SUPERVISOR_COMMS"):
+            delattr(task_runner, "SUPERVISOR_COMMS")
+        yield
+        if hasattr(task_runner, "SUPERVISOR_COMMS"):
+            delattr(task_runner, "SUPERVISOR_COMMS")
+
+    def test_set_supervisor_comms_overrides_and_restores(self):
+        task_runner.SUPERVISOR_COMMS = self.DummyComms()
+        original = task_runner.SUPERVISOR_COMMS
+        replacement = self.DummyComms()
+
+        with set_supervisor_comms(replacement):
+            assert task_runner.SUPERVISOR_COMMS is replacement
+        assert task_runner.SUPERVISOR_COMMS is original
+
+    def test_set_supervisor_comms_sets_temporarily_when_not_set(self):
+        assert not hasattr(task_runner, "SUPERVISOR_COMMS")
+        replacement = self.DummyComms()
+
+        with set_supervisor_comms(replacement):
+            assert task_runner.SUPERVISOR_COMMS is replacement
+        assert not hasattr(task_runner, "SUPERVISOR_COMMS")
+
+    def test_set_supervisor_comms_unsets_temporarily_when_not_set(self):
+        assert not hasattr(task_runner, "SUPERVISOR_COMMS")
+
+        # This will delete an attribute that isn't set, and restore it likewise
+        with set_supervisor_comms(None):
+            assert not hasattr(task_runner, "SUPERVISOR_COMMS")
+
+        assert not hasattr(task_runner, "SUPERVISOR_COMMS")
+
+
+class TestInProcessTestSupervisor:
+    def test_inprocess_supervisor_comms_roundtrip(self):
+        """
+        Test that InProcessSupervisorComms correctly sends a message to the 
supervisor,
+        and that the supervisor's response is received via the message queue.
+
+        This verifies the end-to-end communication flow:
+        - send_request() dispatches a message to the supervisor
+        - the supervisor handles the request and appends a response via 
send_msg()
+        - get_message() returns the enqueued response
+
+        This test mocks the supervisor's `_handle_request()` method to simulate
+        a simple echo-style response, avoiding full task execution.
+        """
+
+        class MinimalSupervisor(InProcessTestSupervisor):
+            def _handle_request(self, msg, log):
+                resp = VariableResult(key=msg.key, value="value")
+                self.send_msg(resp)
+
+        supervisor = MinimalSupervisor(
+            id="test",
+            pid=123,
+            requests_fd=-1,
+            process=MagicMock(),
+            process_log=MagicMock(),
+            client=MagicMock(),
+        )
+        comms = InProcessSupervisorComms(supervisor=supervisor)
+        supervisor.comms = comms
+
+        test_msg = GetVariable(key="test_key")
+
+        comms.send_request(log=MagicMock(), msg=test_msg)
+
+        # Ensure we got back what we expect
+        response = comms.get_message()
+        assert isinstance(response, VariableResult)
+        assert response.value == "value"

Reply via email to