kaxil commented on code in PR #44465: URL: https://github.com/apache/airflow/pull/44465#discussion_r1864034492
########## task_sdk/tests/execution_time/test_supervisor.py: ########## @@ -430,6 +432,216 @@ def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, t } in captured_logs +class TestWatchedSubprocessKill: + @pytest.fixture + def mock_process(self, mocker): + process = mocker.Mock(spec=psutil.Process) + process.pid = 12345 + return process + + @pytest.fixture + def watched_subprocess(self, mocker, mock_process): + proc = WatchedSubprocess( + ti_id=TI_ID, + pid=12345, + stdin=mocker.Mock(), + client=mocker.Mock(), + process=mock_process, + ) + # Mock the selector + mock_selector = mocker.Mock(spec=selectors.DefaultSelector) + mock_selector.select.return_value = [] + + # Set the selector on the process + proc.selector = mock_selector + return proc + + @pytest.mark.parametrize( + ["signal_to_send", "wait_side_effect", "expected_signals"], + [ + pytest.param( + signal.SIGINT, + [0], + [signal.SIGINT], + id="SIGINT-success-without-escalation", + ), + pytest.param( + signal.SIGINT, + [psutil.TimeoutExpired(0.1), 0], + [signal.SIGINT, signal.SIGTERM], + id="SIGINT-escalates-to-SIGTERM", + ), + pytest.param( + signal.SIGINT, + [ + psutil.TimeoutExpired(0.1), # SIGINT times out + psutil.TimeoutExpired(0.1), # SIGTERM times out + 0, # SIGKILL succeeds + ], + [signal.SIGINT, signal.SIGTERM, signal.SIGKILL], + id="SIGINT-escalates-to-SIGTERM-then-SIGKILL", + ), + pytest.param( + signal.SIGTERM, + [ + psutil.TimeoutExpired(0.1), # SIGTERM times out + 0, # SIGKILL succeeds + ], + [signal.SIGTERM, signal.SIGKILL], + id="SIGTERM-escalates-to-SIGKILL", + ), + pytest.param( + signal.SIGKILL, + [0], + [signal.SIGKILL], + id="SIGKILL-success-without-escalation", + ), + ], + ) + def test_force_kill_escalation( + self, + watched_subprocess, + mock_process, + mocker, + signal_to_send, + wait_side_effect, + expected_signals, + captured_logs, + ): + """Test escalation path for SIGINT, SIGTERM, and SIGKILL when force=True.""" + # Mock the process wait method to return the exit code or raise an exception + mock_process.wait.side_effect = wait_side_effect + + watched_subprocess.kill(signal_to_send=signal_to_send, escalation_delay=0.1, force=True) + + # Check that the correct signals were sent + mock_process.send_signal.assert_has_calls([mocker.call(sig) for sig in expected_signals]) + + # Check that the process was waited on for each signal + mock_process.wait.assert_has_calls([mocker.call(timeout=0)] * len(expected_signals)) + + ## Validate log messages + # If escalation occurred, we should see a warning log for each signal sent + if len(expected_signals) > 1: + assert { + "event": "Process did not terminate in time; escalating", + "level": "warning", Review Comment: This could be changed to debug too! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org