This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e8be208e3f Honor retry_policy on non-deferrable TriggerDagRunOperator 
wait failures (#68254)
1e8be208e3f is described below

commit 1e8be208e3f7b8eb3eb510ed1d94b7a947bff689
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jun 9 03:41:22 2026 +0100

    Honor retry_policy on non-deferrable TriggerDagRunOperator wait failures 
(#68254)
    
    When a non-deferrable TriggerDagRunOperator(wait_for_completion=True) sees 
the
    triggered DagRun reach a failed state, the failure branch called
    _handle_current_task_failed directly, which only checks the standard retry 
count
    and never consults a configured retry_policy. The deferrable path raises
    AirflowException, which flows through run()'s exception handler and is 
evaluated
    by _apply_retry_policy_or_default, so a retry_policy was honored there but 
not on
    the polling path.
    
    Route the non-deferrable failed-state branch through
    _apply_retry_policy_or_default with a synthesized AirflowException 
mirroring the
    deferrable path, so the policy gets a vote on both. With no retry_policy 
set the
    behavior is unchanged (it falls back to _handle_current_task_failed).
---
 .../src/airflow/sdk/execution_time/task_runner.py  |  12 ++-
 .../task_sdk/execution_time/test_task_runner.py    | 101 +++++++++++++++++++++
 2 files changed, 112 insertions(+), 1 deletion(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index beda0728c7d..5ab862ef6e4 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -1815,7 +1815,17 @@ def _handle_trigger_dag_run(
                 log.error(
                     "DagRun finished with failed state.", 
dag_id=drte.trigger_dag_id, state=comms_msg.state
                 )
-                return _handle_current_task_failed(ti)
+                # Mirror the deferrable path (DagStateTrigger -> 
execute_complete raises
+                # AirflowException), which flows through run()'s exception 
handler and
+                # therefore honours a configured retry_policy. Synthesize the 
same
+                # exception here so non-deferrable waits evaluate the policy 
too,
+                # falling back to the standard retry-count check when none is 
set.
+                return _apply_retry_policy_or_default(
+                    ti,
+                    AirflowException(f"{drte.trigger_dag_id} failed with 
failed state {comms_msg.state}"),
+                    log,
+                    context,
+                )
             if comms_msg.state in drte.allowed_states:
                 log.info(
                     "DagRun finished with allowed state.", 
dag_id=drte.trigger_dag_id, state=comms_msg.state
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index f3b07482b4f..29bc2558c0c 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -50,6 +50,7 @@ from airflow._shared.observability.traces import (
 from airflow.api_fastapi.execution_api.routes.task_instances import 
_emit_task_span
 from airflow.listeners import hookimpl
 from airflow.providers.standard.operators.python import PythonOperator
+from airflow.providers.standard.operators.trigger_dagrun import 
TriggerDagRunOperator
 from airflow.sdk import (
     DAG,
     BaseOperator,
@@ -76,6 +77,7 @@ from airflow.sdk.bases.xcom import BaseXCom
 from airflow.sdk.definitions._internal.types import NOTSET, 
SET_DURING_EXECUTION, is_arg_set
 from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, 
AssetUriRef, Dataset, Model
 from airflow.sdk.definitions.param import DagParam
+from airflow.sdk.definitions.retry_policy import ExceptionRetryPolicy, 
RetryAction, RetryRule
 from airflow.sdk.exceptions import (
     AirflowException,
     AirflowFailException,
@@ -5076,6 +5078,105 @@ class TestTriggerDagRunOperator:
 
         assert state == TaskInstanceState.UP_FOR_RETRY
 
+    def 
test_handle_trigger_dag_run_wait_for_completion_failed_state_retry_policy_fail(
+        self, create_runtime_ti, mock_supervisor_comms
+    ):
+        """A retry_policy returning FAIL fails the task even when retries are 
still available.
+
+        This mirrors the deferrable path, where DagStateTrigger -> 
execute_complete raises
+        AirflowException and the configured retry_policy is consulted via 
run()'s handler.
+        """
+        task = TriggerDagRunOperator(
+            task_id="test_task",
+            trigger_dag_id="test_dag",
+            trigger_run_id="test_run_id",
+            poke_interval=5,
+            wait_for_completion=True,
+            allowed_states=[DagRunState.SUCCESS],
+            failed_states=[DagRunState.FAILED],
+            deferrable=False,
+            retry_policy=ExceptionRetryPolicy(
+                rules=[
+                    RetryRule(exception=AirflowException, 
action=RetryAction.FAIL, reason="not retryable")
+                ],
+            ),
+        )
+        ti = create_runtime_ti(
+            
dag_id="test_handle_trigger_dag_run_wait_for_completion_failed_state_retry_policy_fail",
+            run_id="test_run",
+            task=task,
+            should_retry=True,
+        )
+
+        def _send_side_effect(*args, **kwargs):
+            msg = kwargs.get("msg")
+            if msg is None and args:
+                msg = args[0]
+            if isinstance(msg, TriggerDagRun):
+                return OKResponse(ok=True)
+            if isinstance(msg, GetDagRunState):
+                return DagRunStateResult(state=DagRunState.FAILED)
+            return None
+
+        mock_supervisor_comms.send.side_effect = _send_side_effect
+
+        log = mock.MagicMock()
+        with mock.patch("time.sleep", return_value=None):
+            state, _, _ = run(ti, ti.get_template_context(), log)
+
+        assert state == TaskInstanceState.FAILED
+
+    def 
test_handle_trigger_dag_run_wait_for_completion_failed_state_retry_policy_delay(
+        self, create_runtime_ti, mock_supervisor_comms
+    ):
+        """A retry_policy returning RETRY forwards its custom delay and reason 
on the RetryTask."""
+        task = TriggerDagRunOperator(
+            task_id="test_task",
+            trigger_dag_id="test_dag",
+            trigger_run_id="test_run_id",
+            poke_interval=5,
+            wait_for_completion=True,
+            allowed_states=[DagRunState.SUCCESS],
+            failed_states=[DagRunState.FAILED],
+            deferrable=False,
+            retry_policy=ExceptionRetryPolicy(
+                rules=[
+                    RetryRule(
+                        exception=AirflowException,
+                        action=RetryAction.RETRY,
+                        retry_delay=timedelta(minutes=7),
+                        reason="backing off",
+                    ),
+                ],
+            ),
+        )
+        ti = create_runtime_ti(
+            
dag_id="test_handle_trigger_dag_run_wait_for_completion_failed_state_retry_policy_delay",
+            run_id="test_run",
+            task=task,
+            should_retry=True,
+        )
+
+        def _send_side_effect(*args, **kwargs):
+            msg = kwargs.get("msg")
+            if msg is None and args:
+                msg = args[0]
+            if isinstance(msg, TriggerDagRun):
+                return OKResponse(ok=True)
+            if isinstance(msg, GetDagRunState):
+                return DagRunStateResult(state=DagRunState.FAILED)
+            return None
+
+        mock_supervisor_comms.send.side_effect = _send_side_effect
+
+        log = mock.MagicMock()
+        with mock.patch("time.sleep", return_value=None):
+            state, msg, _ = run(ti, ti.get_template_context(), log)
+
+        assert state == TaskInstanceState.UP_FOR_RETRY
+        assert msg.retry_delay_seconds == timedelta(minutes=7).total_seconds()
+        assert msg.retry_reason == "backing off"
+
     @pytest.mark.parametrize(
         ("allowed_states", "failed_states", "intermediate_state"),
         [

Reply via email to