This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-10-test by this push:
     new 4b27c3fb087 [v2-10-test] Fix premature evaluation in mapped task group 
(#44937)
4b27c3fb087 is described below

commit 4b27c3fb087a61ca42be8d09b1462b731aaa640d
Author: Shahar Epstein <[email protected]>
AuthorDate: Tue Dec 17 23:07:42 2024 +0200

    [v2-10-test] Fix premature evaluation in mapped task group (#44937)
    
    * Fix docstrings and warnings in trigger_rule_dep.py
    
    * Fix pre-mature evaluation of tasks in mapped task group
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    Co-authored-by: Ephraim Anierobi <[email protected]>
    
    * Add newsfragment
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    Co-authored-by: Ephraim Anierobi <[email protected]>
---
 airflow/ti_deps/deps/trigger_rule_dep.py | 42 ++++++++++------
 newsfragments/44937.bugfix.rst           |  1 +
 tests/models/test_mappedoperator.py      | 86 +++++++++++++++++++++++++++++++-
 3 files changed, 113 insertions(+), 16 deletions(-)

diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py 
b/airflow/ti_deps/deps/trigger_rule_dep.py
index 76291c8a057..6e00f718be2 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -27,6 +27,7 @@ from sqlalchemy import and_, func, or_, select
 from airflow.models.taskinstance import PAST_DEPENDS_MET
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils.state import TaskInstanceState
+from airflow.utils.task_group import MappedTaskGroup
 from airflow.utils.trigger_rule import TriggerRule as TR
 
 if TYPE_CHECKING:
@@ -63,8 +64,7 @@ class _UpstreamTIStates(NamedTuple):
         ``counter`` is inclusive of ``setup_counter`` -- e.g. if there are 2 
skipped upstreams, one
         of which is a setup, then counter will show 2 skipped and setup 
counter will show 1.
 
-        :param ti: the ti that we want to calculate deps for
-        :param finished_tis: all the finished tasks of the dag_run
+        :param finished_upstreams: all the finished upstreams of the dag_run
         """
         counter: dict[str, int] = Counter()
         setup_counter: dict[str, int] = Counter()
@@ -143,6 +143,19 @@ class TriggerRuleDep(BaseTIDep):
 
             return ti.task.get_mapped_ti_count(ti.run_id, session=session)
 
+        def _iter_expansion_dependencies(task_group: MappedTaskGroup) -> 
Iterator[str]:
+            from airflow.models.mappedoperator import MappedOperator
+
+            if isinstance(ti.task, MappedOperator):
+                for op in ti.task.iter_mapped_dependencies():
+                    yield op.task_id
+            if task_group and task_group.iter_mapped_task_groups():
+                yield from (
+                    op.task_id
+                    for tg in task_group.iter_mapped_task_groups()
+                    for op in tg.iter_mapped_dependencies()
+                )
+
         @functools.lru_cache
         def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | 
range | None:
             """
@@ -156,6 +169,13 @@ class TriggerRuleDep(BaseTIDep):
                 assert ti.task
                 assert isinstance(ti.task.dag, DAG)
 
+            if isinstance(ti.task.task_group, MappedTaskGroup):
+                is_fast_triggered = ti.task.trigger_rule in (TR.ONE_SUCCESS, 
TR.ONE_FAILED, TR.ONE_DONE)
+                if is_fast_triggered and upstream_id not in set(
+                    _iter_expansion_dependencies(task_group=ti.task.task_group)
+                ):
+                    return None
+
             try:
                 expanded_ti_count = _get_expanded_ti_count()
             except (NotFullyPopulated, NotMapped):
@@ -217,7 +237,7 @@ class TriggerRuleDep(BaseTIDep):
             for upstream_id in relevant_tasks:
                 map_indexes = _get_relevant_upstream_map_indexes(upstream_id)
                 if map_indexes is None:  # All tis of this upstream are 
dependencies.
-                    yield (TaskInstance.task_id == upstream_id)
+                    yield TaskInstance.task_id == upstream_id
                     continue
                 # At this point we know we want to depend on only selected tis
                 # of this upstream task. Since the upstream may not have been
@@ -237,11 +257,9 @@ class TriggerRuleDep(BaseTIDep):
 
         def _evaluate_setup_constraint(*, relevant_setups) -> 
Iterator[tuple[TIDepStatus, bool]]:
             """
-            Evaluate whether ``ti``'s trigger rule was met.
+            Evaluate whether ``ti``'s trigger rule was met as part of the 
setup constraint.
 
-            :param ti: Task instance to evaluate the trigger rule of.
-            :param dep_context: The current dependency context.
-            :param session: Database session.
+            :param relevant_setups: Relevant setups for the current task 
instance.
             """
             if TYPE_CHECKING:
                 assert ti.task
@@ -327,13 +345,7 @@ class TriggerRuleDep(BaseTIDep):
                 )
 
         def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
-            """
-            Evaluate whether ``ti``'s trigger rule was met.
-
-            :param ti: Task instance to evaluate the trigger rule of.
-            :param dep_context: The current dependency context.
-            :param session: Database session.
-            """
+            """Evaluate whether ``ti``'s trigger rule in direct relatives was 
met."""
             if TYPE_CHECKING:
                 assert ti.task
 
@@ -433,7 +445,7 @@ class TriggerRuleDep(BaseTIDep):
                     )
                     if not past_depends_met:
                         yield self._failing_status(
-                            reason=("Task should be skipped but the past 
depends are not met")
+                            reason="Task should be skipped but the past 
depends are not met"
                         )
                         return
                 changed = ti.set_state(new_state, session)
diff --git a/newsfragments/44937.bugfix.rst b/newsfragments/44937.bugfix.rst
new file mode 100644
index 00000000000..d50da4de82f
--- /dev/null
+++ b/newsfragments/44937.bugfix.rst
@@ -0,0 +1 @@
+Fix pre-mature evaluation of tasks in mapped task group. The origins of the 
bug are in ``TriggerRuleDep``, when dealing with ``TriggerRule`` that is fastly 
triggered (i.e, ``ONE_FAILED``, ``ONE_SUCCESS`, or ``ONE_DONE``). Please note 
that at time of merging, this fix has been applied only for Airflow version > 
2.10.4 and < 3, and should be ported to v3 after merging PR #40460.
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index cf547912fb9..d1e896200c7 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -37,7 +37,7 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom_arg import XComArg
 from airflow.operators.python import PythonOperator
-from airflow.utils.state import TaskInstanceState
+from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.task_instance_session import 
set_current_task_instance_session
 from airflow.utils.trigger_rule import TriggerRule
@@ -1784,3 +1784,87 @@ class TestMappedSetupTeardown:
             "group.last": {0: "success", 1: "skipped", 2: "success"},
         }
         assert states == expected
+
+
+def 
test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete(dag_maker,
 session):
+    """Test that one failed trigger rule works well in mapped task group"""
+    with dag_maker() as dag:
+
+        @dag.task
+        def t1():
+            return [1, 2, 3]
+
+        @task_group("tg1")
+        def tg1(a):
+            @dag.task()
+            def t2(a):
+                return a
+
+            @dag.task(trigger_rule=TriggerRule.ONE_FAILED)
+            def t3(a):
+                return a
+
+            t2(a) >> t3(a)
+
+        t = t1()
+        tg1.expand(a=t)
+
+    dr = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id="t1")
+    ti.run()
+    dr.task_instance_scheduling_decisions()
+    ti3 = dr.get_task_instance(task_id="tg1.t3")
+    assert not ti3.state
+
+
+def 
test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete__mapped_skip_with_all_success(
+    dag_maker, session
+):
+    with dag_maker():
+
+        @task
+        def make_list():
+            return [4, 42, 2]
+
+        @task
+        def double(n):
+            if n == 42:
+                raise AirflowSkipException("42")
+            return n * 2
+
+        @task
+        def last(n):
+            print(n)
+
+        @task_group
+        def group(n: int) -> None:
+            last(double(n))
+
+        list = make_list()
+        group.expand(n=list)
+
+    dr = dag_maker.create_dagrun()
+
+    def _one_scheduling_decision_iteration() -> dict[tuple[str, int], 
TaskInstance]:
+        decision = dr.task_instance_scheduling_decisions(session=session)
+        return {(ti.task_id, ti.map_index): ti for ti in 
decision.schedulable_tis}
+
+    tis = _one_scheduling_decision_iteration()
+    tis["make_list", -1].run()
+    assert tis["make_list", -1].state == State.SUCCESS
+
+    tis = _one_scheduling_decision_iteration()
+    tis["group.double", 0].run()
+    tis["group.double", 1].run()
+    tis["group.double", 2].run()
+
+    assert tis["group.double", 0].state == State.SUCCESS
+    assert tis["group.double", 1].state == State.SKIPPED
+    assert tis["group.double", 2].state == State.SUCCESS
+
+    tis = _one_scheduling_decision_iteration()
+    tis["group.last", 0].run()
+    tis["group.last", 2].run()
+    assert tis["group.last", 0].state == State.SUCCESS
+    assert dr.get_task_instance("group.last", map_index=1, 
session=session).state == State.SKIPPED
+    assert tis["group.last", 2].state == State.SUCCESS

Reply via email to