This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 27001a2371 Improve support for lists in setup/teardown context manager 
(#31616)
27001a2371 is described below

commit 27001a23718d6b8b5118eb130be84713af9a4477
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Wed May 31 10:22:35 2023 +0100

    Improve support for lists in setup/teardown context manager (#31616)
    
    * Improve support for lists in setup/teardown context manager
    
    This commit addresses an issue related to utilizing lists within the 
setup/teardown context manager.
    To enable the usage of a list of tasks on the right-hand side of the 
context manager,
    a new function called context_wrapper has been introduced. This is a 
subclass of 'list' that has
     the functionality of a context manager. Users will need to utilize this 
wrapper whenever they intend
    to use a list within the RHS of the context manager.
    
    * Fix typing
---
 airflow/decorators/setup_teardown.py    |  33 ++++-
 airflow/utils/setup_teardown.py         | 104 ++++++++++-----
 tests/decorators/test_setup_teardown.py | 229 ++++++++++++++++++++++++++++++++
 3 files changed, 334 insertions(+), 32 deletions(-)

diff --git a/airflow/decorators/setup_teardown.py 
b/airflow/decorators/setup_teardown.py
index 17c87c7025..e810c75a88 100644
--- a/airflow/decorators/setup_teardown.py
+++ b/airflow/decorators/setup_teardown.py
@@ -19,9 +19,11 @@ from __future__ import annotations
 import types
 from typing import Callable
 
-from airflow import AirflowException
+from airflow import AirflowException, XComArg
 from airflow.decorators import python_task
 from airflow.decorators.task_group import _TaskGroupFactory
+from airflow.models import BaseOperator
+from airflow.utils.setup_teardown import SetupTeardownContext
 
 
 def setup_task(func: Callable) -> Callable:
@@ -48,3 +50,32 @@ def teardown_task(_func=None, *, on_failure_fail_dagrun: 
bool = False) -> Callab
     if _func is None:
         return teardown
     return teardown(_func)
+
+
+class ContextWrapper(list):
+    """A list subclass that has a context manager that pushes setup/teardown 
tasks to the context."""
+
+    def __init__(self, tasks: list[BaseOperator | XComArg]):
+        self.tasks = tasks
+        super().__init__(tasks)
+
+    def __enter__(self):
+        operators = []
+        for task in self.tasks:
+            if isinstance(task, BaseOperator):
+                operators.append(task)
+                if not task.is_setup and not task.is_teardown:
+                    raise AirflowException("Only setup/teardown tasks can be 
used as context managers.")
+            elif not task.operator.is_setup and not task.operator.is_teardown:
+                raise AirflowException("Only setup/teardown tasks can be used 
as context managers.")
+        if not operators:
+            # means we have XComArgs
+            operators = [task.operator for task in self.tasks]
+        SetupTeardownContext.push_setup_teardown_task(operators)
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        SetupTeardownContext.set_work_task_roots_and_leaves()
+
+
+context_wrapper = ContextWrapper
diff --git a/airflow/utils/setup_teardown.py b/airflow/utils/setup_teardown.py
index 696bb5ba9b..9b28085880 100644
--- a/airflow/utils/setup_teardown.py
+++ b/airflow/utils/setup_teardown.py
@@ -26,107 +26,149 @@ if TYPE_CHECKING:
 class SetupTeardownContext:
     """Context manager for setup/teardown tasks."""
 
-    _context_managed_setup_task: Operator | None = None
-    _previous_context_managed_setup_task: list[Operator] = []
-    _context_managed_teardown_task: Operator | None = None
-    _previous_context_managed_teardown_task: list[Operator] = []
+    _context_managed_setup_task: Operator | list[Operator] | None = None
+    _previous_context_managed_setup_task: list[Operator | list[Operator]] = []
+    _context_managed_teardown_task: Operator | list[Operator] | None = None
+    _previous_context_managed_teardown_task: list[Operator | list[Operator]] = 
[]
     active: bool = False
-    context_map: dict[Operator, list[Operator]] = {}
+    context_map: dict[Operator | tuple[Operator], list[Operator]] = {}
 
     @classmethod
-    def push_context_managed_setup_task(cls, task: Operator):
+    def push_context_managed_setup_task(cls, task: Operator | list[Operator]):
         if cls._context_managed_setup_task:
             
cls._previous_context_managed_setup_task.append(cls._context_managed_setup_task)
         cls._context_managed_setup_task = task
 
     @classmethod
-    def push_context_managed_teardown_task(cls, task: Operator):
+    def push_context_managed_teardown_task(cls, task: Operator | 
list[Operator]):
         if cls._context_managed_teardown_task:
             
cls._previous_context_managed_teardown_task.append(cls._context_managed_teardown_task)
         cls._context_managed_teardown_task = task
 
     @classmethod
-    def pop_context_managed_setup_task(cls) -> Operator | None:
+    def pop_context_managed_setup_task(cls) -> Operator | list[Operator] | 
None:
         old_setup_task = cls._context_managed_setup_task
         if cls._previous_context_managed_setup_task:
             cls._context_managed_setup_task = 
cls._previous_context_managed_setup_task.pop()
-            if cls._context_managed_setup_task and old_setup_task:
-                cls._context_managed_setup_task.set_downstream(old_setup_task)
+            setup_task = cls._context_managed_setup_task
+            if setup_task and old_setup_task:
+                if isinstance(setup_task, list):
+                    for task in setup_task:
+                        task.set_downstream(old_setup_task)
+                else:
+                    setup_task.set_downstream(old_setup_task)
         else:
             cls._context_managed_setup_task = None
         return old_setup_task
 
     @classmethod
     def update_context_map(cls, operator):
-        setup_task = SetupTeardownContext.get_context_managed_setup_task()
-        teardown_task = 
SetupTeardownContext.get_context_managed_teardown_task()
         ctx = SetupTeardownContext.context_map
-        if setup_task:
+        if setup_task := SetupTeardownContext.get_context_managed_setup_task():
+            if isinstance(setup_task, list):
+                setup_task = tuple(setup_task)
             if ctx.get(setup_task) is None:
                 ctx[setup_task] = [operator]
             else:
                 ctx[setup_task].append(operator)
-        if teardown_task:
+        if teardown_task := 
SetupTeardownContext.get_context_managed_teardown_task():
+            if isinstance(teardown_task, list):
+                teardown_task = tuple(teardown_task)
             if ctx.get(teardown_task) is None:
                 ctx[teardown_task] = [operator]
             else:
                 ctx[teardown_task].append(operator)
 
     @classmethod
-    def pop_context_managed_teardown_task(cls) -> Operator | None:
+    def pop_context_managed_teardown_task(cls) -> Operator | list[Operator] | 
None:
         old_teardown_task = cls._context_managed_teardown_task
         if cls._previous_context_managed_teardown_task:
             cls._context_managed_teardown_task = 
cls._previous_context_managed_teardown_task.pop()
-            if cls._context_managed_teardown_task and old_teardown_task:
-                
cls._context_managed_teardown_task.set_upstream(old_teardown_task)
+            teardown_task = cls._context_managed_teardown_task
+            if teardown_task and old_teardown_task:
+                if isinstance(teardown_task, list):
+                    for task in teardown_task:
+                        task.set_upstream(old_teardown_task)
+                else:
+                    teardown_task.set_upstream(old_teardown_task)
         else:
             cls._context_managed_teardown_task = None
         return old_teardown_task
 
     @classmethod
-    def get_context_managed_setup_task(cls) -> Operator | None:
+    def get_context_managed_setup_task(cls) -> Operator | list[Operator] | 
None:
         return cls._context_managed_setup_task
 
     @classmethod
-    def get_context_managed_teardown_task(cls) -> Operator | None:
+    def get_context_managed_teardown_task(cls) -> Operator | list[Operator] | 
None:
         return cls._context_managed_teardown_task
 
     @classmethod
-    def push_setup_teardown_task(cls, operator):
-        if operator.is_teardown:
+    def push_setup_teardown_task(cls, operator: Operator | list[Operator]):
+        if isinstance(operator, list):
+            first_task: Operator = operator[0]
+            if first_task.is_teardown:
+                if not all(task.is_teardown == first_task.is_teardown for task 
in operator):
+                    raise ValueError("All tasks in the list must be either 
setup or teardown tasks")
+                
SetupTeardownContext.push_context_managed_teardown_task(operator)
+                upstream_setup: list[Operator] = [task for task in 
first_task.upstream_list if task.is_setup]
+                if upstream_setup:
+                    
SetupTeardownContext.push_context_managed_setup_task(upstream_setup)
+            elif first_task.is_setup:
+                if not all(task.is_setup == first_task.is_setup for task in 
operator):
+                    raise ValueError("All tasks in the list must be either 
setup or teardown tasks")
+                SetupTeardownContext.push_context_managed_setup_task(operator)
+                downstream_teardown: list[Operator] = [
+                    task for task in first_task.downstream_list if 
task.is_teardown
+                ]
+                if downstream_teardown:
+                    
SetupTeardownContext.push_context_managed_teardown_task(downstream_teardown)
+        elif operator.is_teardown:
             SetupTeardownContext.push_context_managed_teardown_task(operator)
             upstream_setup = [task for task in operator.upstream_list if 
task.is_setup]
             if upstream_setup:
-                
SetupTeardownContext.push_context_managed_setup_task(upstream_setup[-1])
+                
SetupTeardownContext.push_context_managed_setup_task(upstream_setup)
         elif operator.is_setup:
             SetupTeardownContext.push_context_managed_setup_task(operator)
             downstream_teardown = [task for task in operator.downstream_list 
if task.is_teardown]
             if downstream_teardown:
-                
SetupTeardownContext.push_context_managed_teardown_task(downstream_teardown[0])
+                
SetupTeardownContext.push_context_managed_teardown_task(downstream_teardown)
         SetupTeardownContext.active = True
 
     @classmethod
     def set_work_task_roots_and_leaves(cls):
-        setup_task = cls.get_context_managed_setup_task()
-        teardown_task = cls.get_context_managed_teardown_task()
-        if setup_task:
+        if setup_task := cls.get_context_managed_setup_task():
+            if isinstance(setup_task, list):
+                setup_task = tuple(setup_task)
             tasks_in_context = cls.context_map.get(setup_task, [])
             if tasks_in_context:
                 roots = [task for task in tasks_in_context if not 
task.upstream_list]
                 if not roots:
-                    setup_task.set_downstream(tasks_in_context[0])
+                    setup_task >> tasks_in_context[0]
+                elif isinstance(setup_task, tuple):
+                    for task in setup_task:
+                        task >> roots
                 else:
-                    setup_task.set_downstream(roots)
-        if teardown_task:
+                    setup_task >> roots
+        if teardown_task := cls.get_context_managed_teardown_task():
+            if isinstance(teardown_task, list):
+                teardown_task = tuple(teardown_task)
             tasks_in_context = cls.context_map.get(teardown_task, [])
             if tasks_in_context:
                 leaves = [task for task in tasks_in_context if not 
task.downstream_list]
                 if not leaves:
-                    teardown_task.set_upstream(tasks_in_context[-1])
+                    teardown_task << tasks_in_context[-1]
+                elif isinstance(teardown_task, tuple):
+                    for task in teardown_task:
+                        task << leaves
                 else:
-                    teardown_task.set_upstream(leaves)
+                    teardown_task << leaves
         setup_task = SetupTeardownContext.pop_context_managed_setup_task()
         teardown_task = 
SetupTeardownContext.pop_context_managed_teardown_task()
+        if isinstance(setup_task, list):
+            setup_task = tuple(setup_task)
+        if isinstance(teardown_task, list):
+            teardown_task = tuple(teardown_task)
         SetupTeardownContext.active = False
         SetupTeardownContext.context_map.pop(setup_task, None)
         SetupTeardownContext.context_map.pop(teardown_task, None)
diff --git a/tests/decorators/test_setup_teardown.py 
b/tests/decorators/test_setup_teardown.py
index 082b5e7813..05ea076dde 100644
--- a/tests/decorators/test_setup_teardown.py
+++ b/tests/decorators/test_setup_teardown.py
@@ -21,6 +21,7 @@ import pytest
 
 from airflow import AirflowException
 from airflow.decorators import setup, task, task_group, teardown
+from airflow.decorators.setup_teardown import context_wrapper
 from airflow.operators.bash import BashOperator
 
 
@@ -847,3 +848,231 @@ class TestSetupTearDownTask:
             "mytask2",
         }
         assert dag.task_group.children["teardowntask2"].downstream_task_ids == 
{"teardowntask"}
+
+    def test_setup_decorator_context_manager_with_list_on_left(self, 
dag_maker):
+        @setup
+        def setuptask():
+            print("setup")
+
+        @setup
+        def setuptask2():
+            print("setup")
+
+        @task()
+        def mytask():
+            print("mytask")
+
+        @teardown
+        def teardowntask():
+            print("teardown")
+
+        with dag_maker() as dag:
+            with [setuptask(), setuptask2()] >> teardowntask():
+                mytask()
+
+        assert len(dag.task_group.children) == 4
+        assert not dag.task_group.children["setuptask"].upstream_task_ids
+        assert not dag.task_group.children["setuptask2"].upstream_task_ids
+        assert dag.task_group.children["setuptask"].downstream_task_ids == 
{"mytask", "teardowntask"}
+        assert dag.task_group.children["setuptask2"].downstream_task_ids == 
{"mytask", "teardowntask"}
+        assert dag.task_group.children["mytask"].upstream_task_ids == 
{"setuptask", "setuptask2"}
+        assert dag.task_group.children["mytask"].downstream_task_ids == 
{"teardowntask"}
+        assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+            "setuptask",
+            "setuptask2",
+            "mytask",
+        }
+
+    def test_setup_decorator_context_manager_with_list_on_right(self, 
dag_maker):
+        @setup
+        def setuptask():
+            print("setup")
+
+        @setup
+        def setuptask2():
+            print("setup")
+
+        @task()
+        def mytask():
+            print("mytask")
+
+        @teardown
+        def teardowntask():
+            print("teardown")
+
+        with dag_maker() as dag:
+            with teardowntask() << context_wrapper([setuptask(), 
setuptask2()]):
+                mytask()
+
+        assert len(dag.task_group.children) == 4
+        assert not dag.task_group.children["setuptask"].upstream_task_ids
+        assert not dag.task_group.children["setuptask2"].upstream_task_ids
+        assert dag.task_group.children["setuptask"].downstream_task_ids == 
{"mytask", "teardowntask"}
+        assert dag.task_group.children["setuptask2"].downstream_task_ids == 
{"mytask", "teardowntask"}
+        assert dag.task_group.children["mytask"].upstream_task_ids == 
{"setuptask", "setuptask2"}
+        assert dag.task_group.children["mytask"].downstream_task_ids == 
{"teardowntask"}
+        assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+            "setuptask",
+            "setuptask2",
+            "mytask",
+        }
+
+    def test_setup_decorator_context_manager_errors_with_mixed_up_tasks(self, 
dag_maker):
+        @setup
+        def setuptask():
+            print("setup")
+
+        @setup
+        def setuptask2():
+            print("setup")
+
+        @task()
+        def mytask():
+            print("mytask")
+
+        @teardown
+        def teardowntask():
+            print("teardown")
+
+        with pytest.raises(ValueError, match="All tasks in the list must be 
either setup or teardown tasks"):
+            with dag_maker():
+                with setuptask() << context_wrapper([teardowntask(), 
setuptask2()]):
+                    mytask()
+
+    def test_teardown_decorator_context_manager_with_list_on_left(self, 
dag_maker):
+        @setup
+        def setuptask():
+            print("setup")
+
+        @task()
+        def mytask():
+            print("mytask")
+
+        @teardown
+        def teardowntask():
+            print("teardown")
+
+        @teardown
+        def teardowntask2():
+            print("teardown")
+
+        with dag_maker() as dag:
+            with [teardowntask(), teardowntask2()] << setuptask():
+                mytask()
+
+        assert len(dag.task_group.children) == 4
+        assert not dag.task_group.children["setuptask"].upstream_task_ids
+        assert dag.task_group.children["setuptask"].downstream_task_ids == {
+            "mytask",
+            "teardowntask",
+            "teardowntask2",
+        }
+        assert dag.task_group.children["mytask"].upstream_task_ids == 
{"setuptask"}
+        assert dag.task_group.children["mytask"].downstream_task_ids == 
{"teardowntask", "teardowntask2"}
+        assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+            "setuptask",
+            "mytask",
+        }
+        assert dag.task_group.children["teardowntask2"].upstream_task_ids == {
+            "setuptask",
+            "mytask",
+        }
+
+    def test_teardown_decorator_context_manager_with_list_on_right(self, 
dag_maker):
+        @setup
+        def setuptask():
+            print("setup")
+
+        @task()
+        def mytask():
+            print("mytask")
+
+        @teardown
+        def teardowntask():
+            print("teardown")
+
+        @teardown
+        def teardowntask2():
+            print("teardown")
+
+        with dag_maker() as dag:
+            with setuptask() >> context_wrapper([teardowntask(), 
teardowntask2()]):
+                mytask()
+
+        assert len(dag.task_group.children) == 4
+        assert not dag.task_group.children["setuptask"].upstream_task_ids
+        assert dag.task_group.children["setuptask"].downstream_task_ids == {
+            "mytask",
+            "teardowntask",
+            "teardowntask2",
+        }
+        assert dag.task_group.children["mytask"].upstream_task_ids == 
{"setuptask"}
+        assert dag.task_group.children["mytask"].downstream_task_ids == 
{"teardowntask", "teardowntask2"}
+        assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+            "setuptask",
+            "mytask",
+        }
+        assert dag.task_group.children["teardowntask2"].upstream_task_ids == {
+            "setuptask",
+            "mytask",
+        }
+
+    def test_classic_operator_context_manager_with_list_on_left(self, 
dag_maker):
+        @task()
+        def mytask():
+            print("mytask")
+
+        with dag_maker() as dag:
+            teardowntask = BashOperator.as_teardown(task_id="teardowntask", 
bash_command="echo 1")
+            teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2", 
bash_command="echo 1")
+            setuptask = BashOperator.as_setup(task_id="setuptask", 
bash_command="echo 1")
+            with [teardowntask, teardowntask2] << setuptask:
+                mytask()
+
+        assert len(dag.task_group.children) == 4
+        assert not dag.task_group.children["setuptask"].upstream_task_ids
+        assert dag.task_group.children["setuptask"].downstream_task_ids == {
+            "mytask",
+            "teardowntask",
+            "teardowntask2",
+        }
+        assert dag.task_group.children["mytask"].upstream_task_ids == 
{"setuptask"}
+        assert dag.task_group.children["mytask"].downstream_task_ids == 
{"teardowntask", "teardowntask2"}
+        assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+            "setuptask",
+            "mytask",
+        }
+        assert dag.task_group.children["teardowntask2"].upstream_task_ids == {
+            "setuptask",
+            "mytask",
+        }
+
+    def test_classic_operator_context_manager_with_list_on_right(self, 
dag_maker):
+        @task()
+        def mytask():
+            print("mytask")
+
+        with dag_maker() as dag:
+            teardowntask = BashOperator.as_teardown(task_id="teardowntask", 
bash_command="echo 1")
+            teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2", 
bash_command="echo 1")
+            setuptask = BashOperator.as_setup(task_id="setuptask", 
bash_command="echo 1")
+            with setuptask >> context_wrapper([teardowntask, teardowntask2]):
+                mytask()
+
+        assert len(dag.task_group.children) == 4
+        assert not dag.task_group.children["setuptask"].upstream_task_ids
+        assert dag.task_group.children["setuptask"].downstream_task_ids == {
+            "mytask",
+            "teardowntask",
+            "teardowntask2",
+        }
+        assert dag.task_group.children["mytask"].upstream_task_ids == 
{"setuptask"}
+        assert dag.task_group.children["mytask"].downstream_task_ids == 
{"teardowntask", "teardowntask2"}
+        assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+            "setuptask",
+            "mytask",
+        }
+        assert dag.task_group.children["teardowntask2"].upstream_task_ids == {
+            "setuptask",
+            "mytask",
+        }

Reply via email to