kaxil commented on code in PR #61627:
URL: https://github.com/apache/airflow/pull/61627#discussion_r2940846223


##########
task-sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -2097,7 +2097,34 @@ def supervise(
             sentry_integration=sentry_integration,
         )
 
-        exit_code = process.wait()
+        # Forward termination signals to the task subprocess so that the 
operator's
+        # on_kill() hook is invoked on graceful shutdown (e.g. K8s pod 
SIGTERM).
+        # Without this, the supervisor exits on SIGTERM without notifying the 
child,
+        # leaving spawned resources (pods, subprocesses, etc.) running.
+        prev_sigterm = signal.getsignal(signal.SIGTERM)

Review Comment:
   `signal.signal()` returns the previous handler, so you can skip the separate 
`getsignal()` calls. Something like `prev_sigterm = 
signal.signal(signal.SIGTERM, _forward_signal)` on line 2119 (and same for 
SIGINT). That way you save 2 lines and 2 syscalls.



##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -485,6 +485,102 @@ def on_kill(self) -> None:
         captured = capfd.readouterr()
         assert "On kill hook called!" in captured.out
 
+    def test_on_kill_hook_called_when_supervisor_receives_sigterm(
+        self,
+        client_with_ti_start,
+        mocked_parse,
+        make_ti_context,
+        mock_supervisor_comms,
+        create_runtime_ti,
+        make_ti_context_dict,
+        capfd,
+    ):
+        """Test that SIGTERM to the supervisor process is forwarded to the 
task subprocess.
+
+        This simulates what happens when Kubernetes sends SIGTERM to the 
worker pod:
+        the supervisor should forward the signal to the child process so that 
the
+        operator's on_kill() hook is triggered for resource cleanup.
+        """
+        import threading
+
+        ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{ti_id}/run":
+                return httpx.Response(200, json=make_ti_context_dict())
+            return httpx.Response(status_code=204)
+
+        def subprocess_main():
+            CommsDecoder()._get_response()
+
+            class CustomOperator(BaseOperator):
+                def execute(self, context):
+                    for i in range(30):
+                        print(f"Iteration {i}")
+                        sleep(1)
+
+                def on_kill(self) -> None:
+                    print("On kill hook called via signal forwarding!")
+
+            task = CustomOperator(task_id="test-signal-forward")
+            runtime_ti = create_runtime_ti(
+                dag_id="c",
+                task=task,
+                conf={},
+            )
+            run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        proc = ActivitySubprocess.start(
+            dag_rel_path=os.devnull,
+            bundle_info=FAKE_BUNDLE,
+            what=TaskInstance(
+                id=ti_id,
+                task_id="b",
+                dag_id="c",
+                run_id="d",
+                try_number=1,
+                dag_version_id=uuid7(),
+            ),
+            client=make_client(transport=httpx.MockTransport(handle_request)),
+            target=subprocess_main,
+        )
+
+        # Install signal forwarding handler (same mechanism as supervise() 
does)
+        prev_sigterm = signal.getsignal(signal.SIGTERM)

Review Comment:
   This test re-implements the signal forwarding logic instead of exercising 
the production code path in `supervise()`. If someone breaks the handler 
registration in `supervise()`, this test still passes. One approach: have the 
subprocess send SIGTERM to its parent (`os.kill(os.getppid(), signal.SIGTERM)`) 
while `supervise()` is running, similar to how the existing 
`test_on_kill_hook_called_when_sigkilled` test works from the child side. That 
would test the actual wiring.



##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -485,6 +485,102 @@ def on_kill(self) -> None:
         captured = capfd.readouterr()
         assert "On kill hook called!" in captured.out
 
+    def test_on_kill_hook_called_when_supervisor_receives_sigterm(
+        self,
+        client_with_ti_start,
+        mocked_parse,
+        make_ti_context,
+        mock_supervisor_comms,
+        create_runtime_ti,
+        make_ti_context_dict,
+        capfd,
+    ):
+        """Test that SIGTERM to the supervisor process is forwarded to the 
task subprocess.
+
+        This simulates what happens when Kubernetes sends SIGTERM to the 
worker pod:
+        the supervisor should forward the signal to the child process so that 
the
+        operator's on_kill() hook is triggered for resource cleanup.
+        """
+        import threading
+
+        ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == f"/task-instances/{ti_id}/run":
+                return httpx.Response(200, json=make_ti_context_dict())
+            return httpx.Response(status_code=204)
+
+        def subprocess_main():
+            CommsDecoder()._get_response()
+
+            class CustomOperator(BaseOperator):
+                def execute(self, context):
+                    for i in range(30):
+                        print(f"Iteration {i}")
+                        sleep(1)
+
+                def on_kill(self) -> None:
+                    print("On kill hook called via signal forwarding!")
+
+            task = CustomOperator(task_id="test-signal-forward")
+            runtime_ti = create_runtime_ti(
+                dag_id="c",
+                task=task,
+                conf={},
+            )
+            run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        proc = ActivitySubprocess.start(
+            dag_rel_path=os.devnull,
+            bundle_info=FAKE_BUNDLE,
+            what=TaskInstance(
+                id=ti_id,
+                task_id="b",
+                dag_id="c",
+                run_id="d",
+                try_number=1,
+                dag_version_id=uuid7(),
+            ),
+            client=make_client(transport=httpx.MockTransport(handle_request)),
+            target=subprocess_main,
+        )
+
+        # Install signal forwarding handler (same mechanism as supervise() 
does)
+        prev_sigterm = signal.getsignal(signal.SIGTERM)
+
+        def _forward_signal(signum, frame):
+            try:
+                os.kill(proc.pid, signum)
+            except ProcessLookupError:
+                pass
+
+        signal.signal(signal.SIGTERM, _forward_signal)
+
+        # Send SIGTERM to ourselves (the supervisor) from a background thread,
+        # giving the subprocess time to start executing first. Then forcefully
+        # terminate the subprocess so the test does not hang.
+        def send_signals():
+            sleep(2)

Review Comment:
   The `sleep(2)` calls for synchronization are likely to cause flakes on slow 
CI. If the subprocess hasn't started executing within 2s, or if `on_kill()` 
takes longer than 2s, the test fails or hangs. A sentinel output line that the 
parent waits for before sending SIGTERM would be more reliable than a fixed 
sleep.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to