This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 69938fd163 Fix pre-mature evaluation of tasks in mapped task group 
(#34337)
69938fd163 is described below

commit 69938fd163045d750b8c218500d79bc89858f9c1
Author: Ephraim Anierobi <splendidzig...@gmail.com>
AuthorDate: Wed Nov 1 21:37:15 2023 +0100

    Fix pre-mature evaluation of tasks in mapped task group (#34337)
    
    * Fix pre-mature evaluation of tasks in mapped task group
    
    Getting the relevant upstream indexes of a task instance in a mapped task 
group
    should only be done when the task has expanded. If the task has not 
expanded yet,
    we should return None so that the task can wait for the upstreams before 
trying
    to run.
    This issue is more noticeable when the trigger rule is ONE_FAILED because 
then,
    the task instance is marked as SKIPPED.
    This commit fixes this issue.
    closes: https://github.com/apache/airflow/issues/34023
    
    * fixup! Fix pre-mature evaluation of tasks in mapped task group
    
    * fixup! fixup! Fix pre-mature evaluation of tasks in mapped task group
    
    * fixup! fixup! fixup! Fix pre-mature evaluation of tasks in mapped task 
group
    
    * Fix tests
---
 airflow/ti_deps/deps/trigger_rule_dep.py    | 18 +++++++++++
 tests/models/test_mappedoperator.py         |  4 +--
 tests/ti_deps/deps/test_trigger_rule_dep.py | 47 +++++++++++++++++++++++++----
 3 files changed, 61 insertions(+), 8 deletions(-)

diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py 
b/airflow/ti_deps/deps/trigger_rule_dep.py
index ca2a6100a2..6203b2a79b 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -27,6 +27,7 @@ from sqlalchemy import and_, func, or_, select
 from airflow.models.taskinstance import PAST_DEPENDS_MET
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils.state import TaskInstanceState
+from airflow.utils.task_group import MappedTaskGroup
 from airflow.utils.trigger_rule import TriggerRule as TR
 
 if TYPE_CHECKING:
@@ -132,6 +133,20 @@ class TriggerRuleDep(BaseTIDep):
             """
             return ti.task.get_mapped_ti_count(ti.run_id, session=session)
 
+        def _iter_expansion_dependencies() -> Iterator[str]:
+            from airflow.models.mappedoperator import MappedOperator
+
+            if isinstance(ti.task, MappedOperator):
+                for op in ti.task.iter_mapped_dependencies():
+                    yield op.task_id
+            task_group = ti.task.task_group
+            if task_group and task_group.iter_mapped_task_groups():
+                yield from (
+                    op.task_id
+                    for tg in task_group.iter_mapped_task_groups()
+                    for op in tg.iter_mapped_dependencies()
+                )
+
         @functools.lru_cache
         def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | 
range | None:
             """Get the given task's map indexes relevant to the current ti.
@@ -142,6 +157,9 @@ class TriggerRuleDep(BaseTIDep):
             """
             if TYPE_CHECKING:
                 assert isinstance(ti.task.dag, DAG)
+            if isinstance(ti.task.task_group, MappedTaskGroup):
+                if upstream_id not in set(_iter_expansion_dependencies()):
+                    return None
             try:
                 expanded_ti_count = _get_expanded_ti_count()
             except (NotFullyPopulated, NotMapped):
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index 7244c55774..5c2e23c1f9 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -1305,8 +1305,8 @@ class TestMappedSetupTeardown:
         states = self.get_states(dr)
         expected = {
             "file_transforms.my_setup": {0: "success", 1: "failed", 2: 
"skipped"},
-            "file_transforms.my_work": {0: "success", 1: "upstream_failed", 2: 
"skipped"},
-            "file_transforms.my_teardown": {0: "success", 1: 
"upstream_failed", 2: "skipped"},
+            "file_transforms.my_work": {2: "upstream_failed", 1: 
"upstream_failed", 0: "upstream_failed"},
+            "file_transforms.my_teardown": {2: "success", 1: "success", 0: 
"success"},
         }
 
         assert states == expected
diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py 
b/tests/ti_deps/deps/test_trigger_rule_dep.py
index 00cbcd449a..1bc8808cb8 100644
--- a/tests/ti_deps/deps/test_trigger_rule_dep.py
+++ b/tests/ti_deps/deps/test_trigger_rule_dep.py
@@ -1165,19 +1165,23 @@ def 
test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session):
     tis = _one_scheduling_decision_iteration()
     assert sorted(tis) == [("tg.t1", 0), ("tg.t1", 1), ("tg.t1", 2)]
 
-    # After running the first t1, the first t2 becomes immediately available.
+    # After running the first t1, the remaining t1 must be run before t2 is 
available.
     tis["tg.t1", 0].run()
     tis = _one_scheduling_decision_iteration()
-    assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2), ("tg.t2", 0)]
+    assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2)]
 
-    # Similarly for the subsequent t2 instances.
+    # After running all t1, t2 is available.
+    tis["tg.t1", 1].run()
     tis["tg.t1", 2].run()
     tis = _one_scheduling_decision_iteration()
-    assert sorted(tis) == [("tg.t1", 1), ("tg.t2", 0), ("tg.t2", 2)]
+    assert sorted(tis) == [("tg.t2", 0), ("tg.t2", 1), ("tg.t2", 2)]
 
-    # But running t2 partially does not make t3 available.
-    tis["tg.t1", 1].run()
+    # Similarly for t2 instances. They both have to complete before t3 is 
available
     tis["tg.t2", 0].run()
+    tis = _one_scheduling_decision_iteration()
+    assert sorted(tis) == [("tg.t2", 1), ("tg.t2", 2)]
+
+    # But running t2 partially does not make t3 available.
     tis["tg.t2", 2].run()
     tis = _one_scheduling_decision_iteration()
     assert sorted(tis) == [("tg.t2", 1)]
@@ -1407,3 +1411,34 @@ class TestTriggerRuleDepSetupConstraint:
             (status,) = self.get_dep_statuses(dr, "w2", 
flag_upstream_failed=True, session=session)
         assert status.reason.startswith("All setup tasks must complete 
successfully")
         assert self.get_ti(dr, "w2").state == expected
+
+
+def 
test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete(dag_maker,
 session):
+    """Test that one failed trigger rule works well in mapped task group"""
+    with dag_maker() as dag:
+
+        @dag.task
+        def t1():
+            return [1, 2, 3]
+
+        @task_group("tg1")
+        def tg1(a):
+            @dag.task()
+            def t2(a):
+                return a
+
+            @dag.task(trigger_rule=TriggerRule.ONE_FAILED)
+            def t3(a):
+                return a
+
+            t2(a) >> t3(a)
+
+        t = t1()
+        tg1.expand(a=t)
+
+    dr = dag_maker.create_dagrun()
+    ti = dr.get_task_instance(task_id="t1")
+    ti.run()
+    dr.task_instance_scheduling_decisions()
+    ti3 = dr.get_task_instance(task_id="tg1.t3")
+    assert not ti3.state

Reply via email to