This is an automated email from the ASF dual-hosted git repository.
shahar pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-10-test by this push:
new 4b27c3fb087 [v2-10-test] Fix premature evaluation in mapped task group
(#44937)
4b27c3fb087 is described below
commit 4b27c3fb087a61ca42be8d09b1462b731aaa640d
Author: Shahar Epstein <[email protected]>
AuthorDate: Tue Dec 17 23:07:42 2024 +0200
[v2-10-test] Fix premature evaluation in mapped task group (#44937)
* Fix docstrings and warnings in trigger_rule_dep.py
* Fix pre-mature evaluation of tasks in mapped task group
Co-authored-by: Tzu-ping Chung <[email protected]>
Co-authored-by: Ephraim Anierobi <[email protected]>
* Add newsfragment
---------
Co-authored-by: Tzu-ping Chung <[email protected]>
Co-authored-by: Ephraim Anierobi <[email protected]>
---
airflow/ti_deps/deps/trigger_rule_dep.py | 42 ++++++++++------
newsfragments/44937.bugfix.rst | 1 +
tests/models/test_mappedoperator.py | 86 +++++++++++++++++++++++++++++++-
3 files changed, 113 insertions(+), 16 deletions(-)
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow/ti_deps/deps/trigger_rule_dep.py
index 76291c8a057..6e00f718be2 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -27,6 +27,7 @@ from sqlalchemy import and_, func, or_, select
from airflow.models.taskinstance import PAST_DEPENDS_MET
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.state import TaskInstanceState
+from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule as TR
if TYPE_CHECKING:
@@ -63,8 +64,7 @@ class _UpstreamTIStates(NamedTuple):
``counter`` is inclusive of ``setup_counter`` -- e.g. if there are 2
skipped upstreams, one
of which is a setup, then counter will show 2 skipped and setup
counter will show 1.
- :param ti: the ti that we want to calculate deps for
- :param finished_tis: all the finished tasks of the dag_run
+ :param finished_upstreams: all the finished upstreams of the dag_run
"""
counter: dict[str, int] = Counter()
setup_counter: dict[str, int] = Counter()
@@ -143,6 +143,19 @@ class TriggerRuleDep(BaseTIDep):
return ti.task.get_mapped_ti_count(ti.run_id, session=session)
+ def _iter_expansion_dependencies(task_group: MappedTaskGroup) ->
Iterator[str]:
+ from airflow.models.mappedoperator import MappedOperator
+
+ if isinstance(ti.task, MappedOperator):
+ for op in ti.task.iter_mapped_dependencies():
+ yield op.task_id
+ if task_group and task_group.iter_mapped_task_groups():
+ yield from (
+ op.task_id
+ for tg in task_group.iter_mapped_task_groups()
+ for op in tg.iter_mapped_dependencies()
+ )
+
@functools.lru_cache
def _get_relevant_upstream_map_indexes(upstream_id: str) -> int |
range | None:
"""
@@ -156,6 +169,13 @@ class TriggerRuleDep(BaseTIDep):
assert ti.task
assert isinstance(ti.task.dag, DAG)
+ if isinstance(ti.task.task_group, MappedTaskGroup):
+ is_fast_triggered = ti.task.trigger_rule in (TR.ONE_SUCCESS,
TR.ONE_FAILED, TR.ONE_DONE)
+ if is_fast_triggered and upstream_id not in set(
+ _iter_expansion_dependencies(task_group=ti.task.task_group)
+ ):
+ return None
+
try:
expanded_ti_count = _get_expanded_ti_count()
except (NotFullyPopulated, NotMapped):
@@ -217,7 +237,7 @@ class TriggerRuleDep(BaseTIDep):
for upstream_id in relevant_tasks:
map_indexes = _get_relevant_upstream_map_indexes(upstream_id)
if map_indexes is None: # All tis of this upstream are
dependencies.
- yield (TaskInstance.task_id == upstream_id)
+ yield TaskInstance.task_id == upstream_id
continue
# At this point we know we want to depend on only selected tis
# of this upstream task. Since the upstream may not have been
@@ -237,11 +257,9 @@ class TriggerRuleDep(BaseTIDep):
def _evaluate_setup_constraint(*, relevant_setups) ->
Iterator[tuple[TIDepStatus, bool]]:
"""
- Evaluate whether ``ti``'s trigger rule was met.
+ Evaluate whether ``ti``'s trigger rule was met as part of the
setup constraint.
- :param ti: Task instance to evaluate the trigger rule of.
- :param dep_context: The current dependency context.
- :param session: Database session.
+ :param relevant_setups: Relevant setups for the current task
instance.
"""
if TYPE_CHECKING:
assert ti.task
@@ -327,13 +345,7 @@ class TriggerRuleDep(BaseTIDep):
)
def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
- """
- Evaluate whether ``ti``'s trigger rule was met.
-
- :param ti: Task instance to evaluate the trigger rule of.
- :param dep_context: The current dependency context.
- :param session: Database session.
- """
+ """Evaluate whether ``ti``'s trigger rule in direct relatives was
met."""
if TYPE_CHECKING:
assert ti.task
@@ -433,7 +445,7 @@ class TriggerRuleDep(BaseTIDep):
)
if not past_depends_met:
yield self._failing_status(
- reason=("Task should be skipped but the past
depends are not met")
+ reason="Task should be skipped but the past
depends are not met"
)
return
changed = ti.set_state(new_state, session)
diff --git a/newsfragments/44937.bugfix.rst b/newsfragments/44937.bugfix.rst
new file mode 100644
index 00000000000..d50da4de82f
--- /dev/null
+++ b/newsfragments/44937.bugfix.rst
@@ -0,0 +1 @@
+Fix pre-mature evaluation of tasks in mapped task group. The origins of the
bug are in ``TriggerRuleDep``, when dealing with ``TriggerRule`` that is fastly
triggered (i.e, ``ONE_FAILED``, ``ONE_SUCCESS`, or ``ONE_DONE``). Please note
that at time of merging, this fix has been applied only for Airflow version >
2.10.4 and < 3, and should be ported to v3 after merging PR #40460.
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
index cf547912fb9..d1e896200c7 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -37,7 +37,7 @@ from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom_arg import XComArg
from airflow.operators.python import PythonOperator
-from airflow.utils.state import TaskInstanceState
+from airflow.utils.state import State, TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.task_instance_session import
set_current_task_instance_session
from airflow.utils.trigger_rule import TriggerRule
@@ -1784,3 +1784,87 @@ class TestMappedSetupTeardown:
"group.last": {0: "success", 1: "skipped", 2: "success"},
}
assert states == expected
+
+
+def
test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete(dag_maker,
session):
+ """Test that one failed trigger rule works well in mapped task group"""
+ with dag_maker() as dag:
+
+ @dag.task
+ def t1():
+ return [1, 2, 3]
+
+ @task_group("tg1")
+ def tg1(a):
+ @dag.task()
+ def t2(a):
+ return a
+
+ @dag.task(trigger_rule=TriggerRule.ONE_FAILED)
+ def t3(a):
+ return a
+
+ t2(a) >> t3(a)
+
+ t = t1()
+ tg1.expand(a=t)
+
+ dr = dag_maker.create_dagrun()
+ ti = dr.get_task_instance(task_id="t1")
+ ti.run()
+ dr.task_instance_scheduling_decisions()
+ ti3 = dr.get_task_instance(task_id="tg1.t3")
+ assert not ti3.state
+
+
+def
test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete__mapped_skip_with_all_success(
+ dag_maker, session
+):
+ with dag_maker():
+
+ @task
+ def make_list():
+ return [4, 42, 2]
+
+ @task
+ def double(n):
+ if n == 42:
+ raise AirflowSkipException("42")
+ return n * 2
+
+ @task
+ def last(n):
+ print(n)
+
+ @task_group
+ def group(n: int) -> None:
+ last(double(n))
+
+ list = make_list()
+ group.expand(n=list)
+
+ dr = dag_maker.create_dagrun()
+
+ def _one_scheduling_decision_iteration() -> dict[tuple[str, int],
TaskInstance]:
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ return {(ti.task_id, ti.map_index): ti for ti in
decision.schedulable_tis}
+
+ tis = _one_scheduling_decision_iteration()
+ tis["make_list", -1].run()
+ assert tis["make_list", -1].state == State.SUCCESS
+
+ tis = _one_scheduling_decision_iteration()
+ tis["group.double", 0].run()
+ tis["group.double", 1].run()
+ tis["group.double", 2].run()
+
+ assert tis["group.double", 0].state == State.SUCCESS
+ assert tis["group.double", 1].state == State.SKIPPED
+ assert tis["group.double", 2].state == State.SUCCESS
+
+ tis = _one_scheduling_decision_iteration()
+ tis["group.last", 0].run()
+ tis["group.last", 2].run()
+ assert tis["group.last", 0].state == State.SUCCESS
+ assert dr.get_task_instance("group.last", map_index=1,
session=session).state == State.SKIPPED
+ assert tis["group.last", 2].state == State.SUCCESS