Nataneljpwd commented on code in PR #64181:
URL: https://github.com/apache/airflow/pull/64181#discussion_r2986204125


##########
airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py:
##########
@@ -619,6 +619,60 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
                     reason=f"No strategy to evaluate trigger rule 
'{trigger_rule_str}'."
                 )
 
+        def _evaluate_teardown_scope() -> Iterator[TIDepStatus]:
+            """Ensure all tasks between setup(s) and this teardown have 
completed."""
+            if not task.dag:
+                return
+
+            setup_task_ids = {t.task_id for t in task.upstream_list if 
t.is_setup}
+            if not setup_task_ids:
+                return

Review Comment:
   Might not be needed as the for loop will just do nothing



##########
airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py:
##########
@@ -619,6 +619,60 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
                     reason=f"No strategy to evaluate trigger rule 
'{trigger_rule_str}'."
                 )
 
+        def _evaluate_teardown_scope() -> Iterator[TIDepStatus]:
+            """Ensure all tasks between setup(s) and this teardown have 
completed."""
+            if not task.dag:
+                return
+
+            setup_task_ids = {t.task_id for t in task.upstream_list if 
t.is_setup}
+            if not setup_task_ids:
+                return
+
+            all_upstream_ids = task.get_flat_relative_ids(upstream=True)
+            indirect_upstream_ids = all_upstream_ids - task.upstream_task_ids
+
+            if not indirect_upstream_ids:
+                return
+
+            in_scope_ids = set()
+            for setup_id in setup_task_ids:
+                setup_obj = task.dag.get_task(setup_id)
+                in_scope_ids.update(indirect_upstream_ids & 
setup_obj.get_flat_relative_ids(upstream=False))
+
+            if not in_scope_ids:
+                return
+
+            in_scope_tasks = {tid: task.dag.get_task(tid) for tid in 
in_scope_ids}
+
+            finished_upstream_tis = (
+                x
+                for x in 
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
+                if _is_relevant_upstream(upstream=x, relevant_ids=in_scope_ids)
+            )
+            done = sum(1 for _ in finished_upstream_tis)

Review Comment:
   Why is the done variable needed? The finished upstream is only used for the 
sum, so why not immediately sum and instead of x for ... Do 1 for ...
   No need to loop twice here



##########
airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py:
##########
@@ -619,6 +619,60 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
                     reason=f"No strategy to evaluate trigger rule 
'{trigger_rule_str}'."
                 )
 
+        def _evaluate_teardown_scope() -> Iterator[TIDepStatus]:
+            """Ensure all tasks between setup(s) and this teardown have 
completed."""
+            if not task.dag:
+                return
+
+            setup_task_ids = {t.task_id for t in task.upstream_list if 
t.is_setup}
+            if not setup_task_ids:
+                return
+
+            all_upstream_ids = task.get_flat_relative_ids(upstream=True)
+            indirect_upstream_ids = all_upstream_ids - task.upstream_task_ids
+
+            if not indirect_upstream_ids:
+                return
+
+            in_scope_ids = set()
+            for setup_id in setup_task_ids:
+                setup_obj = task.dag.get_task(setup_id)
+                in_scope_ids.update(indirect_upstream_ids & 
setup_obj.get_flat_relative_ids(upstream=False))
+
+            if not in_scope_ids:
+                return
+
+            in_scope_tasks = {tid: task.dag.get_task(tid) for tid in 
in_scope_ids}
+
+            finished_upstream_tis = (
+                x
+                for x in 
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
+                if _is_relevant_upstream(upstream=x, relevant_ids=in_scope_ids)
+            )
+            done = sum(1 for _ in finished_upstream_tis)
+
+            if not any(t.get_needs_expansion() for t in 
in_scope_tasks.values()):
+                expected = len(in_scope_tasks)
+            else:
+                task_id_counts = session.execute(
+                    select(TaskInstance.task_id, 
func.count(TaskInstance.task_id))

Review Comment:
   Why do we select the task_id here?



##########
airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py:
##########
@@ -619,6 +619,60 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
                     reason=f"No strategy to evaluate trigger rule 
'{trigger_rule_str}'."
                 )
 
+        def _evaluate_teardown_scope() -> Iterator[TIDepStatus]:
+            """Ensure all tasks between setup(s) and this teardown have 
completed."""
+            if not task.dag:
+                return
+
+            setup_task_ids = {t.task_id for t in task.upstream_list if 
t.is_setup}
+            if not setup_task_ids:
+                return
+
+            all_upstream_ids = task.get_flat_relative_ids(upstream=True)
+            indirect_upstream_ids = all_upstream_ids - task.upstream_task_ids
+
+            if not indirect_upstream_ids:
+                return
+
+            in_scope_ids = set()
+            for setup_id in setup_task_ids:
+                setup_obj = task.dag.get_task(setup_id)
+                in_scope_ids.update(indirect_upstream_ids & 
setup_obj.get_flat_relative_ids(upstream=False))
+
+            if not in_scope_ids:
+                return
+
+            in_scope_tasks = {tid: task.dag.get_task(tid) for tid in 
in_scope_ids}
+
+            finished_upstream_tis = (
+                x
+                for x in 
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
+                if _is_relevant_upstream(upstream=x, relevant_ids=in_scope_ids)
+            )
+            done = sum(1 for _ in finished_upstream_tis)
+
+            if not any(t.get_needs_expansion() for t in 
in_scope_tasks.values()):
+                expected = len(in_scope_tasks)
+            else:
+                task_id_counts = session.execute(
+                    select(TaskInstance.task_id, 
func.count(TaskInstance.task_id))
+                    .where(TaskInstance.dag_id == ti.dag_id, 
TaskInstance.run_id == ti.run_id)
+                    
.where(or_(*_iter_upstream_conditions(relevant_tasks=in_scope_tasks)))
+                    .group_by(TaskInstance.task_id)
+                ).all()
+                expected = sum(count for _, count in task_id_counts)

Review Comment:
   I think it will give better performance if this is also done in SQL, as 
looks like you only need the sum of count out of the query, and for large dags 
(thousands of tasks) it will be way faster



##########
airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py:
##########
@@ -619,6 +619,60 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
                     reason=f"No strategy to evaluate trigger rule 
'{trigger_rule_str}'."
                 )
 
+        def _evaluate_teardown_scope() -> Iterator[TIDepStatus]:
+            """Ensure all tasks between setup(s) and this teardown have 
completed."""
+            if not task.dag:
+                return

Review Comment:
   Why is this check added?



##########
airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py:
##########
@@ -619,6 +619,60 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
                     reason=f"No strategy to evaluate trigger rule 
'{trigger_rule_str}'."
                 )
 
+        def _evaluate_teardown_scope() -> Iterator[TIDepStatus]:
+            """Ensure all tasks between setup(s) and this teardown have 
completed."""
+            if not task.dag:
+                return
+
+            setup_task_ids = {t.task_id for t in task.upstream_list if 
t.is_setup}
+            if not setup_task_ids:
+                return
+
+            all_upstream_ids = task.get_flat_relative_ids(upstream=True)
+            indirect_upstream_ids = all_upstream_ids - task.upstream_task_ids
+
+            if not indirect_upstream_ids:
+                return
+
+            in_scope_ids = set()
+            for setup_id in setup_task_ids:
+                setup_obj = task.dag.get_task(setup_id)
+                in_scope_ids.update(indirect_upstream_ids & 
setup_obj.get_flat_relative_ids(upstream=False))
+
+            if not in_scope_ids:
+                return
+

Review Comment:
   Same as above



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to