This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 27001a2371 Improve support for lists in setup/teardown context manager
(#31616)
27001a2371 is described below
commit 27001a23718d6b8b5118eb130be84713af9a4477
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Wed May 31 10:22:35 2023 +0100
Improve support for lists in setup/teardown context manager (#31616)
* Improve support for lists in setup/teardown context manager
This commit addresses an issue related to utilizing lists within the
setup/teardown context manager.
To enable the usage of a list of tasks on the right-hand side of the
context manager,
a new function called context_wrapper has been introduced. This is a
subclass of 'list' that has
the functionality of a context manager. Users will need to utilize this
wrapper whenever they intend
to use a list within the RHS of the context manager.
* Fix typing
---
airflow/decorators/setup_teardown.py | 33 ++++-
airflow/utils/setup_teardown.py | 104 ++++++++++-----
tests/decorators/test_setup_teardown.py | 229 ++++++++++++++++++++++++++++++++
3 files changed, 334 insertions(+), 32 deletions(-)
diff --git a/airflow/decorators/setup_teardown.py
b/airflow/decorators/setup_teardown.py
index 17c87c7025..e810c75a88 100644
--- a/airflow/decorators/setup_teardown.py
+++ b/airflow/decorators/setup_teardown.py
@@ -19,9 +19,11 @@ from __future__ import annotations
import types
from typing import Callable
-from airflow import AirflowException
+from airflow import AirflowException, XComArg
from airflow.decorators import python_task
from airflow.decorators.task_group import _TaskGroupFactory
+from airflow.models import BaseOperator
+from airflow.utils.setup_teardown import SetupTeardownContext
def setup_task(func: Callable) -> Callable:
@@ -48,3 +50,32 @@ def teardown_task(_func=None, *, on_failure_fail_dagrun:
bool = False) -> Callab
if _func is None:
return teardown
return teardown(_func)
+
+
+class ContextWrapper(list):
+ """A list subclass that has a context manager that pushes setup/teardown
tasks to the context."""
+
+ def __init__(self, tasks: list[BaseOperator | XComArg]):
+ self.tasks = tasks
+ super().__init__(tasks)
+
+ def __enter__(self):
+ operators = []
+ for task in self.tasks:
+ if isinstance(task, BaseOperator):
+ operators.append(task)
+ if not task.is_setup and not task.is_teardown:
+ raise AirflowException("Only setup/teardown tasks can be
used as context managers.")
+ elif not task.operator.is_setup and not task.operator.is_teardown:
+ raise AirflowException("Only setup/teardown tasks can be used
as context managers.")
+ if not operators:
+ # means we have XComArgs
+ operators = [task.operator for task in self.tasks]
+ SetupTeardownContext.push_setup_teardown_task(operators)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ SetupTeardownContext.set_work_task_roots_and_leaves()
+
+
+context_wrapper = ContextWrapper
diff --git a/airflow/utils/setup_teardown.py b/airflow/utils/setup_teardown.py
index 696bb5ba9b..9b28085880 100644
--- a/airflow/utils/setup_teardown.py
+++ b/airflow/utils/setup_teardown.py
@@ -26,107 +26,149 @@ if TYPE_CHECKING:
class SetupTeardownContext:
"""Context manager for setup/teardown tasks."""
- _context_managed_setup_task: Operator | None = None
- _previous_context_managed_setup_task: list[Operator] = []
- _context_managed_teardown_task: Operator | None = None
- _previous_context_managed_teardown_task: list[Operator] = []
+ _context_managed_setup_task: Operator | list[Operator] | None = None
+ _previous_context_managed_setup_task: list[Operator | list[Operator]] = []
+ _context_managed_teardown_task: Operator | list[Operator] | None = None
+ _previous_context_managed_teardown_task: list[Operator | list[Operator]] =
[]
active: bool = False
- context_map: dict[Operator, list[Operator]] = {}
+ context_map: dict[Operator | tuple[Operator], list[Operator]] = {}
@classmethod
- def push_context_managed_setup_task(cls, task: Operator):
+ def push_context_managed_setup_task(cls, task: Operator | list[Operator]):
if cls._context_managed_setup_task:
cls._previous_context_managed_setup_task.append(cls._context_managed_setup_task)
cls._context_managed_setup_task = task
@classmethod
- def push_context_managed_teardown_task(cls, task: Operator):
+ def push_context_managed_teardown_task(cls, task: Operator |
list[Operator]):
if cls._context_managed_teardown_task:
cls._previous_context_managed_teardown_task.append(cls._context_managed_teardown_task)
cls._context_managed_teardown_task = task
@classmethod
- def pop_context_managed_setup_task(cls) -> Operator | None:
+ def pop_context_managed_setup_task(cls) -> Operator | list[Operator] |
None:
old_setup_task = cls._context_managed_setup_task
if cls._previous_context_managed_setup_task:
cls._context_managed_setup_task =
cls._previous_context_managed_setup_task.pop()
- if cls._context_managed_setup_task and old_setup_task:
- cls._context_managed_setup_task.set_downstream(old_setup_task)
+ setup_task = cls._context_managed_setup_task
+ if setup_task and old_setup_task:
+ if isinstance(setup_task, list):
+ for task in setup_task:
+ task.set_downstream(old_setup_task)
+ else:
+ setup_task.set_downstream(old_setup_task)
else:
cls._context_managed_setup_task = None
return old_setup_task
@classmethod
def update_context_map(cls, operator):
- setup_task = SetupTeardownContext.get_context_managed_setup_task()
- teardown_task =
SetupTeardownContext.get_context_managed_teardown_task()
ctx = SetupTeardownContext.context_map
- if setup_task:
+ if setup_task := SetupTeardownContext.get_context_managed_setup_task():
+ if isinstance(setup_task, list):
+ setup_task = tuple(setup_task)
if ctx.get(setup_task) is None:
ctx[setup_task] = [operator]
else:
ctx[setup_task].append(operator)
- if teardown_task:
+ if teardown_task :=
SetupTeardownContext.get_context_managed_teardown_task():
+ if isinstance(teardown_task, list):
+ teardown_task = tuple(teardown_task)
if ctx.get(teardown_task) is None:
ctx[teardown_task] = [operator]
else:
ctx[teardown_task].append(operator)
@classmethod
- def pop_context_managed_teardown_task(cls) -> Operator | None:
+ def pop_context_managed_teardown_task(cls) -> Operator | list[Operator] |
None:
old_teardown_task = cls._context_managed_teardown_task
if cls._previous_context_managed_teardown_task:
cls._context_managed_teardown_task =
cls._previous_context_managed_teardown_task.pop()
- if cls._context_managed_teardown_task and old_teardown_task:
-
cls._context_managed_teardown_task.set_upstream(old_teardown_task)
+ teardown_task = cls._context_managed_teardown_task
+ if teardown_task and old_teardown_task:
+ if isinstance(teardown_task, list):
+ for task in teardown_task:
+ task.set_upstream(old_teardown_task)
+ else:
+ teardown_task.set_upstream(old_teardown_task)
else:
cls._context_managed_teardown_task = None
return old_teardown_task
@classmethod
- def get_context_managed_setup_task(cls) -> Operator | None:
+ def get_context_managed_setup_task(cls) -> Operator | list[Operator] |
None:
return cls._context_managed_setup_task
@classmethod
- def get_context_managed_teardown_task(cls) -> Operator | None:
+ def get_context_managed_teardown_task(cls) -> Operator | list[Operator] |
None:
return cls._context_managed_teardown_task
@classmethod
- def push_setup_teardown_task(cls, operator):
- if operator.is_teardown:
+ def push_setup_teardown_task(cls, operator: Operator | list[Operator]):
+ if isinstance(operator, list):
+ first_task: Operator = operator[0]
+ if first_task.is_teardown:
+ if not all(task.is_teardown == first_task.is_teardown for task
in operator):
+ raise ValueError("All tasks in the list must be either
setup or teardown tasks")
+
SetupTeardownContext.push_context_managed_teardown_task(operator)
+ upstream_setup: list[Operator] = [task for task in
first_task.upstream_list if task.is_setup]
+ if upstream_setup:
+
SetupTeardownContext.push_context_managed_setup_task(upstream_setup)
+ elif first_task.is_setup:
+ if not all(task.is_setup == first_task.is_setup for task in
operator):
+ raise ValueError("All tasks in the list must be either
setup or teardown tasks")
+ SetupTeardownContext.push_context_managed_setup_task(operator)
+ downstream_teardown: list[Operator] = [
+ task for task in first_task.downstream_list if
task.is_teardown
+ ]
+ if downstream_teardown:
+
SetupTeardownContext.push_context_managed_teardown_task(downstream_teardown)
+ elif operator.is_teardown:
SetupTeardownContext.push_context_managed_teardown_task(operator)
upstream_setup = [task for task in operator.upstream_list if
task.is_setup]
if upstream_setup:
-
SetupTeardownContext.push_context_managed_setup_task(upstream_setup[-1])
+
SetupTeardownContext.push_context_managed_setup_task(upstream_setup)
elif operator.is_setup:
SetupTeardownContext.push_context_managed_setup_task(operator)
downstream_teardown = [task for task in operator.downstream_list
if task.is_teardown]
if downstream_teardown:
-
SetupTeardownContext.push_context_managed_teardown_task(downstream_teardown[0])
+
SetupTeardownContext.push_context_managed_teardown_task(downstream_teardown)
SetupTeardownContext.active = True
@classmethod
def set_work_task_roots_and_leaves(cls):
- setup_task = cls.get_context_managed_setup_task()
- teardown_task = cls.get_context_managed_teardown_task()
- if setup_task:
+ if setup_task := cls.get_context_managed_setup_task():
+ if isinstance(setup_task, list):
+ setup_task = tuple(setup_task)
tasks_in_context = cls.context_map.get(setup_task, [])
if tasks_in_context:
roots = [task for task in tasks_in_context if not
task.upstream_list]
if not roots:
- setup_task.set_downstream(tasks_in_context[0])
+ setup_task >> tasks_in_context[0]
+ elif isinstance(setup_task, tuple):
+ for task in setup_task:
+ task >> roots
else:
- setup_task.set_downstream(roots)
- if teardown_task:
+ setup_task >> roots
+ if teardown_task := cls.get_context_managed_teardown_task():
+ if isinstance(teardown_task, list):
+ teardown_task = tuple(teardown_task)
tasks_in_context = cls.context_map.get(teardown_task, [])
if tasks_in_context:
leaves = [task for task in tasks_in_context if not
task.downstream_list]
if not leaves:
- teardown_task.set_upstream(tasks_in_context[-1])
+ teardown_task << tasks_in_context[-1]
+ elif isinstance(teardown_task, tuple):
+ for task in teardown_task:
+ task << leaves
else:
- teardown_task.set_upstream(leaves)
+ teardown_task << leaves
setup_task = SetupTeardownContext.pop_context_managed_setup_task()
teardown_task =
SetupTeardownContext.pop_context_managed_teardown_task()
+ if isinstance(setup_task, list):
+ setup_task = tuple(setup_task)
+ if isinstance(teardown_task, list):
+ teardown_task = tuple(teardown_task)
SetupTeardownContext.active = False
SetupTeardownContext.context_map.pop(setup_task, None)
SetupTeardownContext.context_map.pop(teardown_task, None)
diff --git a/tests/decorators/test_setup_teardown.py
b/tests/decorators/test_setup_teardown.py
index 082b5e7813..05ea076dde 100644
--- a/tests/decorators/test_setup_teardown.py
+++ b/tests/decorators/test_setup_teardown.py
@@ -21,6 +21,7 @@ import pytest
from airflow import AirflowException
from airflow.decorators import setup, task, task_group, teardown
+from airflow.decorators.setup_teardown import context_wrapper
from airflow.operators.bash import BashOperator
@@ -847,3 +848,231 @@ class TestSetupTearDownTask:
"mytask2",
}
assert dag.task_group.children["teardowntask2"].downstream_task_ids ==
{"teardowntask"}
+
+ def test_setup_decorator_context_manager_with_list_on_left(self,
dag_maker):
+ @setup
+ def setuptask():
+ print("setup")
+
+ @setup
+ def setuptask2():
+ print("setup")
+
+ @task()
+ def mytask():
+ print("mytask")
+
+ @teardown
+ def teardowntask():
+ print("teardown")
+
+ with dag_maker() as dag:
+ with [setuptask(), setuptask2()] >> teardowntask():
+ mytask()
+
+ assert len(dag.task_group.children) == 4
+ assert not dag.task_group.children["setuptask"].upstream_task_ids
+ assert not dag.task_group.children["setuptask2"].upstream_task_ids
+ assert dag.task_group.children["setuptask"].downstream_task_ids ==
{"mytask", "teardowntask"}
+ assert dag.task_group.children["setuptask2"].downstream_task_ids ==
{"mytask", "teardowntask"}
+ assert dag.task_group.children["mytask"].upstream_task_ids ==
{"setuptask", "setuptask2"}
+ assert dag.task_group.children["mytask"].downstream_task_ids ==
{"teardowntask"}
+ assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+ "setuptask",
+ "setuptask2",
+ "mytask",
+ }
+
+ def test_setup_decorator_context_manager_with_list_on_right(self,
dag_maker):
+ @setup
+ def setuptask():
+ print("setup")
+
+ @setup
+ def setuptask2():
+ print("setup")
+
+ @task()
+ def mytask():
+ print("mytask")
+
+ @teardown
+ def teardowntask():
+ print("teardown")
+
+ with dag_maker() as dag:
+ with teardowntask() << context_wrapper([setuptask(),
setuptask2()]):
+ mytask()
+
+ assert len(dag.task_group.children) == 4
+ assert not dag.task_group.children["setuptask"].upstream_task_ids
+ assert not dag.task_group.children["setuptask2"].upstream_task_ids
+ assert dag.task_group.children["setuptask"].downstream_task_ids ==
{"mytask", "teardowntask"}
+ assert dag.task_group.children["setuptask2"].downstream_task_ids ==
{"mytask", "teardowntask"}
+ assert dag.task_group.children["mytask"].upstream_task_ids ==
{"setuptask", "setuptask2"}
+ assert dag.task_group.children["mytask"].downstream_task_ids ==
{"teardowntask"}
+ assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+ "setuptask",
+ "setuptask2",
+ "mytask",
+ }
+
+ def test_setup_decorator_context_manager_errors_with_mixed_up_tasks(self,
dag_maker):
+ @setup
+ def setuptask():
+ print("setup")
+
+ @setup
+ def setuptask2():
+ print("setup")
+
+ @task()
+ def mytask():
+ print("mytask")
+
+ @teardown
+ def teardowntask():
+ print("teardown")
+
+ with pytest.raises(ValueError, match="All tasks in the list must be
either setup or teardown tasks"):
+ with dag_maker():
+ with setuptask() << context_wrapper([teardowntask(),
setuptask2()]):
+ mytask()
+
+ def test_teardown_decorator_context_manager_with_list_on_left(self,
dag_maker):
+ @setup
+ def setuptask():
+ print("setup")
+
+ @task()
+ def mytask():
+ print("mytask")
+
+ @teardown
+ def teardowntask():
+ print("teardown")
+
+ @teardown
+ def teardowntask2():
+ print("teardown")
+
+ with dag_maker() as dag:
+ with [teardowntask(), teardowntask2()] << setuptask():
+ mytask()
+
+ assert len(dag.task_group.children) == 4
+ assert not dag.task_group.children["setuptask"].upstream_task_ids
+ assert dag.task_group.children["setuptask"].downstream_task_ids == {
+ "mytask",
+ "teardowntask",
+ "teardowntask2",
+ }
+ assert dag.task_group.children["mytask"].upstream_task_ids ==
{"setuptask"}
+ assert dag.task_group.children["mytask"].downstream_task_ids ==
{"teardowntask", "teardowntask2"}
+ assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+ "setuptask",
+ "mytask",
+ }
+ assert dag.task_group.children["teardowntask2"].upstream_task_ids == {
+ "setuptask",
+ "mytask",
+ }
+
+ def test_teardown_decorator_context_manager_with_list_on_right(self,
dag_maker):
+ @setup
+ def setuptask():
+ print("setup")
+
+ @task()
+ def mytask():
+ print("mytask")
+
+ @teardown
+ def teardowntask():
+ print("teardown")
+
+ @teardown
+ def teardowntask2():
+ print("teardown")
+
+ with dag_maker() as dag:
+ with setuptask() >> context_wrapper([teardowntask(),
teardowntask2()]):
+ mytask()
+
+ assert len(dag.task_group.children) == 4
+ assert not dag.task_group.children["setuptask"].upstream_task_ids
+ assert dag.task_group.children["setuptask"].downstream_task_ids == {
+ "mytask",
+ "teardowntask",
+ "teardowntask2",
+ }
+ assert dag.task_group.children["mytask"].upstream_task_ids ==
{"setuptask"}
+ assert dag.task_group.children["mytask"].downstream_task_ids ==
{"teardowntask", "teardowntask2"}
+ assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+ "setuptask",
+ "mytask",
+ }
+ assert dag.task_group.children["teardowntask2"].upstream_task_ids == {
+ "setuptask",
+ "mytask",
+ }
+
+ def test_classic_operator_context_manager_with_list_on_left(self,
dag_maker):
+ @task()
+ def mytask():
+ print("mytask")
+
+ with dag_maker() as dag:
+ teardowntask = BashOperator.as_teardown(task_id="teardowntask",
bash_command="echo 1")
+ teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2",
bash_command="echo 1")
+ setuptask = BashOperator.as_setup(task_id="setuptask",
bash_command="echo 1")
+ with [teardowntask, teardowntask2] << setuptask:
+ mytask()
+
+ assert len(dag.task_group.children) == 4
+ assert not dag.task_group.children["setuptask"].upstream_task_ids
+ assert dag.task_group.children["setuptask"].downstream_task_ids == {
+ "mytask",
+ "teardowntask",
+ "teardowntask2",
+ }
+ assert dag.task_group.children["mytask"].upstream_task_ids ==
{"setuptask"}
+ assert dag.task_group.children["mytask"].downstream_task_ids ==
{"teardowntask", "teardowntask2"}
+ assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+ "setuptask",
+ "mytask",
+ }
+ assert dag.task_group.children["teardowntask2"].upstream_task_ids == {
+ "setuptask",
+ "mytask",
+ }
+
+ def test_classic_operator_context_manager_with_list_on_right(self,
dag_maker):
+ @task()
+ def mytask():
+ print("mytask")
+
+ with dag_maker() as dag:
+ teardowntask = BashOperator.as_teardown(task_id="teardowntask",
bash_command="echo 1")
+ teardowntask2 = BashOperator.as_teardown(task_id="teardowntask2",
bash_command="echo 1")
+ setuptask = BashOperator.as_setup(task_id="setuptask",
bash_command="echo 1")
+ with setuptask >> context_wrapper([teardowntask, teardowntask2]):
+ mytask()
+
+ assert len(dag.task_group.children) == 4
+ assert not dag.task_group.children["setuptask"].upstream_task_ids
+ assert dag.task_group.children["setuptask"].downstream_task_ids == {
+ "mytask",
+ "teardowntask",
+ "teardowntask2",
+ }
+ assert dag.task_group.children["mytask"].upstream_task_ids ==
{"setuptask"}
+ assert dag.task_group.children["mytask"].downstream_task_ids ==
{"teardowntask", "teardowntask2"}
+ assert dag.task_group.children["teardowntask"].upstream_task_ids == {
+ "setuptask",
+ "mytask",
+ }
+ assert dag.task_group.children["teardowntask2"].upstream_task_ids == {
+ "setuptask",
+ "mytask",
+ }