This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4dc98451ba8 fix: handle unmapped task deadlock when upstream tasks are
removed (#62034)
4dc98451ba8 is described below
commit 4dc98451ba83a956a8c96072df171eddcf1b8775
Author: Zhen-Lun (Kevin) Hong <[email protected]>
AuthorDate: Sun Jun 14 19:36:14 2026 +0800
fix: handle unmapped task deadlock when upstream tasks are removed (#62034)
* fix: prevent deadlock when the number of mapped tasks is reduced
* chore: add unit tests
* chore: add test to check rerunning with an upstream task removed
* add unit test of unmapped tasks
---
.../src/airflow/ti_deps/deps/trigger_rule_dep.py | 23 +---
airflow-core/tests/unit/models/test_dagrun.py | 121 +++++++++++++++++++++
.../unit/ti_deps/deps/test_trigger_rule_dep.py | 90 +++++++++++++++
3 files changed, 217 insertions(+), 17 deletions(-)
diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 9f9f9bd77aa..54d5a4d9309 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -503,9 +503,7 @@ class TriggerRuleDep(BaseTIDep):
)
)
elif trigger_rule == TR.ALL_SUCCESS:
- num_failures = upstream - success
- if ti.map_index > -1:
- num_failures -= removed
+ num_failures = upstream - success - removed
if num_failures > 0:
yield self._failing_status(
reason=(
@@ -516,9 +514,7 @@ class TriggerRuleDep(BaseTIDep):
)
)
elif trigger_rule == TR.ALL_FAILED:
- num_success = upstream - failed - upstream_failed
- if ti.map_index > -1:
- num_success -= removed
+ num_success = upstream - failed - upstream_failed - removed
if num_success > 0:
yield self._failing_status(
reason=(
@@ -539,9 +535,7 @@ class TriggerRuleDep(BaseTIDep):
)
)
elif trigger_rule == TR.NONE_FAILED:
- num_failures = upstream - success - skipped
- if ti.map_index > -1:
- num_failures -= removed
+ num_failures = upstream - success - skipped - removed
if num_failures > 0:
yield self._failing_status(
reason=(
@@ -552,9 +546,7 @@ class TriggerRuleDep(BaseTIDep):
)
)
elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
- num_failures = upstream - success - skipped
- if ti.map_index > -1:
- num_failures -= removed
+ num_failures = upstream - success - skipped - removed
if num_failures > 0:
yield self._failing_status(
reason=(
@@ -614,11 +606,8 @@ class TriggerRuleDep(BaseTIDep):
)
elif trigger_rule == TR.ALL_DONE_MIN_ONE_SUCCESS:
# For this trigger rule, skipped tasks are not considered
"done"
- non_skipped_done = success + failed + upstream_failed + removed
- non_skipped_upstream = upstream - skipped
- if ti.map_index > -1:
- non_skipped_upstream -= removed
- non_skipped_done -= removed
+ non_skipped_done = success + failed + upstream_failed
+ non_skipped_upstream = upstream - skipped - removed
if skipped > 0:
yield self._failing_status(
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index a8a831fb6fa..a8d384d9480 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -3229,6 +3229,127 @@ def
test_mapped_task_rerun_with_different_length_of_args(session, dag_maker, rer
assert len(success_tis) == rerun_length
+def test_mapped_task_length_reduction_rerun_downstream_not_deadlocked(session,
dag_maker):
+ @task
+ def producer():
+ context = get_current_context()
+ if context["ti"].try_number == 0:
+ return [i for i in range(3)]
+ return [i for i in range(2)]
+
+ @task
+ def work(arg):
+ return arg
+
+ @task
+ def finish(data):
+ return sum(data)
+
+ def _task_ids(tis):
+ return [(ti.task_id, ti.map_index) for ti in tis]
+
+ with dag_maker(session=session):
+ produced = producer()
+ mapped = work.expand(arg=produced)
+ done = finish(produced)
+ mapped >> done
+
+ dr: DagRun = dag_maker.create_dagrun()
+
+ # First run with 3 mapped task instances.
+ dag_maker.run_ti("producer", dr)
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1),
("work", 2)]
+
+ for ti in decision.schedulable_tis:
+ dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert _task_ids(decision.schedulable_tis) == [("finish", -1)]
+ dag_maker.run_ti("finish", dr)
+
+ # Clear and rerun with one fewer mapped task instance.
+ clear_task_instances(dr.get_task_instances(session=session),
session=session)
+ ti = dr.get_task_instance(task_id="producer", session=session)
+ ti.try_number += 1
+ session.merge(ti)
+
+ dag_maker.run_ti("producer", dr)
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1)]
+
+ mapped_states = session.execute(
+ select(TI.map_index, TI.state)
+ .where(TI.task_id == "work", TI.dag_id == dr.dag_id, TI.run_id ==
dr.run_id)
+ .order_by(TI.map_index)
+ ).all()
+ assert mapped_states == [
+ (0, State.NONE),
+ (1, State.NONE),
+ (2, TaskInstanceState.REMOVED),
+ ]
+
+ for ti in decision.schedulable_tis:
+ dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert _task_ids(decision.schedulable_tis) == [("finish", -1)]
+
+ dag_maker.run_ti("finish", dr)
+ finish_ti = dr.get_task_instance(task_id="finish", map_index=-1,
session=session)
+ assert finish_ti
+ assert finish_ti.state == TaskInstanceState.SUCCESS
+
+
+def test_rerun_with_upstream_task_removed(session, dag_maker):
+ def _task_ids(tis):
+ return [(ti.task_id, ti.map_index) for ti in tis]
+
+ with dag_maker("test", session=session):
+ upstream_1 = EmptyOperator(task_id="upstream_1")
+ upstream_2 = EmptyOperator(task_id="upstream_2")
+ downstream = EmptyOperator(task_id="downstream")
+ [upstream_1, upstream_2] >> downstream
+
+ dr: DagRun = dag_maker.create_dagrun()
+
+ dag_maker.run_ti("upstream_1", dr)
+ dag_maker.run_ti("upstream_2", dr)
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert _task_ids(decision.schedulable_tis) == [("downstream", -1)]
+
+ dag_maker.run_ti("downstream", dr)
+ dr.update_state(session=session)
+ assert dr.state == DagRunState.SUCCESS
+
+ # Rerun with upstream_1 removed
+ with dag_maker("test", session=session, serialized=True) as dag:
+ upstream_2 = EmptyOperator(task_id="upstream_2")
+ downstream = EmptyOperator(task_id="downstream")
+ upstream_2 >> downstream
+
+ latest_version = DagVersion.get_latest_version(dag.dag_id)
+ assert latest_version.version_number == 2
+
+ clear_task_instances(
+ dr.get_task_instances(session=session),
+ session=session,
+ run_on_latest_version=True,
+ )
+
+ upstream_1 = dr.get_task_instance(task_id="upstream_1", map_index=-1,
session=session)
+ assert upstream_1.state == TaskInstanceState.REMOVED
+
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert _task_ids(decision.schedulable_tis) == [("upstream_2", -1)]
+
+ dag_maker.run_ti("upstream_2", dr)
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert _task_ids(decision.schedulable_tis) == [("downstream", -1)]
+
+ dag_maker.run_ti("downstream", dr)
+ dr.update_state(session=session)
+ assert dr.state == DagRunState.SUCCESS
+
+
def test_operator_mapped_task_group_receives_value(dag_maker, session):
with dag_maker(session=session):
diff --git a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
index bd4a576f9f1..c5f805cd553 100644
--- a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
+++ b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
@@ -1561,6 +1561,96 @@ class TestTriggerRuleDep:
expected_ti_state=expected_ti_state,
)
+ @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+ @pytest.mark.parametrize(
+ ("trigger_rule", "upstream_states"),
+ [
+ (
+ TriggerRule.ALL_SUCCESS,
+ _UpstreamTIStates(
+ success=3,
+ skipped=0,
+ failed=0,
+ upstream_failed=0,
+ removed=2,
+ done=5,
+ skipped_setup=0,
+ success_setup=0,
+ ),
+ ),
+ (
+ TriggerRule.ALL_FAILED,
+ _UpstreamTIStates(
+ success=0,
+ skipped=0,
+ failed=3,
+ upstream_failed=0,
+ removed=2,
+ done=5,
+ skipped_setup=0,
+ success_setup=0,
+ ),
+ ),
+ (
+ TriggerRule.NONE_FAILED,
+ _UpstreamTIStates(
+ success=3,
+ skipped=0,
+ failed=0,
+ upstream_failed=0,
+ removed=2,
+ done=5,
+ skipped_setup=0,
+ success_setup=0,
+ ),
+ ),
+ (
+ TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
+ _UpstreamTIStates(
+ success=3,
+ skipped=0,
+ failed=0,
+ upstream_failed=0,
+ removed=2,
+ done=5,
+ skipped_setup=0,
+ success_setup=0,
+ ),
+ ),
+ (
+ TriggerRule.ALL_DONE_MIN_ONE_SUCCESS,
+ _UpstreamTIStates(
+ success=3,
+ skipped=0,
+ failed=0,
+ upstream_failed=0,
+ removed=2,
+ done=5,
+ skipped_setup=0,
+ success_setup=0,
+ ),
+ ),
+ ],
+ )
+ def test_non_mapped_task_ignores_removed_upstream_tis(
+ self,
+ monkeypatch,
+ session,
+ get_task_instance,
+ flag_upstream_failed,
+ trigger_rule,
+ upstream_states,
+ ):
+ """
+ Non-mapped trigger-rule checks should exclude removed upstream task
instances.
+ """
+ ti = get_task_instance(
+ trigger_rule,
+ normal_tasks=["upstream_1", "upstream_2", "upstream_3",
"upstream_4", "upstream_5"],
+ )
+ monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_:
upstream_states)
+ _test_trigger_rule(ti=ti, session=session,
flag_upstream_failed=flag_upstream_failed)
+
def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session):
from airflow.sdk import task, task_group