This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-3-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 7a4869c737e46acb1fe794acef588b7251d2c4a9 Author: Ash Berlin-Taylor <a...@apache.org> AuthorDate: Tue Jul 5 16:40:00 2022 +0100 Fix cycle bug with attaching label to task group (#24847) The problem was specific to EdgeModifiers as they try to be "transparent" to upstream/downstream The fix is to set track the upstream/downstream for the task group before making any changes to the EdgeModifiers' relations -- otherwise the roots of the TG were added as dependencies to themeslves! (cherry picked from commit efc05a5f0b3d261293c2efaf6771e4af9a2f324c) --- airflow/utils/task_group.py | 12 ++++++------ tests/utils/test_task_group.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index ed8d380ff0..64c11f79db 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -283,6 +283,12 @@ class TaskGroup(DAGNode): Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids. """ + if not isinstance(task_or_task_list, Sequence): + task_or_task_list = [task_or_task_list] + + for task_like in task_or_task_list: + self.update_relative(task_like, upstream) + if upstream: for task in self.get_roots(): task.set_upstream(task_or_task_list) @@ -290,12 +296,6 @@ class TaskGroup(DAGNode): for task in self.get_leaves(): task.set_downstream(task_or_task_list) - if not isinstance(task_or_task_list, Sequence): - task_or_task_list = [task_or_task_list] - - for task_like in task_or_task_list: - self.update_relative(task_like, upstream) - def __enter__(self) -> "TaskGroup": TaskGroupContext.push_context_managed_task_group(self) return self diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 9aacc96b82..864b2fb68a 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -1222,3 +1222,28 @@ def test_add_to_another_group(): tg.add(task) assert str(ctx.value) == "cannot add 'section_2.task' to 'section_1' (already in group 'section_2')" + + +def test_task_group_edge_modifier_chain(): + from airflow.models.baseoperator import chain + from airflow.utils.edgemodifier import Label + + with DAG(dag_id="test", start_date=pendulum.DateTime(2022, 5, 20)) as dag: + start = EmptyOperator(task_id="sleep_3_seconds") + + with TaskGroup(group_id="group1") as tg: + t1 = EmptyOperator(task_id="dummy1") + t2 = EmptyOperator(task_id="dummy2") + + t3 = EmptyOperator(task_id="echo_done") + + # The case we are testing for is when a Label is inside a list -- meaning that we do tg.set_upstream + # instead of label.set_downstream + chain(start, [Label("branch three")], tg, t3) + + assert start.downstream_task_ids == {t1.node_id, t2.node_id} + assert t3.upstream_task_ids == {t1.node_id, t2.node_id} + assert tg.upstream_task_ids == set() + assert tg.downstream_task_ids == {t3.node_id} + # Check that we can perform a topological_sort + dag.topological_sort()