ashb commented on code in PR #53718:
URL: https://github.com/apache/airflow/pull/53718#discussion_r2236677730
##########
task-sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -877,6 +879,11 @@ def run(
assert ti.task is not None
assert isinstance(ti.task, BaseOperator)
+ def _on_kill(signum, frame):
+ ti.task.on_kill()
+
+ signal.signal(signal.SIGTERM, _on_kill)
Review Comment:
I wonder if we need to put some `pid` checks in here, to make sure that we
don't run this in any subprocesses that the task might fork. What did we do in
2.x?
##########
task-sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -877,6 +879,11 @@ def run(
assert ti.task is not None
assert isinstance(ti.task, BaseOperator)
+ def _on_kill(signum, frame):
+ ti.task.on_kill()
+
+ signal.signal(signal.SIGTERM, _on_kill)
Review Comment:
Nit, but an important one:
```suggestion
def _on_term(signum, frame):
ti.task.on_kill()
signal.signal(signal.SIGTERM, _on_term)
```
##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -338,6 +339,83 @@ def subprocess_main():
]
)
+ def test_on_kill_hook_called_when_sigkilled(
+ self,
+ client_with_ti_start,
+ mocked_parse,
+ make_ti_context,
+ mock_supervisor_comms,
+ create_runtime_ti,
+ make_ti_context_dict,
+ capfd,
+ ):
+ main_pid = os.getpid()
+ ti_id = "4d828a62-a417-4936-a7a6-2b3fabacecab"
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == f"/task-instances/{ti_id}/heartbeat":
+ return httpx.Response(
+ status_code=409,
+ json={
+ "detail": {
+ "reason": "not_running",
+ "message": "TI is no longer in the 'running'
state. Task state might be externally set and task should terminate",
+ "current_state": "failed",
+ }
+ },
+ )
+ if request.url.path == f"/task-instances/{ti_id}/run":
+ return httpx.Response(200, json=make_ti_context_dict())
+ return httpx.Response(status_code=204)
+
+ def subprocess_main():
+ # Ensure we follow the "protocol" and get the startup message
before we do anything
+ CommsDecoder()._get_response()
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ for i in range(1000):
+ print(f"Iteration {i}")
+ sleep(1)
+
+ def on_kill(self) -> None:
+ print("On kill hook called!")
+
+ task = CustomOperator(task_id="print-params")
+ runtime_ti = create_runtime_ti(
+ dag_id="c",
+ task=task,
+ conf={
+ "x": 3,
+ "text": "Hello World!",
+ "flag": False,
+ "a_simple_list": ["one", "two", "three", "actually one
value is made per line"],
+ },
+ )
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ assert os.getpid() != main_pid
+ os.kill(os.getpid(), signal.SIGTERM)
Review Comment:
This might have timing issues, since this is the function ending, it's
possible (if unlikely) that the process will clean up before the signal is
handled.
Something like this should protect against it?
```suggestion
os.kill(os.getpid(), signal.SIGTERM)
# Ensure that the signal is serviced before we finish and exit
the subprocess.
sleep(0.5)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]