This is an automated email from the ASF dual-hosted git repository. vatsrahul1001 pushed a commit to branch backport-173c2a1-v3-2-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 65504bd1d4fac592bdfc1e3ddfd1d46f9ce8d957 Author: Jarek Potiuk <[email protected]> AuthorDate: Tue May 19 20:35:42 2026 +0200 Recover stuck TIs when direct terminal-state API call fails (#66574) * Recover stuck TIs when direct terminal-state API call fails The supervisor's _handle_request for SucceedTask, RetryTask, DeferTask, and RescheduleTask set _terminal_state BEFORE calling the matching client.task_instances.{succeed,retry,defer,reschedule}() API. If that API call raised (transient network blip, server 5xx, etc.), _terminal_state was set on the supervisor but the server never saw the transition. The supervisor's update_task_state_if_needed then saw final_state in STATES_SENT_DIRECTLY and short-circuited the recovery finish() call -- leaving the TaskInstance stuck RUNNING on the server forever, blocking downstream dependencies and triggering false alerts. Two-part fix: 1. Make the direct API call FIRST. Only set _terminal_state and the new _terminal_state_synced_to_server flag after the call returns successfully. If the API raises, both stay unset and the exception propagates to handle_requests, where the existing catch-all sends an ErrorResponse to the task subprocess. 2. Have update_task_state_if_needed always call finish() when _terminal_state_synced_to_server is False, regardless of what final_state happens to return. The finish() API takes the state value, so a SUCCESS / DEFERRED / etc. transition that originally failed is re-attempted via finish() on subprocess exit. Pre-existing semantics for the no-direct-API states (FAILED, UP_FOR_RETRY without RetryTask, etc.) preserved -- those land in the same finish() branch. Tests added: - _terminal_state not set when succeed() raises. - update_task_state_if_needed calls finish() when synced flag is False, even with final_state == SUCCESS. - update_task_state_if_needed skips finish() when synced flag is True (preserves the existing happy-path optimisation). Reported by the L3 ASVS sweep at apache/tooling-agents#24 (FINDING-007). * Refactor terminal-state dispatch and parametrize tests across all 4 states Address review feedback on #66574: - Extract `_send_terminal_state_msg` helper so the per-msg-type dispatch for succeed / retry / defer / reschedule lives in one place. Both `_handle_request` and `_replay_pending_terminal_state_msg` now go through it instead of duplicating the four-branch isinstance chain. - Parametrize the two recovery tests over all four terminal-state message types (was only Succeed + Defer); add UP_FOR_RETRY and UP_FOR_RESCHEDULE coverage. * Narrow _pending_terminal_state_msg type to satisfy mypy The field was annotated as BaseModel | None, but _send_terminal_state_msg expects SucceedTask | RetryTask | DeferTask | RescheduleTask. mypy couldn't prove the narrowing at the _replay_pending_terminal_state_msg call site. Tighten the field type to the exact union the setter assigns and the consumer accepts. --------- Co-authored-by: vatsrahul1001 <[email protected]> Co-authored-by: Rahul Vats <[email protected]> (cherry picked from commit 173c2a1806dd087272ec287fb923917630ef8f81) --- .../src/airflow/sdk/execution_time/supervisor.py | 110 +++++++++++++---- .../task_sdk/execution_time/test_supervisor.py | 131 +++++++++++++++++++++ 2 files changed, 219 insertions(+), 22 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index cd68dc85255..a9b12dc1521 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1089,6 +1089,18 @@ class ActivitySubprocess(WatchedSubprocess): _terminal_state: str | None = attrs.field(default=None, init=False) _final_state: str | None = attrs.field(default=None, init=False) + # The terminal-state message currently being processed by `_handle_request`, + # captured BEFORE the dedicated API call (succeed / retry / defer / + # reschedule). If the API call raises (network blip, server 5xx, etc.), + # this attribute stays set and the dispatcher in + # `update_task_state_if_needed` re-issues the matching API call on + # subprocess exit — re-attempting the original transition rather than + # falling back to `finish()`, which doesn't accept SUCCESS / DEFERRED / + # SERVER_TERMINATED on the server side. Cleared (and `_terminal_state` + # set) only after the API call returns successfully. + _pending_terminal_state_msg: SucceedTask | RetryTask | DeferTask | RescheduleTask | None = attrs.field( + default=None, init=False + ) _last_successful_heartbeat: float = attrs.field(default=0, init=False) _last_heartbeat_attempt: float = attrs.field(default=0, init=False) @@ -1206,10 +1218,23 @@ class ActivitySubprocess(WatchedSubprocess): return self._exit_code def update_task_state_if_needed(self): - # If the process has finished non-directly patched state (directly means deferred, reschedule, etc.), - # update the state of the TaskInstance to reflect the final state of the process. - # For states like `deferred`, `up_for_reschedule`, the process will exit with 0, but the state will be updated - # by the subprocess in the `handle_requests` method. + # If a direct-state API call (succeed / retry / defer / reschedule) + # was attempted but raised, `_pending_terminal_state_msg` still holds + # the original request. Re-issue the matching dedicated API call so + # the server learns the terminal state we couldn't deliver earlier. + # Without this recovery, a transient API failure during the direct + # call would leave the TI stuck RUNNING on the server — `finish()` + # cannot substitute because the server-side `finish` endpoint does + # not accept SUCCESS / DEFERRED / SERVER_TERMINATED transitions. + if self._pending_terminal_state_msg is not None: + self._replay_pending_terminal_state_msg() + return + + # If the process has finished a non-directly-patched state (e.g. + # FAILED, UP_FOR_RETRY without RetryTask), `finish()` is the + # dedicated endpoint for those transitions. For states already in + # STATES_SENT_DIRECTLY whose direct API call succeeded, no further + # action is needed. if self.final_state not in STATES_SENT_DIRECTLY: self.client.task_instances.finish( id=self.id, @@ -1218,6 +1243,58 @@ class ActivitySubprocess(WatchedSubprocess): rendered_map_index=self._rendered_map_index, ) + def _send_terminal_state_msg(self, msg: SucceedTask | RetryTask | DeferTask | RescheduleTask) -> None: + # Capture the message BEFORE the API call so the recovery dispatcher + # in `update_task_state_if_needed` can re-issue it if the call raises + # (network blip, transient server 5xx). Clear the pending slot and + # record the resulting state only after the call returns successfully. + self._pending_terminal_state_msg = msg + if isinstance(msg, SucceedTask): + self.client.task_instances.succeed( + id=self.id, + when=msg.end_date, + task_outlets=msg.task_outlets, + outlet_events=msg.outlet_events, + rendered_map_index=self._rendered_map_index, + ) + self._terminal_state = msg.state + elif isinstance(msg, RetryTask): + self.client.task_instances.retry( + id=self.id, + end_date=msg.end_date, + rendered_map_index=self._rendered_map_index, + retry_delay_seconds=getattr(msg, "retry_delay_seconds", None), + retry_reason=getattr(msg, "retry_reason", None), + ) + self._terminal_state = msg.state + elif isinstance(msg, DeferTask): + self.client.task_instances.defer(self.id, msg) + self._terminal_state = TaskInstanceState.DEFERRED + elif isinstance(msg, RescheduleTask): + self.client.task_instances.reschedule(self.id, msg) + self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE + self._pending_terminal_state_msg = None + + def _replay_pending_terminal_state_msg(self) -> None: + """ + Re-issue the dedicated API call for an unsynced terminal-state msg. + + Best-effort — if the second attempt also fails the exception is + logged and we move on; the supervisor's overall failure handling + (heartbeat, exit-code reporting) will eventually surface the issue. + """ + msg = self._pending_terminal_state_msg + if msg is None: + return + try: + self._send_terminal_state_msg(msg) + except Exception: + log.exception( + "Recovery retry of terminal-state API call failed; TI may be stuck on the server", + ti_id=self.id, + msg_type=type(msg).__name__, + ) + def _upload_logs(self): """ Upload all log files found to the remote storage. @@ -1389,29 +1466,20 @@ class ActivitySubprocess(WatchedSubprocess): resp: BaseModel | None = None dump_opts = {} if isinstance(msg, TaskState): + # No direct API call here — the recovery path in + # `update_task_state_if_needed` will call `finish()` for + # non-direct states (FAILED, etc.) once the subprocess exits. self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() self._rendered_map_index = msg.rendered_map_index elif isinstance(msg, SucceedTask): - self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() self._rendered_map_index = msg.rendered_map_index - self.client.task_instances.succeed( - id=self.id, - when=msg.end_date, - task_outlets=msg.task_outlets, - outlet_events=msg.outlet_events, - rendered_map_index=self._rendered_map_index, - ) + self._send_terminal_state_msg(msg) elif isinstance(msg, RetryTask): - self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() self._rendered_map_index = msg.rendered_map_index - self.client.task_instances.retry( - id=self.id, - end_date=msg.end_date, - rendered_map_index=self._rendered_map_index, - ) + self._send_terminal_state_msg(msg) elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) if isinstance(conn, ConnectionResponse): @@ -1463,12 +1531,10 @@ class ActivitySubprocess(WatchedSubprocess): ) resp = XComSequenceSliceResult.from_response(xcoms) elif isinstance(msg, DeferTask): - self._terminal_state = TaskInstanceState.DEFERRED self._rendered_map_index = msg.rendered_map_index - self.client.task_instances.defer(self.id, msg) + self._send_terminal_state_msg(msg) elif isinstance(msg, RescheduleTask): - self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE - self.client.task_instances.reschedule(self.id, msg) + self._send_terminal_state_msg(msg) elif isinstance(msg, SkipDownstreamTasks): self.client.task_instances.skip_downstream_tasks(self.id, msg) elif isinstance(msg, SetXCom): diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index f61b257b71b..05aa87c5cfb 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2793,6 +2793,137 @@ class TestHandleRequest: # Should not raise StopIteration (which would mean the loop crashed). generator.send(req2) + @pytest.mark.parametrize( + ("msg", "api_method", "expected_state"), + [ + pytest.param( + SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")), + "succeed", + TaskInstanceState.SUCCESS, + id="succeed", + ), + pytest.param( + RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")), + "retry", + TaskInstanceState.UP_FOR_RETRY, + id="retry", + ), + pytest.param( + DeferTask( + next_method="execute_complete", + classpath="airflow.providers.standard.triggers.external_task.WorkflowTrigger", + trigger_kwargs={}, + ), + "defer", + TaskInstanceState.DEFERRED, + id="defer", + ), + pytest.param( + RescheduleTask( + reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), + end_date=timezone.parse("2024-10-31T12:00:00Z"), + ), + "reschedule", + TaskInstanceState.UP_FOR_RESCHEDULE, + id="reschedule", + ), + ], + ) + def test_terminal_state_not_set_when_direct_api_fails( + self, watched_subprocess, mocker, msg, api_method, expected_state + ): + """`_terminal_state` must NOT be set when the dedicated terminal-state + API raises. + + The original message is captured in `_pending_terminal_state_msg` + BEFORE the API call so the recovery dispatcher in + `update_task_state_if_needed` can re-issue it on subprocess exit. + Covers all four terminal-state message types. + """ + watched_subprocess, _ = watched_subprocess + setattr( + watched_subprocess.client.task_instances, + api_method, + mocker.Mock(side_effect=httpx.ConnectError("connection refused")), + ) + + with pytest.raises(httpx.ConnectError): + watched_subprocess._handle_request(msg, mocker.Mock(), req_id=1) + + assert watched_subprocess._terminal_state is None + # Pending msg preserved so the recovery dispatcher can re-issue. + assert watched_subprocess._pending_terminal_state_msg is msg + + @pytest.mark.parametrize( + ("msg", "api_method", "expected_state"), + [ + pytest.param( + SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")), + "succeed", + TaskInstanceState.SUCCESS, + id="succeed", + ), + pytest.param( + RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")), + "retry", + TaskInstanceState.UP_FOR_RETRY, + id="retry", + ), + pytest.param( + DeferTask( + next_method="execute_complete", + classpath="airflow.providers.standard.triggers.external_task.WorkflowTrigger", + trigger_kwargs={}, + ), + "defer", + TaskInstanceState.DEFERRED, + id="defer", + ), + pytest.param( + RescheduleTask( + reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), + end_date=timezone.parse("2024-10-31T12:00:00Z"), + ), + "reschedule", + TaskInstanceState.UP_FOR_RESCHEDULE, + id="reschedule", + ), + ], + ) + def test_update_task_state_replays_pending_terminal_state_call( + self, watched_subprocess, mocker, msg, api_method, expected_state + ): + """If a direct terminal-state API call was attempted and raised, the + recovery dispatcher must re-issue the dedicated endpoint (not + `finish()`, which the server-side endpoint refuses for SUCCESS / + DEFERRED / SERVER_TERMINATED). Covers all four message types. + """ + watched_subprocess, _ = watched_subprocess + watched_subprocess._exit_code = 0 + # Simulate the failure scenario: original API call raised, msg preserved. + watched_subprocess._pending_terminal_state_msg = msg + + watched_subprocess.update_task_state_if_needed() + + # Recovery re-issues the dedicated endpoint, NOT finish(). + getattr(watched_subprocess.client.task_instances, api_method).assert_called_once() + watched_subprocess.client.task_instances.finish.assert_not_called() + assert watched_subprocess._terminal_state == expected_state + assert watched_subprocess._pending_terminal_state_msg is None + + def test_update_task_state_no_recovery_without_pending_msg(self, watched_subprocess, mocker): + """No replay when nothing was pending — preserves the original + STATES_SENT_DIRECTLY short-circuit for the happy path.""" + watched_subprocess, _ = watched_subprocess + watched_subprocess._exit_code = 0 + watched_subprocess._terminal_state = TaskInstanceState.SUCCESS + watched_subprocess._pending_terminal_state_msg = None + + watched_subprocess.update_task_state_if_needed() + + watched_subprocess.client.task_instances.finish.assert_not_called() + watched_subprocess.client.task_instances.succeed.assert_not_called() + class TestSetSupervisorComms: class DummyComms:
