This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-7-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit b5dbf97317c4ea0dde6a2b006d34b694c95d2f9a
Author: Daniel Standish <15932138+dstand...@users.noreply.github.com>
AuthorDate: Sun Sep 3 08:48:51 2023 -0700

    Ensure that tasks wait for running indirect setup (#33903)
    
    * move internal functions to methods -- no behavior change
    
    * add setup constraint logic
    
    * comments
    
    * simplify
    
    * simplify
    
    * fix
    
    * fix
    
    * update tests
    
    * static checks
    
    * add constraint that setup tasks followed by ALL_SUCCESS rule
    
    * add todo
    
    * docs
    
    * docs
    
    * add test
    
    * fix static check
    
    (cherry picked from commit e75cecab72bd59ae9fa04631bdc7a27a745d61fe)
---
 airflow/models/abstractoperator.py               |  11 +
 airflow/models/dag.py                            |   8 +
 airflow/ti_deps/deps/trigger_rule_dep.py         | 572 ++++++++++++++---------
 docs/apache-airflow/howto/setup-and-teardown.rst |   8 +
 tests/models/test_dag.py                         |  14 +
 tests/models/test_taskinstance.py                |  61 ++-
 tests/ti_deps/deps/test_trigger_rule_dep.py      | 164 +++++++
 7 files changed, 602 insertions(+), 236 deletions(-)

diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index ba357c0bd1..675550c82c 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -293,6 +293,17 @@ class AbstractOperator(Templater, DAGNode):
                     if t.is_teardown and not t == self:
                         yield t
 
+    def get_upstreams_only_setups(self) -> Iterable[Operator]:
+        """
+        Return relevant upstream setups.
+
+        This method is meant to be used when we are checking task dependencies 
where we need
+        to wait for all the upstream setups to complete before we can run the 
task.
+        """
+        for task in self.get_upstreams_only_setups_and_teardowns():
+            if task.is_setup:
+                yield task
+
     def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | 
MappedTaskGroup]:
         """Return mapped nodes that are direct dependencies of the current 
task.
 
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 75fee04145..b5f31f536f 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -128,6 +128,7 @@ from airflow.utils.sqlalchemy import (
     with_row_locks,
 )
 from airflow.utils.state import DagRunState, TaskInstanceState
+from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
 
 if TYPE_CHECKING:
@@ -717,6 +718,13 @@ class DAG(LoggingMixin):
         :meta private:
         """
         for task in self.tasks:
+            if task.is_setup:
+                for down_task in task.downstream_list:
+                    if not down_task.is_teardown and down_task.trigger_rule != 
TriggerRule.ALL_SUCCESS:
+                        # todo: we can relax this to allow out-of-scope tasks 
to have other trigger rules
+                        # this is required to ensure consistent behavior of dag
+                        # when clearing an indirect setup
+                        raise ValueError("Setup tasks must be followed with 
trigger rule ALL_SUCCESS.")
             FailStopDagInvalidTriggerRule.check(dag=self, 
trigger_rule=task.trigger_rule)
 
     def __repr__(self):
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py 
b/airflow/ti_deps/deps/trigger_rule_dep.py
index dbdf692e76..609f3b539e 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import collections
 import collections.abc
 import functools
-from typing import TYPE_CHECKING, Iterator, NamedTuple
+from typing import TYPE_CHECKING, Iterator, KeysView, NamedTuple
 
 from sqlalchemy import and_, func, or_, select
 
@@ -34,6 +34,7 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Session
     from sqlalchemy.sql.expression import ColumnOperators
 
+    from airflow import DAG
     from airflow.models.taskinstance import TaskInstance
 
 
@@ -121,10 +122,6 @@ class TriggerRuleDep(BaseTIDep):
         from airflow.models.operator import needs_expansion
         from airflow.models.taskinstance import TaskInstance
 
-        task = ti.task
-        upstream_tasks = {t.task_id: t for t in task.upstream_list}
-        trigger_rule = task.trigger_rule
-
         @functools.lru_cache
         def _get_expanded_ti_count() -> int:
             """Get how many tis the current task is supposed to be expanded 
into.
@@ -132,7 +129,7 @@ class TriggerRuleDep(BaseTIDep):
             This extra closure allows us to query the database only when 
needed,
             and at most once.
             """
-            return task.get_mapped_ti_count(ti.run_id, session=session)
+            return ti.task.get_mapped_ti_count(ti.run_id, session=session)
 
         @functools.lru_cache
         def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | 
range | None:
@@ -142,24 +139,34 @@ class TriggerRuleDep(BaseTIDep):
             and at most once for each task (instead of once for each expanded
             task instance of the same task).
             """
+            if TYPE_CHECKING:
+                assert isinstance(ti.task.dag, DAG)
             try:
                 expanded_ti_count = _get_expanded_ti_count()
             except (NotFullyPopulated, NotMapped):
                 return None
             return ti.get_relevant_upstream_map_indexes(
-                upstream_tasks[upstream_id],
-                expanded_ti_count,
+                upstream=ti.task.dag.task_dict[upstream_id],
+                ti_count=expanded_ti_count,
                 session=session,
             )
 
-        def _is_relevant_upstream(upstream: TaskInstance) -> bool:
-            """Whether a task instance is a "relevant upstream" of the current 
task."""
+        def _is_relevant_upstream(upstream: TaskInstance, relevant_ids: 
set[str] | KeysView[str]) -> bool:
+            """
+            Whether a task instance is a "relevant upstream" of the current 
task.
+
+            This will return false if upstream.task_id is not in relevant_ids,
+            or if both of the following are true:
+                1. upstream.task_id in relevant_ids is True
+                2. ti is in a mapped task group and upstream has a map index
+                  that ti does not depend on.
+            """
             # Not actually an upstream task.
-            if upstream.task_id not in task.upstream_task_ids:
+            if upstream.task_id not in relevant_ids:
                 return False
             # The current task is not in a mapped task group. All tis from an
             # upstream task are relevant.
-            if task.get_closest_mapped_task_group() is None:
+            if ti.task.get_closest_mapped_task_group() is None:
                 return True
             # The upstream ti is not expanded. The upstream may be mapped or
             # not, but the ti is relevant either way.
@@ -167,7 +174,7 @@ class TriggerRuleDep(BaseTIDep):
                 return True
             # Now we need to perform fine-grained check on whether this 
specific
             # upstream ti's map index is relevant.
-            relevant = _get_relevant_upstream_map_indexes(upstream.task_id)
+            relevant = 
_get_relevant_upstream_map_indexes(upstream_id=upstream.task_id)
             if relevant is None:
                 return True
             if relevant == upstream.map_index:
@@ -176,31 +183,17 @@ class TriggerRuleDep(BaseTIDep):
                 return True
             return False
 
-        finished_upstream_tis = (
-            finished_ti
-            for finished_ti in 
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
-            if _is_relevant_upstream(finished_ti)
-        )
-        upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis)
-
-        success = upstream_states.success
-        skipped = upstream_states.skipped
-        failed = upstream_states.failed
-        upstream_failed = upstream_states.upstream_failed
-        removed = upstream_states.removed
-        done = upstream_states.done
-        success_setup = upstream_states.success_setup
-        skipped_setup = upstream_states.skipped_setup
-
-        def _iter_upstream_conditions() -> Iterator[ColumnOperators]:
+        def _iter_upstream_conditions(relevant_tasks: dict) -> 
Iterator[ColumnOperators]:
             # Optimization: If the current task is not in a mapped task group,
             # it depends on all upstream task instances.
-            if task.get_closest_mapped_task_group() is None:
-                yield TaskInstance.task_id.in_(upstream_tasks)
+            from airflow.models.taskinstance import TaskInstance
+
+            if ti.task.get_closest_mapped_task_group() is None:
+                yield TaskInstance.task_id.in_(relevant_tasks.keys())
                 return
             # Otherwise we need to figure out which map indexes are depended on
             # for each upstream by the current task instance.
-            for upstream_id in upstream_tasks:
+            for upstream_id in relevant_tasks.keys():
                 map_indexes = _get_relevant_upstream_map_indexes(upstream_id)
                 if map_indexes is None:  # All tis of this upstream are 
dependencies.
                     yield (TaskInstance.task_id == upstream_id)
@@ -221,27 +214,49 @@ class TriggerRuleDep(BaseTIDep):
                 else:
                     yield and_(TaskInstance.task_id == upstream_id, 
TaskInstance.map_index == map_indexes)
 
-        # Optimization: Don't need to hit the database if all upstreams are
-        # "simple" tasks (no task or task group mapping involved).
-        if not any(needs_expansion(t) for t in upstream_tasks.values()):
-            upstream = len(upstream_tasks)
-            upstream_setup = sum(1 for x in upstream_tasks.values() if 
x.is_setup)
-        else:
-            task_id_counts = session.execute(
-                select(TaskInstance.task_id, func.count(TaskInstance.task_id))
-                .where(TaskInstance.dag_id == ti.dag_id, TaskInstance.run_id 
== ti.run_id)
-                .where(or_(*_iter_upstream_conditions()))
-                .group_by(TaskInstance.task_id)
-            ).all()
-            upstream = sum(count for _, count in task_id_counts)
-            upstream_setup = sum(c for t, c in task_id_counts if 
upstream_tasks[t].is_setup)
-
-        upstream_done = done >= upstream
-
-        changed = False
-        new_state = None
-        if dep_context.flag_upstream_failed:
-            if trigger_rule == TR.ALL_SUCCESS:
+        def _evaluate_setup_constraint(*, relevant_setups) -> 
Iterator[tuple[TIDepStatus, bool]]:
+            """Evaluate whether ``ti``'s trigger rule was met.
+
+            :param ti: Task instance to evaluate the trigger rule of.
+            :param dep_context: The current dependency context.
+            :param session: Database session.
+            """
+            task = ti.task
+
+            indirect_setups = {k: v for k, v in relevant_setups.items() if k 
not in task.upstream_task_ids}
+            finished_upstream_tis = (
+                x
+                for x in 
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
+                if _is_relevant_upstream(upstream=x, 
relevant_ids=indirect_setups.keys())
+            )
+            upstream_states = 
_UpstreamTIStates.calculate(finished_upstream_tis)
+
+            # all of these counts reflect indirect setups which are relevant 
for this ti
+            success = upstream_states.success
+            skipped = upstream_states.skipped
+            failed = upstream_states.failed
+            upstream_failed = upstream_states.upstream_failed
+            removed = upstream_states.removed
+
+            # Optimization: Don't need to hit the database if all upstreams are
+            # "simple" tasks (no task or task group mapping involved).
+            if not any(needs_expansion(t) for t in indirect_setups.values()):
+                upstream = len(indirect_setups)
+            else:
+                task_id_counts = session.execute(
+                    select(TaskInstance.task_id, 
func.count(TaskInstance.task_id))
+                    .where(TaskInstance.dag_id == ti.dag_id, 
TaskInstance.run_id == ti.run_id)
+                    
.where(or_(*_iter_upstream_conditions(relevant_tasks=indirect_setups)))
+                    .group_by(TaskInstance.task_id)
+                ).all()
+                upstream = sum(count for _, count in task_id_counts)
+
+            new_state = None
+            changed = False
+
+            # if there's a failure, we mark upstream_failed; if there's a 
skip, we mark skipped
+            # in either case, we don't wait for all relevant setups to complete
+            if dep_context.flag_upstream_failed:
                 if upstream_failed or failed:
                     new_state = TaskInstanceState.UPSTREAM_FAILED
                 elif skipped:
@@ -249,196 +264,297 @@ class TriggerRuleDep(BaseTIDep):
                 elif removed and success and ti.map_index > -1:
                     if ti.map_index >= success:
                         new_state = TaskInstanceState.REMOVED
-            elif trigger_rule == TR.ALL_FAILED:
-                if success or skipped:
-                    new_state = TaskInstanceState.SKIPPED
-            elif trigger_rule == TR.ONE_SUCCESS:
-                if upstream_done and done == skipped:
-                    # if upstream is done and all are skipped mark as skipped
-                    new_state = TaskInstanceState.SKIPPED
-                elif upstream_done and success <= 0:
-                    # if upstream is done and there are no success mark as 
upstream failed
-                    new_state = TaskInstanceState.UPSTREAM_FAILED
-            elif trigger_rule == TR.ONE_FAILED:
-                if upstream_done and not (failed or upstream_failed):
-                    new_state = TaskInstanceState.SKIPPED
-            elif trigger_rule == TR.ONE_DONE:
-                if upstream_done and not (failed or success):
-                    new_state = TaskInstanceState.SKIPPED
-            elif trigger_rule == TR.NONE_FAILED:
-                if upstream_failed or failed:
-                    new_state = TaskInstanceState.UPSTREAM_FAILED
-            elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
-                if upstream_failed or failed:
-                    new_state = TaskInstanceState.UPSTREAM_FAILED
-                elif skipped == upstream:
-                    new_state = TaskInstanceState.SKIPPED
-            elif trigger_rule == TR.NONE_SKIPPED:
-                if skipped:
-                    new_state = TaskInstanceState.SKIPPED
-            elif trigger_rule == TR.ALL_SKIPPED:
-                if success or failed:
-                    new_state = TaskInstanceState.SKIPPED
-            elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS:
-                if upstream_done and upstream_setup and skipped_setup >= 
upstream_setup:
-                    # when there is an upstream setup and they have all 
skipped, then skip
-                    new_state = TaskInstanceState.SKIPPED
-                elif upstream_done and upstream_setup and success_setup == 0:
-                    # when there is an upstream setup, if none succeeded, mark 
upstream failed
-                    # if at least one setup ran, we'll let it run
-                    new_state = TaskInstanceState.UPSTREAM_FAILED
-        if new_state is not None:
-            if new_state == TaskInstanceState.SKIPPED and 
dep_context.wait_for_past_depends_before_skipping:
-                past_depends_met = ti.xcom_pull(
-                    task_ids=ti.task_id, key=PAST_DEPENDS_MET, 
session=session, default=False
-                )
-                if not past_depends_met:
-                    yield self._failing_status(
-                        reason=("Task should be skipped but the past depends 
are not met")
+
+            if new_state is not None:
+                if (
+                    new_state == TaskInstanceState.SKIPPED
+                    and dep_context.wait_for_past_depends_before_skipping
+                ):
+                    past_depends_met = ti.xcom_pull(
+                        task_ids=ti.task_id, key=PAST_DEPENDS_MET, 
session=session, default=False
                     )
-                    return
-            changed = ti.set_state(new_state, session)
+                    if not past_depends_met:
+                        yield self._failing_status(
+                            reason="Task should be skipped but the past 
depends are not met"
+                        ), changed
+                        return
+                changed = ti.set_state(new_state, session)
 
-        if changed:
-            dep_context.have_changed_ti_states = True
+            if changed:
+                dep_context.have_changed_ti_states = True
 
-        if trigger_rule == TR.ONE_SUCCESS:
-            if success <= 0:
+            non_successes = upstream - success
+            if ti.map_index > -1:
+                non_successes -= removed
+            if non_successes > 0:
                 yield self._failing_status(
                     reason=(
-                        f"Task's trigger rule '{trigger_rule}' requires one 
upstream task success, "
-                        f"but none were found. 
upstream_states={upstream_states}, "
+                        f"All setup tasks must complete successfully. Relevant 
setups: {relevant_setups}: "
+                        f"upstream_states={upstream_states}, "
                         f"upstream_task_ids={task.upstream_task_ids}"
+                    ),
+                ), changed
+
+        def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
+            """Evaluate whether ``ti``'s trigger rule was met.
+
+            :param ti: Task instance to evaluate the trigger rule of.
+            :param dep_context: The current dependency context.
+            :param session: Database session.
+            """
+            task = ti.task
+            upstream_tasks = {t.task_id: t for t in task.upstream_list}
+            trigger_rule = task.trigger_rule
+
+            finished_upstream_tis = (
+                finished_ti
+                for finished_ti in 
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
+                if _is_relevant_upstream(upstream=finished_ti, 
relevant_ids=ti.task.upstream_task_ids)
+            )
+            upstream_states = 
_UpstreamTIStates.calculate(finished_upstream_tis)
+
+            success = upstream_states.success
+            skipped = upstream_states.skipped
+            failed = upstream_states.failed
+            upstream_failed = upstream_states.upstream_failed
+            removed = upstream_states.removed
+            done = upstream_states.done
+            success_setup = upstream_states.success_setup
+            skipped_setup = upstream_states.skipped_setup
+
+            # Optimization: Don't need to hit the database if all upstreams are
+            # "simple" tasks (no task or task group mapping involved).
+            if not any(needs_expansion(t) for t in upstream_tasks.values()):
+                upstream = len(upstream_tasks)
+                upstream_setup = sum(1 for x in upstream_tasks.values() if 
x.is_setup)
+            else:
+                task_id_counts = session.execute(
+                    select(TaskInstance.task_id, 
func.count(TaskInstance.task_id))
+                    .where(TaskInstance.dag_id == ti.dag_id, 
TaskInstance.run_id == ti.run_id)
+                    
.where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks)))
+                    .group_by(TaskInstance.task_id)
+                ).all()
+                upstream = sum(count for _, count in task_id_counts)
+                upstream_setup = sum(c for t, c in task_id_counts if 
upstream_tasks[t].is_setup)
+
+            upstream_done = done >= upstream
+
+            changed = False
+            new_state = None
+            if dep_context.flag_upstream_failed:
+                if trigger_rule == TR.ALL_SUCCESS:
+                    if upstream_failed or failed:
+                        new_state = TaskInstanceState.UPSTREAM_FAILED
+                    elif skipped:
+                        new_state = TaskInstanceState.SKIPPED
+                    elif removed and success and ti.map_index > -1:
+                        if ti.map_index >= success:
+                            new_state = TaskInstanceState.REMOVED
+                elif trigger_rule == TR.ALL_FAILED:
+                    if success or skipped:
+                        new_state = TaskInstanceState.SKIPPED
+                elif trigger_rule == TR.ONE_SUCCESS:
+                    if upstream_done and done == skipped:
+                        # if upstream is done and all are skipped mark as 
skipped
+                        new_state = TaskInstanceState.SKIPPED
+                    elif upstream_done and success <= 0:
+                        # if upstream is done and there are no success mark as 
upstream failed
+                        new_state = TaskInstanceState.UPSTREAM_FAILED
+                elif trigger_rule == TR.ONE_FAILED:
+                    if upstream_done and not (failed or upstream_failed):
+                        new_state = TaskInstanceState.SKIPPED
+                elif trigger_rule == TR.ONE_DONE:
+                    if upstream_done and not (failed or success):
+                        new_state = TaskInstanceState.SKIPPED
+                elif trigger_rule == TR.NONE_FAILED:
+                    if upstream_failed or failed:
+                        new_state = TaskInstanceState.UPSTREAM_FAILED
+                elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
+                    if upstream_failed or failed:
+                        new_state = TaskInstanceState.UPSTREAM_FAILED
+                    elif skipped == upstream:
+                        new_state = TaskInstanceState.SKIPPED
+                elif trigger_rule == TR.NONE_SKIPPED:
+                    if skipped:
+                        new_state = TaskInstanceState.SKIPPED
+                elif trigger_rule == TR.ALL_SKIPPED:
+                    if success or failed:
+                        new_state = TaskInstanceState.SKIPPED
+                elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS:
+                    if upstream_done and upstream_setup and skipped_setup >= 
upstream_setup:
+                        # when there is an upstream setup and they have all 
skipped, then skip
+                        new_state = TaskInstanceState.SKIPPED
+                    elif upstream_done and upstream_setup and success_setup == 
0:
+                        # when there is an upstream setup, if none succeeded, 
mark upstream failed
+                        # if at least one setup ran, we'll let it run
+                        new_state = TaskInstanceState.UPSTREAM_FAILED
+            if new_state is not None:
+                if (
+                    new_state == TaskInstanceState.SKIPPED
+                    and dep_context.wait_for_past_depends_before_skipping
+                ):
+                    past_depends_met = ti.xcom_pull(
+                        task_ids=ti.task_id, key=PAST_DEPENDS_MET, 
session=session, default=False
                     )
-                )
-        elif trigger_rule == TR.ONE_FAILED:
-            if not failed and not upstream_failed:
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}' requires one 
upstream task failure, "
-                        f"but none were found. 
upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+                    if not past_depends_met:
+                        yield self._failing_status(
+                            reason=("Task should be skipped but the past 
depends are not met")
+                        )
+                        return
+                changed = ti.set_state(new_state, session)
+
+            if changed:
+                dep_context.have_changed_ti_states = True
+
+            if trigger_rule == TR.ONE_SUCCESS:
+                if success <= 0:
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' requires 
one upstream task success, "
+                            f"but none were found. 
upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-        elif trigger_rule == TR.ONE_DONE:
-            if success + failed <= 0:
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}'"
-                        "requires at least one upstream task failure or 
success"
-                        f"but none were failed or success. 
upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+            elif trigger_rule == TR.ONE_FAILED:
+                if not failed and not upstream_failed:
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' requires 
one upstream task failure, "
+                            f"but none were found. 
upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-        elif trigger_rule == TR.ALL_SUCCESS:
-            num_failures = upstream - success
-            if ti.map_index > -1:
-                num_failures -= removed
-            if num_failures > 0:
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}' requires all 
upstream tasks to have "
-                        f"succeeded, but found {num_failures} non-success(es). 
"
-                        f"upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+            elif trigger_rule == TR.ONE_DONE:
+                if success + failed <= 0:
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}'"
+                            "requires at least one upstream task failure or 
success"
+                            f"but none were failed or success. 
upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-        elif trigger_rule == TR.ALL_FAILED:
-            num_success = upstream - failed - upstream_failed
-            if ti.map_index > -1:
-                num_success -= removed
-            if num_success > 0:
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}' requires all 
upstream tasks to have failed, "
-                        f"but found {num_success} non-failure(s). "
-                        f"upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+            elif trigger_rule == TR.ALL_SUCCESS:
+                num_failures = upstream - success
+                if ti.map_index > -1:
+                    num_failures -= removed
+                if num_failures > 0:
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' requires 
all upstream tasks to have "
+                            f"succeeded, but found {num_failures} 
non-success(es). "
+                            f"upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-        elif trigger_rule == TR.ALL_DONE:
-            if not upstream_done:
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}' requires all 
upstream tasks to have "
-                        f"completed, but found {len(upstream_tasks) - done} 
task(s) that were not done. "
-                        f"upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+            elif trigger_rule == TR.ALL_FAILED:
+                num_success = upstream - failed - upstream_failed
+                if ti.map_index > -1:
+                    num_success -= removed
+                if num_success > 0:
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' requires 
all upstream tasks "
+                            f"to have failed, but found {num_success} 
non-failure(s). "
+                            f"upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-        elif trigger_rule == TR.NONE_FAILED:
-            num_failures = upstream - success - skipped
-            if ti.map_index > -1:
-                num_failures -= removed
-            if num_failures > 0:
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}' requires all 
upstream tasks to have "
-                        f"succeeded or been skipped, but found {num_failures} 
non-success(es). "
-                        f"upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+            elif trigger_rule == TR.ALL_DONE:
+                if not upstream_done:
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' requires 
all upstream tasks to have "
+                            f"completed, but found {len(upstream_tasks) - 
done} task(s) that were "
+                            f"not done. upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-        elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
-            num_failures = upstream - success - skipped
-            if ti.map_index > -1:
-                num_failures -= removed
-            if num_failures > 0:
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}' requires all 
upstream tasks to have "
-                        f"succeeded or been skipped, but found {num_failures} 
non-success(es). "
-                        f"upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+            elif trigger_rule == TR.NONE_FAILED:
+                num_failures = upstream - success - skipped
+                if ti.map_index > -1:
+                    num_failures -= removed
+                if num_failures > 0:
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' requires 
all upstream tasks to have "
+                            f"succeeded or been skipped, but found 
{num_failures} non-success(es). "
+                            f"upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-        elif trigger_rule == TR.NONE_SKIPPED:
-            if not upstream_done or (skipped > 0):
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}' requires all 
upstream tasks to not have been "
-                        f"skipped, but found {skipped} task(s) skipped. "
-                        f"upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+            elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
+                num_failures = upstream - success - skipped
+                if ti.map_index > -1:
+                    num_failures -= removed
+                if num_failures > 0:
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' requires 
all upstream tasks to have "
+                            f"succeeded or been skipped, but found 
{num_failures} non-success(es). "
+                            f"upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-        elif trigger_rule == TR.ALL_SKIPPED:
-            num_non_skipped = upstream - skipped
-            if num_non_skipped > 0:
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}' requires all 
upstream tasks to have been "
-                        f"skipped, but found {num_non_skipped} task(s) in non 
skipped state. "
-                        f"upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+            elif trigger_rule == TR.NONE_SKIPPED:
+                if not upstream_done or (skipped > 0):
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' requires 
all upstream tasks to not "
+                            f"have been skipped, but found {skipped} task(s) 
skipped. "
+                            f"upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-        elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS:
-            if not upstream_done:
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}' requires all 
upstream tasks to have "
-                        f"completed, but found {len(upstream_tasks) - done} 
task(s) that were not done. "
-                        f"upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+            elif trigger_rule == TR.ALL_SKIPPED:
+                num_non_skipped = upstream - skipped
+                if num_non_skipped > 0:
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' requires 
all upstream tasks to have been "
+                            f"skipped, but found {num_non_skipped} task(s) in 
non skipped state. "
+                            f"upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-            elif upstream_setup is None:  # for now, None only happens in 
mapped case
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}' cannot have 
mapped tasks as upstream. "
-                        f"upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+            elif trigger_rule == TR.ALL_DONE_SETUP_SUCCESS:
+                if not upstream_done:
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' requires 
all upstream tasks to have "
+                            f"completed, but found {len(upstream_tasks) - 
done} task(s) that were not done. "
+                            f"upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-            elif upstream_setup and not success_setup >= 1:
-                yield self._failing_status(
-                    reason=(
-                        f"Task's trigger rule '{trigger_rule}' requires at 
least one upstream setup task be "
-                        f"successful, but found {upstream_setup - 
success_setup} task(s) that were not. "
-                        f"upstream_states={upstream_states}, "
-                        f"upstream_task_ids={task.upstream_task_ids}"
+                elif upstream_setup is None:  # for now, None only happens in 
mapped case
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' cannot have 
mapped tasks as upstream. "
+                            f"upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
+                    )
+                elif upstream_setup and not success_setup >= 1:
+                    yield self._failing_status(
+                        reason=(
+                            f"Task's trigger rule '{trigger_rule}' requires at 
least one upstream setup task "
+                            f"be successful, but found {upstream_setup - 
success_setup} task(s) that were "
+                            f"not. upstream_states={upstream_states}, "
+                            f"upstream_task_ids={task.upstream_task_ids}"
+                        )
                     )
-                )
-        else:
-            yield self._failing_status(reason=f"No strategy to evaluate 
trigger rule '{trigger_rule}'.")
+            else:
+                yield self._failing_status(reason=f"No strategy to evaluate 
trigger rule '{trigger_rule}'.")
+
+        if not ti.task.is_teardown:
+            # a teardown cannot have any indirect setups
+            relevant_setups = {t.task_id: t for t in 
ti.task.get_upstreams_only_setups()}
+            if relevant_setups:
+                for status, changed in 
_evaluate_setup_constraint(relevant_setups=relevant_setups):
+                    yield status
+                    if not status.passed and changed:
+                        # no need to evaluate trigger rule; we've already 
marked as skipped or failed
+                        return
+
+        yield from _evaluate_direct_relatives()
diff --git a/docs/apache-airflow/howto/setup-and-teardown.rst 
b/docs/apache-airflow/howto/setup-and-teardown.rst
index 355442a299..7afb3c4a35 100644
--- a/docs/apache-airflow/howto/setup-and-teardown.rst
+++ b/docs/apache-airflow/howto/setup-and-teardown.rst
@@ -125,6 +125,14 @@ In that example, we (in our pretend docs land) actually 
wanted to delete the clu
     create_cluster >> run_query >> other_task
     run_query >> 
EmptyOperator(task_id="cluster_teardown").as_teardown(setups=create_cluster)
 
+Implicit ALL_SUCCESS constraint
+"""""""""""""""""""""""""""""""
+
+Any task in the scope of a setup has an implicit "all_success" constraint on 
its setups.
+This is necessary to ensure that if a task with indirect setups is cleared, it 
will
+wait for them to complete.  If a setup fails or is skipped, the work tasks 
which depend
+them will be marked ask failures or skips.  We also require that any 
non-teardown directly
+downstream of a setup must have trigger rule ALL_SUCCESS.
 
 Controlling dag run state
 """""""""""""""""""""""""
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index df07f857dc..35ed9c0fc5 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -75,6 +75,7 @@ from airflow.utils.session import create_session, 
provide_session
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.task_group import TaskGroup, TaskGroupContext
 from airflow.utils.timezone import datetime as datetime_tz
+from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import DagRunType
 from airflow.utils.weight_rule import WeightRule
 from tests.models import DEFAULT_DATE
@@ -4033,3 +4034,16 @@ class TestTaskClearingSetupTeardownBehavior:
         assert self.cleared_upstream(s1) == {s1, t1}
         assert self.cleared_downstream(s1) == {s1, t1}
         assert self.cleared_neither(s1) == {s1, t1}
+
+    def test_validate_setup_teardown_trigger_rule(self):
+        with DAG(
+            dag_id="direct_setup_trigger_rule", start_date=pendulum.now(), 
schedule=None, catchup=False
+        ) as dag:
+            s1, w1 = self.make_tasks(dag, "s1, w1")
+            s1 >> w1
+            dag.validate_setup_teardown()
+            w1.trigger_rule = TriggerRule.ONE_FAILED
+            with pytest.raises(
+                Exception, match="Setup tasks must be followed with trigger 
rule ALL_SUCCESS."
+            ):
+                dag.validate_setup_teardown()
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index e50917e101..a2e72d614e 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1236,18 +1236,36 @@ class TestTaskInstance:
                 2,
                 _UpstreamTIStates(6, 0, 1, 0, 0, 7, 1, 0),
                 True,
-                None,
+                (True, None),  # is_teardown=True, expect_state=None
                 True,
-                id="one setup failed one setup success --> should run",
+                id="is teardown one setup failed one setup success",
+            ),
+            param(
+                "all_done_setup_success",
+                2,
+                _UpstreamTIStates(6, 0, 1, 0, 0, 7, 1, 0),
+                True,
+                (False, "upstream_failed"),  # is_teardown=False, 
expect_state="upstream_failed"
+                True,
+                id="not teardown one setup failed one setup success",
             ),
             param(
                 "all_done_setup_success",
                 2,
                 _UpstreamTIStates(6, 1, 0, 0, 0, 7, 1, 1),
                 True,
-                None,
+                (True, None),  # is_teardown=True, expect_state=None
                 True,
-                id="one setup success one setup skipped --> should run",
+                id="is teardown one setup success one setup skipped",
+            ),
+            param(
+                "all_done_setup_success",
+                2,
+                _UpstreamTIStates(6, 1, 0, 0, 0, 7, 1, 1),
+                True,
+                (False, "skipped"),  # is_teardown=False, 
expect_state="skipped"
+                True,
+                id="not teardown one setup success one setup skipped",
             ),
             param(
                 "all_done_setup_success",
@@ -1263,18 +1281,36 @@ class TestTaskInstance:
                 1,
                 _UpstreamTIStates(3, 0, 1, 0, 0, 4, 1, 0),
                 True,
-                None,
+                (True, None),  # is_teardown=True, expect_state=None
                 False,
-                id="not all done, one failed",
+                id="is teardown not all done one failed",
+            ),
+            param(
+                "all_done_setup_success",
+                1,
+                _UpstreamTIStates(3, 0, 1, 0, 0, 4, 1, 0),
+                True,
+                (False, "upstream_failed"),  # is_teardown=False, 
expect_state="upstream_failed"
+                False,
+                id="not teardown not all done one failed",
             ),
             param(
                 "all_done_setup_success",
                 1,
                 _UpstreamTIStates(3, 1, 0, 0, 0, 4, 1, 0),
                 True,
-                None,
+                (True, None),  # is_teardown=True, expect_state=None
                 False,
-                id="not all done, one skipped",
+                id="not all done one skipped",
+            ),
+            param(
+                "all_done_setup_success",
+                1,
+                _UpstreamTIStates(3, 1, 0, 0, 0, 4, 1, 0),
+                True,
+                (False, "skipped"),  # is_teardown=False, 
expect_state="skipped'
+                False,
+                id="not all done one skipped",
             ),
         ],
     )
@@ -1289,6 +1325,13 @@ class TestTaskInstance:
         expect_state: State,
         expect_passed: bool,
     ):
+        # this allows us to change the expected state depending on whether the
+        # task is a teardown
+        set_teardown = False
+        if isinstance(expect_state, tuple):
+            set_teardown, expect_state = expect_state
+            assert isinstance(set_teardown, bool)
+
         monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: 
upstream_states)
 
         # sanity checks
@@ -1299,6 +1342,8 @@ class TestTaskInstance:
 
         with dag_maker() as dag:
             downstream = EmptyOperator(task_id="downstream", 
trigger_rule=trigger_rule)
+            if set_teardown:
+                downstream.as_teardown()
             for i in range(5):
                 task = EmptyOperator(task_id=f"work_{i}", dag=dag)
                 task.set_downstream(downstream)
diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py 
b/tests/ti_deps/deps/test_trigger_rule_dep.py
index faa70b5a49..9b7114a400 100644
--- a/tests/ti_deps/deps/test_trigger_rule_dep.py
+++ b/tests/ti_deps/deps/test_trigger_rule_dep.py
@@ -1215,3 +1215,167 @@ def test_mapped_task_check_before_expand(dag_maker, 
session):
     results = list(result_iterator)
     assert len(results) == 1
     assert results[0].passed is False
+
+
+class TestTriggerRuleDepSetupConstraint:
+    @staticmethod
+    def get_ti(dr, task_id):
+        return next(ti for ti in dr.task_instances if ti.task_id == task_id)
+
+    def get_dep_statuses(self, dr, task_id, flag_upstream_failed=False, 
session=None):
+        return list(
+            TriggerRuleDep()._get_dep_statuses(
+                ti=self.get_ti(dr, task_id),
+                
dep_context=DepContext(flag_upstream_failed=flag_upstream_failed),
+                session=session,
+            )
+        )
+
+    def test_setup_constraint_blocks_execution(self, dag_maker, session):
+        with dag_maker(session=session):
+
+            @task
+            def t1():
+                return 1
+
+            @task
+            def t2():
+                return 2
+
+            @task
+            def t3():
+                return 3
+
+            t1_task = t1()
+            t2_task = t2()
+            t3_task = t3()
+            t1_task >> t2_task >> t3_task
+            t1_task.as_setup()
+        dr = dag_maker.create_dagrun()
+
+        # setup constraint is not applied to t2 because it has a direct setup
+        # so even though the setup is not done, the check passes
+        # but trigger rule fails because the normal trigger rule dep behavior
+        statuses = self.get_dep_statuses(dr, "t2", session=session)
+        assert len(statuses) == 1
+        assert statuses[0].passed is False
+        assert statuses[0].reason.startswith("Task's trigger rule 
'all_success' requires all upstream tasks")
+
+        # t3 has an indirect setup so the setup check fails
+        # trigger rule also fails
+        statuses = self.get_dep_statuses(dr, "t3", session=session)
+        assert len(statuses) == 2
+        assert statuses[0].passed is False
+        assert statuses[0].reason.startswith("All setup tasks must complete 
successfully")
+        assert statuses[1].passed is False
+        assert statuses[1].reason.startswith("Task's trigger rule 
'all_success' requires all upstream tasks")
+
+    @pytest.mark.parametrize(
+        "setup_state, expected", [(None, None), ("failed", "upstream_failed"), 
("skipped", "skipped")]
+    )
+    def test_setup_constraint_changes_state_appropriately(self, dag_maker, 
session, setup_state, expected):
+        with dag_maker(session=session):
+
+            @task
+            def t1():
+                return 1
+
+            @task
+            def t2():
+                return 2
+
+            @task
+            def t3():
+                return 3
+
+            t1_task = t1()
+            t2_task = t2()
+            t3_task = t3()
+            t1_task >> t2_task >> t3_task
+            t1_task.as_setup()
+        dr = dag_maker.create_dagrun()
+
+        # if the setup fails then now, in processing the trigger rule dep, the 
ti states
+        # will be updated
+        if setup_state:
+            self.get_ti(dr, "t1").state = setup_state
+        session.commit()
+        (status,) = self.get_dep_statuses(dr, "t2", flag_upstream_failed=True, 
session=session)
+        assert status.passed is False
+        # t2 fails on the non-setup-related trigger rule constraint since it 
has
+        # a direct setup
+        assert status.reason.startswith("Task's trigger rule 'all_success' 
requires")
+        assert self.get_ti(dr, "t2").state == expected
+        assert self.get_ti(dr, "t3").state is None  # hasn't been evaluated yet
+
+        # unlike t2, t3 fails on the setup constraint, and the normal trigger 
rule
+        # constraint is not actually evaluated, since it ain't gonna run anyway
+        if setup_state is None:
+            # when state is None, setup constraint doesn't mutate ti state, so 
we get
+            # two failure reasons -- setup constraint and trigger rule
+            (status, _) = self.get_dep_statuses(dr, "t3", 
flag_upstream_failed=True, session=session)
+        else:
+            (status,) = self.get_dep_statuses(dr, "t3", 
flag_upstream_failed=True, session=session)
+        assert status.reason.startswith("All setup tasks must complete 
successfully")
+        assert self.get_ti(dr, "t3").state == expected
+
+    @pytest.mark.parametrize(
+        "setup_state, expected", [(None, None), ("failed", "upstream_failed"), 
("skipped", "skipped")]
+    )
+    def test_setup_constraint_will_fail_or_skip_fast(self, dag_maker, session, 
setup_state, expected):
+        """
+        When a setup fails or skips, the tasks that depend on it will 
immediately fail or skip
+        and not, for example, wait for all setups to complete before 
determining what is
+        the appropriate state.  This is a bit of a race condition, but it's 
consistent
+        with the behavior for many-to-one direct upstream task relationships, 
and it's
+        required if you want to fail fast.
+
+        So in this test we verify that if even one setup is failed or skipped, 
the
+        state will propagate to the in-scope work tasks.
+        """
+        with dag_maker(session=session):
+
+            @task
+            def s1():
+                return 1
+
+            @task
+            def s2():
+                return 1
+
+            @task
+            def w1():
+                return 2
+
+            @task
+            def w2():
+                return 3
+
+            s1 = s1().as_setup()
+            s2 = s2().as_setup()
+            [s1, s2] >> w1() >> w2()
+        dr = dag_maker.create_dagrun()
+
+        # if the setup fails then now, in processing the trigger rule dep, the 
ti states
+        # will be updated
+        if setup_state:
+            self.get_ti(dr, "s2").state = setup_state
+        session.commit()
+        (status,) = self.get_dep_statuses(dr, "w1", flag_upstream_failed=True, 
session=session)
+        assert status.passed is False
+        # t2 fails on the non-setup-related trigger rule constraint since it 
has
+        # a direct setup
+        assert status.reason.startswith("Task's trigger rule 'all_success' 
requires")
+        assert self.get_ti(dr, "w1").state == expected
+        assert self.get_ti(dr, "w2").state is None  # hasn't been evaluated yet
+
+        # unlike t2, t3 fails on the setup constraint, and the normal trigger 
rule
+        # constraint is not actually evaluated, since it ain't gonna run anyway
+        if setup_state is None:
+            # when state is None, setup constraint doesn't mutate ti state, so 
we get
+            # two failure reasons -- setup constraint and trigger rule
+            (status, _) = self.get_dep_statuses(dr, "w2", 
flag_upstream_failed=True, session=session)
+        else:
+            (status,) = self.get_dep_statuses(dr, "w2", 
flag_upstream_failed=True, session=session)
+        assert status.reason.startswith("All setup tasks must complete 
successfully")
+        assert self.get_ti(dr, "w2").state == expected

Reply via email to