kaxil commented on code in PR #61627:
URL: https://github.com/apache/airflow/pull/61627#discussion_r3295434806


##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -488,6 +488,98 @@ def on_kill(self) -> None:
         captured = capfd.readouterr()
         assert "On kill hook called!" in captured.out
 
+    def test_on_kill_hook_called_when_supervisor_receives_sigterm(
+        self,
+        client_with_ti_start,
+        mocked_parse,
+        make_ti_context,
+        mock_supervisor_comms,
+        create_runtime_ti,
+        make_ti_context_dict,
+        capfd,
+    ):
+        """Test that SIGTERM to the supervisor process is forwarded to the 
task subprocess.
+
+        This simulates what happens when Kubernetes sends SIGTERM to the 
worker pod:
+        the supervisor should forward the signal to the child process so that 
the
+        operator's on_kill() hook is triggered for resource cleanup.
+        """
+        import threading
+
+        ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{ti_id}/run":
+                return httpx.Response(200, json=make_ti_context_dict())
+            return httpx.Response(status_code=204)
+
+        def subprocess_main():
+            CommsDecoder()._get_response()
+
+            class CustomOperator(BaseOperator):
+                def execute(self, context):
+                    for i in range(30):
+                        print(f"Iteration {i}")
+                        sleep(1)
+
+                def on_kill(self) -> None:
+                    print("On kill hook called via signal forwarding!")
+
+            task = CustomOperator(task_id="test-signal-forward")
+            runtime_ti = create_runtime_ti(
+                dag_id="c",
+                task=task,
+                conf={},
+            )
+            run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        proc = ActivitySubprocess.start(
+            dag_rel_path=os.devnull,
+            bundle_info=FAKE_BUNDLE,
+            what=TaskInstance(
+                id=ti_id,
+                task_id="b",
+                dag_id="c",
+                run_id="d",
+                try_number=1,
+                dag_version_id=uuid7(),
+            ),
+            client=make_client(transport=httpx.MockTransport(handle_request)),
+            target=subprocess_main,
+        )
+
+        # Install signal forwarding handler (same mechanism as supervise() 
does)
+        prev_sigterm = signal.getsignal(signal.SIGTERM)
+
+        def _forward_signal(signum, frame):
+            with suppress(ProcessLookupError):
+                os.kill(proc.pid, signum)
+
+        signal.signal(signal.SIGTERM, _forward_signal)

Review Comment:
   Re-flagging from prior review: this test still installs its own 
`_forward_signal` handler in the test process and SIGTERMs itself. The 
production handler registered inside `supervise_task()` is never exercised. If 
the forwarding logic in `supervisor.py` regressed to a no-op, this test would 
still pass because the test's own handler does the forwarding.
   
   To actually cover the production path, the SIGTERM needs to arrive while 
`supervise_task()` is the one with `signal.signal(...)` installed. One way: run 
`supervise_task()` in a subprocess and SIGTERM that subprocess, then assert on 
its captured stdout. The current shape just re-implements the feature and 
asserts the re-implementation works.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to