This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new c8a0f51f89d Move `task_group_to_dict` functions from task-sdk to server-side & reorganize tests (#54857) c8a0f51f89d is described below commit c8a0f51f89dc5c75b86c6888154073c08c2a0671 Author: Kaxil Naik <kaxiln...@apache.org> AuthorDate: Tue Aug 26 19:11:15 2025 +0100 Move `task_group_to_dict` functions from task-sdk to server-side & reorganize tests (#54857) Follow-up of the fix in https://github.com/apache/airflow/pull/54756 - Move `task_group_to_dict` and `task_group_to_dict_grid` functions from `task-sdk` to `airflow-core` API services - Update import paths in `grid.py` and `structure.py` to use new server-side location - Update deprecation mappings to point to new FastAPI service module locations - Remove `AbstractOperator` from isinstance checks in server-side functions (now handles only serialized objects) - Split mixed tests: moved TaskGroup functionality tests to `task-sdk`, keep server function tests in `airflow-core` - Add comprehensive TaskGroup tests in `task-sdk` covering creation, relationships, decorators, and validation - Create clear architectural boundary: client-side TaskGroup authoring vs server-side serialized DAG processing --- .pre-commit-config.yaml | 56 +- airflow-core/newsfragments/54857.significant.rst | 14 + .../airflow/api_fastapi/core_api/routes/ui/grid.py | 8 +- .../api_fastapi/core_api/routes/ui/structure.py | 2 +- .../api_fastapi/core_api/services/ui/grid.py | 3 +- .../api_fastapi/core_api/services/ui/task_group.py | 115 +++ airflow-core/src/airflow/utils/__init__.py | 2 - airflow-core/tests/unit/utils/test_task_group.py | 704 +++--------------- task-sdk/src/airflow/sdk/definitions/taskgroup.py | 99 +-- .../tests/task_sdk/definitions/test_taskgroup.py | 789 ++++++++++++++++++++- 10 files changed, 1036 insertions(+), 756 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6bc1c7f7b19..40cc880d9d9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1552,17 +1552,17 @@ repos: # TODO: These files need to be refactored to remove SDK coupling ^airflow-core/src/airflow/__init__\.py$| - ^airflow-core/src/airflow/models/__init__\.py$| ^airflow-core/src/airflow/api/common/mark_tasks\.py$| ^airflow-core/src/airflow/api_fastapi/core_api/datamodels/assets\.py$| ^airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections\.py$| - ^airflow-core/src/airflow/api_fastapi/core_api/services/public/connections\.py$| ^airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl\.py$| ^airflow-core/src/airflow/api_fastapi/core_api/datamodels/variables\.py$| ^airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid\.py$| ^airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure\.py$| + ^airflow-core/src/airflow/api_fastapi/core_api/services/public/connections\.py$| ^airflow-core/src/airflow/api_fastapi/core_api/services/ui/connections\.py$| ^airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid\.py$| + ^airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py$| ^airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl\.py$| ^airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances\.py$| ^airflow-core/src/airflow/api_fastapi/logging/decorators\.py$| @@ -1570,6 +1570,7 @@ repos: ^airflow-core/src/airflow/assets/manager\.py$| ^airflow-core/src/airflow/cli/commands/connection_command\.py$| ^airflow-core/src/airflow/cli/commands/task_command\.py$| + ^airflow-core/src/airflow/cli/commands/triggerer_command.py$| ^airflow-core/src/airflow/configuration\.py$| ^airflow-core/src/airflow/dag_processing/collection\.py$| ^airflow-core/src/airflow/dag_processing/manager\.py$| @@ -1582,58 +1583,57 @@ repos: ^airflow-core/src/airflow/listeners/spec/asset\.py$| ^airflow-core/src/airflow/listeners/spec/taskinstance\.py$| ^airflow-core/src/airflow/logging/remote\.py$| + ^airflow-core/src/airflow/models/__init__\.py$| ^airflow-core/src/airflow/models/asset\.py$| ^airflow-core/src/airflow/models/baseoperator\.py$| ^airflow-core/src/airflow/models/connection\.py$| ^airflow-core/src/airflow/models/dag\.py$| - ^airflow-core/src/airflow/models/deadline\.py$| ^airflow-core/src/airflow/models/dagbag\.py$| ^airflow-core/src/airflow/models/dagrun\.py$| + ^airflow-core/src/airflow/models/deadline\.py$| + ^airflow-core/src/airflow/models/expandinput\.py$| ^airflow-core/src/airflow/models/mappedoperator\.py$| ^airflow-core/src/airflow/models/operator\.py$| ^airflow-core/src/airflow/models/param\.py$| + ^airflow-core/src/airflow/models/renderedtifields\.py$| ^airflow-core/src/airflow/models/serialized_dag\.py$| ^airflow-core/src/airflow/models/taskinstance\.py$| ^airflow-core/src/airflow/models/taskinstancekey\.py$| ^airflow-core/src/airflow/models/taskmap\.py$| + ^airflow-core/src/airflow/models/taskmixin\.py$| ^airflow-core/src/airflow/models/taskreschedule\.py$| ^airflow-core/src/airflow/models/variable\.py$| + ^airflow-core/src/airflow/models/xcom\.py$| + ^airflow-core/src/airflow/models/xcom_arg\.py$| ^airflow-core/src/airflow/operators/subdag\.py$| + ^airflow-core/src/airflow/plugins_manager\.py$| + ^airflow-core/src/airflow/providers_manager\.py$| ^airflow-core/src/airflow/serialization/dag\.py$| ^airflow-core/src/airflow/serialization/enums\.py$| + ^airflow-core/src/airflow/serialization/helpers\.py$| ^airflow-core/src/airflow/serialization/serialized_objects\.py$| + ^airflow-core/src/airflow/settings\.py$| ^airflow-core/src/airflow/task/task_runner/bash_task_runner\.py$| ^airflow-core/src/airflow/task/task_runner/standard_task_runner\.py$| + ^airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep\.py$| + ^airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep\.py$| + ^airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep\.py$| + ^airflow-core/src/airflow/timetables/assets\.py$| + ^airflow-core/src/airflow/timetables/base\.py$| + ^airflow-core/src/airflow/timetables/simple\.py$| + ^airflow-core/src/airflow/utils/cli\.py$| + ^airflow-core/src/airflow/utils/context\.py$| ^airflow-core/src/airflow/utils/dag_cycle_tester\.py$| + ^airflow-core/src/airflow/utils/dag_edges\.py$| ^airflow-core/src/airflow/utils/dag_parsing_context\.py$| ^airflow-core/src/airflow/utils/decorators\.py$| + ^airflow-core/src/airflow/utils/dot_renderer\.py$| + ^airflow-core/src/airflow/utils/edgemodifier\.py$| + ^airflow-core/src/airflow/utils/email\.py$| + ^airflow-core/src/airflow/utils/helpers\.py$| ^airflow-core/src/airflow/utils/operator_helpers\.py$| ^airflow-core/src/airflow/utils/session\.py$| ^airflow-core/src/airflow/utils/task_group\.py$| ^airflow-core/src/airflow/utils/trigger_rule\.py$| - ^airflow-core/src/airflow/utils/xcom\.py$| - ^airflow-core/src/airflow/providers_manager\.py$| - ^airflow-core/src/airflow/timetables/assets\.py$| - ^airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep\.py$| - ^airflow-core/src/airflow/utils/context\.py$| - ^airflow-core/src/airflow/models/taskmixin\.py$| - ^airflow-core/src/airflow/utils/edgemodifier\.py$| - ^airflow-core/src/airflow/utils/email\.py$| - ^airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep\.py$| - ^airflow-core/src/airflow/utils/helpers\.py$| - ^airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep\.py$| - ^airflow-core/src/airflow/utils/types\.py$| - ^airflow-core/src/airflow/utils/dag_edges\.py$| - ^airflow-core/src/airflow/utils/cli\.py$| - ^airflow-core/src/airflow/timetables/base\.py$| - ^airflow-core/src/airflow/utils/dot_renderer\.py$| - ^airflow-core/src/airflow/models/xcom_arg\.py$| - ^airflow-core/src/airflow/plugins_manager\.py$| - ^airflow-core/src/airflow/models/xcom\.py$| - ^airflow-core/src/airflow/timetables/simple\.py$| - ^airflow-core/src/airflow/settings\.py$| - ^airflow-core/src/airflow/models/renderedtifields\.py$| - ^airflow-core/src/airflow/serialization/helpers\.py$| - ^airflow-core/src/airflow/models/expandinput\.py$| - ^airflow-core/src/airflow/cli/commands/triggerer_command.py$ + ^airflow-core/src/airflow/utils/types\.py$ ## ONLY ADD PREK HOOKS HERE THAT REQUIRE CI IMAGE diff --git a/airflow-core/newsfragments/54857.significant.rst b/airflow-core/newsfragments/54857.significant.rst new file mode 100644 index 00000000000..84ba878db59 --- /dev/null +++ b/airflow-core/newsfragments/54857.significant.rst @@ -0,0 +1,14 @@ +Remove ``get_task_group_children_getter`` and ``task_group_to_dict`` from task-sdk + +The ``get_task_group_children_getter`` and ``task_group_to_dict`` functions have been removed from the task-sdk (``airflow.sdk.definitions.taskgroup``) and moved to server-side API services. These functions are now internal to Airflow's API layer and should not be imported directly by users. + +* Types of change + + * [ ] Dag changes + * [ ] Config changes + * [ ] API changes + * [ ] CLI changes + * [ ] Behaviour changes + * [ ] Plugin changes + * [ ] Dependency changes + * [x] Code interface changes diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py index fd217f3fbdf..0f9f4023cbc 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py @@ -47,14 +47,14 @@ from airflow.api_fastapi.core_api.services.ui.grid import ( _find_aggregates, _merge_node_dicts, ) +from airflow.api_fastapi.core_api.services.ui.task_group import ( + get_task_group_children_getter, + task_group_to_dict_grid, +) from airflow.models.dag_version import DagVersion from airflow.models.dagrun import DagRun from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance -from airflow.sdk.definitions.taskgroup import ( - get_task_group_children_getter, - task_group_to_dict_grid, -) log = structlog.get_logger(logger_name=__name__) grid_router = AirflowRouter(prefix="/grid", tags=["Grid"]) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py index c308ae21432..1fba54ce543 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py @@ -30,9 +30,9 @@ from airflow.api_fastapi.core_api.services.ui.structure import ( bind_output_assets_to_tasks, get_upstream_assets, ) +from airflow.api_fastapi.core_api.services.ui.task_group import task_group_to_dict from airflow.models.dag_version import DagVersion from airflow.models.serialized_dag import SerializedDagModel -from airflow.sdk.definitions.taskgroup import task_group_to_dict from airflow.utils.dag_edges import dag_edges structure_router = AirflowRouter(tags=["Structure"], prefix="/structure") diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py index 8b209fb324c..124f526cd05 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py @@ -23,9 +23,10 @@ from collections.abc import Iterable import structlog from airflow.api_fastapi.common.parameters import state_priority +from airflow.api_fastapi.core_api.services.ui.task_group import get_task_group_children_getter from airflow.models.mappedoperator import MappedOperator from airflow.models.taskmap import TaskMap -from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup, get_task_group_children_getter +from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.serialization.serialized_objects import SerializedBaseOperator log = structlog.get_logger(logger_name=__name__) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py new file mode 100644 index 00000000000..f88dca353c4 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task group utilities for UI API services.""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import cache +from operator import methodcaller + +from airflow.configuration import conf +from airflow.models.mappedoperator import MappedOperator +from airflow.sdk.definitions.taskgroup import MappedTaskGroup +from airflow.serialization.serialized_objects import SerializedBaseOperator + + +@cache +def get_task_group_children_getter() -> Callable: + """Get the Task Group Children Getter for the DAG.""" + sort_order = conf.get("api", "grid_view_sorting_order") + if sort_order == "topological": + return methodcaller("topological_sort") + return methodcaller("hierarchical_alphabetical_sort") + + +def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False): + """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" + if isinstance(task := task_item_or_group, (SerializedBaseOperator, MappedOperator)): + node_operator = { + "id": task.task_id, + "label": task.label, + "operator": task.operator_name, + "type": "task", + } + if task.is_setup: + node_operator["setup_teardown_type"] = "setup" + elif task.is_teardown: + node_operator["setup_teardown_type"] = "teardown" + if isinstance(task, MappedOperator) or parent_group_is_mapped: + node_operator["is_mapped"] = True + return node_operator + + task_group = task_item_or_group + is_mapped = isinstance(task_group, MappedTaskGroup) + children = [ + task_group_to_dict(child, parent_group_is_mapped=parent_group_is_mapped or is_mapped) + for child in get_task_group_children_getter()(task_group) + ] + + if task_group.upstream_group_ids or task_group.upstream_task_ids: + # This is the join node used to reduce the number of edges between two TaskGroup. + children.append({"id": task_group.upstream_join_id, "label": "", "type": "join"}) + + if task_group.downstream_group_ids or task_group.downstream_task_ids: + # This is the join node used to reduce the number of edges between two TaskGroup. + children.append({"id": task_group.downstream_join_id, "label": "", "type": "join"}) + + return { + "id": task_group.group_id, + "label": task_group.label, + "tooltip": task_group.tooltip, + "is_mapped": is_mapped, + "children": children, + "type": "task", + } + + +def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False): + """Create a nested dict representation of this TaskGroup and its children used to construct the Grid.""" + if isinstance(task := task_item_or_group, (MappedOperator, SerializedBaseOperator)): + is_mapped = None + if task.is_mapped or parent_group_is_mapped: + is_mapped = True + setup_teardown_type = None + if task.is_setup is True: + setup_teardown_type = "setup" + elif task.is_teardown is True: + setup_teardown_type = "teardown" + return { + "id": task.task_id, + "label": task.label, + "is_mapped": is_mapped, + "children": None, + "setup_teardown_type": setup_teardown_type, + } + + task_group = task_item_or_group + task_group_sort = get_task_group_children_getter() + is_mapped_group = isinstance(task_group, MappedTaskGroup) + children = [ + task_group_to_dict_grid(x, parent_group_is_mapped=parent_group_is_mapped or is_mapped_group) + for x in task_group_sort(task_group) + ] + + return { + "id": task_group.group_id, + "label": task_group.label, + "is_mapped": is_mapped_group or None, + "children": children or None, + } diff --git a/airflow-core/src/airflow/utils/__init__.py b/airflow-core/src/airflow/utils/__init__.py index 84e999a7c83..4b0ad3419a1 100644 --- a/airflow-core/src/airflow/utils/__init__.py +++ b/airflow-core/src/airflow/utils/__init__.py @@ -30,8 +30,6 @@ __deprecated_classes = { }, "task_group": { "TaskGroup": "airflow.sdk.TaskGroup", - "get_task_group_children_getter": "airflow.sdk.definitions.taskgroup.get_task_group_children_getter", - "task_group_to_dict": "airflow.sdk.definitions.taskgroup.task_group_to_dict", }, "timezone": { # Since we have corrected all uses inside core to use the internal version, anything hitting this diff --git a/airflow-core/tests/unit/utils/test_task_group.py b/airflow-core/tests/unit/utils/test_task_group.py index 7af3b3326f4..447578329aa 100644 --- a/airflow-core/tests/unit/utils/test_task_group.py +++ b/airflow-core/tests/unit/utils/test_task_group.py @@ -17,39 +17,27 @@ # under the License. from __future__ import annotations -from datetime import timedelta - import pendulum import pytest -from airflow.exceptions import TaskAlreadyInTaskGroup +from airflow.api_fastapi.core_api.services.ui.task_group import task_group_to_dict from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG -from airflow.models.xcom_arg import XComArg from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import ( - dag, setup, task as task_decorator, task_group as task_group_decorator, teardown, ) -from airflow.sdk.definitions.taskgroup import TaskGroup, task_group_to_dict +from airflow.sdk.definitions.taskgroup import TaskGroup +from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.dag_edges import dag_edges from tests_common.test_utils.compat import BashOperator, PythonOperator from unit.models import DEFAULT_DATE - -def make_task(name, type_="classic"): - if type_ == "classic": - return BashOperator(task_id=name, bash_command="echo 1") - - @task_decorator - def my_task(): - pass - - return my_task.override(task_id=name)() +pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] EXPECTED_JSON_LEGACY = { @@ -167,24 +155,11 @@ EXPECTED_JSON_LEGACY = { } EXPECTED_JSON = { - "id": None, - "label": None, - "tooltip": "", - "is_mapped": False, "children": [ {"id": "task1", "label": "task1", "operator": "EmptyOperator", "type": "task"}, { - "id": "group234", - "label": "group234", - "tooltip": "", - "is_mapped": False, "children": [ - {"id": "group234.task2", "label": "task2", "operator": "EmptyOperator", "type": "task"}, { - "id": "group234.group34", - "label": "group34", - "tooltip": "", - "is_mapped": False, "children": [ { "id": "group234.group34.task3", @@ -200,21 +175,39 @@ EXPECTED_JSON = { }, {"id": "group234.group34.downstream_join_id", "label": "", "type": "join"}, ], + "id": "group234.group34", + "is_mapped": False, + "label": "group34", + "tooltip": "", + "type": "task", + }, + { + "id": "group234.task2", + "label": "task2", + "operator": "EmptyOperator", "type": "task", }, {"id": "group234.upstream_join_id", "label": "", "type": "join"}, ], + "id": "group234", + "is_mapped": False, + "label": "group234", + "tooltip": "", "type": "task", }, {"id": "task5", "label": "task5", "operator": "EmptyOperator", "type": "task"}, ], + "id": None, + "is_mapped": False, + "label": None, + "tooltip": "", "type": "task", } -def test_build_task_group_context_manager(): +def test_task_group_to_dict_serialized_dag(dag_maker): logical_date = pendulum.parse("20200101") - with DAG("test_build_task_group_context_manager", schedule=None, start_date=logical_date) as dag: + with dag_maker("test_task_group_to_dict", schedule=None, start_date=logical_date) as dag: task1 = EmptyOperator(task_id="task1") with TaskGroup("group234") as group234: _ = EmptyOperator(task_id="task2") @@ -227,31 +220,12 @@ def test_build_task_group_context_manager(): task1 >> group234 group34 >> task5 - assert task1.get_direct_relative_ids(upstream=False) == { - "group234.group34.task4", - "group234.group34.task3", - "group234.task2", - } - assert task5.get_direct_relative_ids(upstream=True) == { - "group234.group34.task4", - "group234.group34.task3", - } - - assert dag.task_group.group_id is None - assert dag.task_group.is_root - assert set(dag.task_group.children.keys()) == {"task1", "group234", "task5"} - assert group34.group_id == "group234.group34" - assert task_group_to_dict(dag.task_group) == EXPECTED_JSON -def test_build_task_group(): - """ - This is an alternative syntax to use TaskGroup. It should result in the same TaskGroup - as using context manager. - """ +def test_task_group_to_dict_alternative_syntax(): logical_date = pendulum.parse("20200101") - dag = DAG("test_build_task_group", schedule=None, start_date=logical_date) + dag = DAG("test_task_group_to_dict_alt", schedule=None, start_date=logical_date) task1 = EmptyOperator(task_id="task1", dag=dag) group234 = TaskGroup("group234", dag=dag) _ = EmptyOperator(task_id="task2", dag=dag, task_group=group234) @@ -263,7 +237,9 @@ def test_build_task_group(): task1 >> group234 group34 >> task5 - assert task_group_to_dict(dag.task_group) == EXPECTED_JSON + serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + assert task_group_to_dict(serialized_dag.task_group) == EXPECTED_JSON def extract_node_id(node, include_label=False): @@ -280,73 +256,57 @@ def extract_node_id(node, include_label=False): return ret -def test_build_task_group_with_prefix(): - """ - Tests that prefix_group_id turns on/off prefixing of task_id with group_id. - """ +def test_task_group_to_dict_with_prefix(dag_maker): logical_date = pendulum.parse("20200101") - with DAG("test_build_task_group_with_prefix", schedule=None, start_date=logical_date) as dag: + with dag_maker("test_task_group_to_dict_prefix", start_date=logical_date) as dag: task1 = EmptyOperator(task_id="task1") with TaskGroup("group234", prefix_group_id=False) as group234: - task2 = EmptyOperator(task_id="task2") + EmptyOperator(task_id="task2") with TaskGroup("group34") as group34: - task3 = EmptyOperator(task_id="task3") + EmptyOperator(task_id="task3") - with TaskGroup("group4", prefix_group_id=False) as group4: - task4 = EmptyOperator(task_id="task4") + with TaskGroup("group4", prefix_group_id=False): + EmptyOperator(task_id="task4") task5 = EmptyOperator(task_id="task5") task1 >> group234 group34 >> task5 - assert task2.task_id == "task2" - assert group34.group_id == "group34" - assert task3.task_id == "group34.task3" - assert group4.group_id == "group34.group4" - assert task4.task_id == "task4" - assert task5.task_id == "task5" - assert group234.get_child_by_label("task2") == task2 - assert group234.get_child_by_label("group34") == group34 - assert group4.get_child_by_label("task4") == task4 - expected_node_id = { - "id": None, - "label": None, "children": [ {"id": "task1", "label": "task1"}, { - "id": "group234", - "label": "group234", "children": [ - {"id": "task2", "label": "task2"}, { - "id": "group34", - "label": "group34", "children": [ - {"id": "group34.task3", "label": "task3"}, { + "children": [{"id": "task4", "label": "task4"}], "id": "group34.group4", "label": "group4", - "children": [{"id": "task4", "label": "task4"}], }, + {"id": "group34.task3", "label": "task3"}, {"id": "group34.downstream_join_id", "label": ""}, ], + "id": "group34", + "label": "group34", }, + {"id": "task2", "label": "task2"}, {"id": "group234.upstream_join_id", "label": ""}, ], + "id": "group234", + "label": "group234", }, {"id": "task5", "label": "task5"}, ], + "id": None, + "label": None, } assert extract_node_id(task_group_to_dict(dag.task_group), include_label=True) == expected_node_id -def test_build_task_group_with_task_decorator(): - """ - Test that TaskGroup can be used with the @task decorator. - """ +def test_task_group_to_dict_with_task_decorator(dag_maker): from airflow.sdk import task @task @@ -370,22 +330,18 @@ def test_build_task_group_with_task_decorator(): print("task_5") logical_date = pendulum.parse("20200101") - with DAG("test_build_task_group_with_task_decorator", schedule=None, start_date=logical_date) as dag: + with dag_maker("test_build_task_group_with_task_decorator", start_date=logical_date) as dag: tsk_1 = task_1() with TaskGroup("group234") as group234: tsk_2 = task_2() tsk_3 = task_3() - tsk_4 = task_4(tsk_2, tsk_3) + task_4(tsk_2, tsk_3) tsk_5 = task_5() tsk_1 >> group234 >> tsk_5 - assert tsk_1.operator in tsk_2.operator.upstream_list - assert tsk_1.operator in tsk_3.operator.upstream_list - assert tsk_5.operator in tsk_4.operator.downstream_list - expected_node_id = { "id": None, "children": [ @@ -418,12 +374,9 @@ def test_build_task_group_with_task_decorator(): ] -def test_sub_dag_task_group(): - """ - Tests dag.partial_subset() updates task_group correctly. - """ +def test_task_group_to_dict_sub_dag(dag_maker): logical_date = pendulum.parse("20200101") - with DAG("test_test_task_group_sub_dag", schedule=None, start_date=logical_date) as dag: + with dag_maker("test_test_task_group_sub_dag", schedule=None, start_date=logical_date) as dag: task1 = EmptyOperator(task_id="task1") with TaskGroup("group234") as group234: _ = EmptyOperator(task_id="task2") @@ -479,26 +432,10 @@ def test_sub_dag_task_group(): ("task1", "group234.upstream_join_id"), ] - groups = subset.task_group.get_task_group_dict() - assert groups.keys() == {None, "group234", "group234.group34"} - - included_group_ids = {"group234", "group234.group34"} - included_task_ids = {"group234.group34.task3", "group234.group34.task4", "task1", "task5"} - for task_group in groups.values(): - assert task_group.upstream_group_ids.issubset(included_group_ids) - assert task_group.downstream_group_ids.issubset(included_group_ids) - assert task_group.upstream_task_ids.issubset(included_task_ids) - assert task_group.downstream_task_ids.issubset(included_task_ids) - - for task in subset.task_group: - assert task.upstream_task_ids.issubset(included_task_ids) - assert task.downstream_task_ids.issubset(included_task_ids) - - -def test_dag_edges(): +def test_task_group_to_dict_and_dag_edges(dag_maker): logical_date = pendulum.parse("20200101") - with DAG("test_dag_edges", schedule=None, start_date=logical_date) as dag: + with dag_maker("test_dag_edges", schedule=None, start_date=logical_date) as dag: task1 = EmptyOperator(task_id="task1") with TaskGroup("group_a") as group_a: with TaskGroup("group_b") as group_b: @@ -540,6 +477,14 @@ def test_dag_edges(): expected_node_id = { "id": None, "children": [ + { + "id": "group_d", + "children": [ + {"id": "group_d.task11"}, + {"id": "group_d.task12"}, + {"id": "group_d.upstream_join_id"}, + ], + }, {"id": "task1"}, { "id": "group_a", @@ -568,16 +513,8 @@ def test_dag_edges(): {"id": "group_c.downstream_join_id"}, ], }, - {"id": "task9"}, {"id": "task10"}, - { - "id": "group_d", - "children": [ - {"id": "group_d.task11"}, - {"id": "group_d.task12"}, - {"id": "group_d.upstream_join_id"}, - ], - }, + {"id": "task9"}, ], } @@ -607,9 +544,9 @@ def test_dag_edges(): ] -def test_dag_edges_setup_teardown(): +def test_dag_edges_setup_teardown(dag_maker): logical_date = pendulum.parse("20200101") - with DAG("test_dag_edges", schedule=None, start_date=logical_date) as dag: + with dag_maker("test_dag_edges", schedule=None, start_date=logical_date) as dag: setup1 = EmptyOperator(task_id="setup1").as_setup() teardown1 = EmptyOperator(task_id="teardown1").as_teardown() @@ -635,14 +572,13 @@ def test_dag_edges_setup_teardown(): ] -def test_dag_edges_setup_teardown_nested(): - from airflow.models.dag import DAG +def test_dag_edges_setup_teardown_nested(dag_maker): from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import task, task_group logical_date = pendulum.parse("20200101") - with DAG(dag_id="s_t_dag", schedule=None, start_date=logical_date) as dag: + with dag_maker(dag_id="s_t_dag", schedule=None, start_date=logical_date) as dag: @task def test_task(): @@ -682,61 +618,10 @@ def test_dag_edges_setup_teardown_nested(): ] -def test_duplicate_group_id(): - from airflow.exceptions import DuplicateTaskIdFound - - logical_date = pendulum.parse("20200101") - - with DAG("test_duplicate_group_id", schedule=None, start_date=logical_date): - _ = EmptyOperator(task_id="task1") - with pytest.raises(DuplicateTaskIdFound, match=r".* 'task1' .*"), TaskGroup("task1"): - pass - - with DAG("test_duplicate_group_id", schedule=None, start_date=logical_date): - _ = EmptyOperator(task_id="task1") - with TaskGroup("group1", prefix_group_id=False): - with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"), TaskGroup("group1"): - pass - - with DAG("test_duplicate_group_id", schedule=None, start_date=logical_date): - with TaskGroup("group1", prefix_group_id=False): - with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"): - _ = EmptyOperator(task_id="group1") - - with DAG("test_duplicate_group_id", schedule=None, start_date=logical_date): - _ = EmptyOperator(task_id="task1") - with TaskGroup("group1"): - with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1.downstream_join_id' .*"): - _ = EmptyOperator(task_id="downstream_join_id") - - with DAG("test_duplicate_group_id", schedule=None, start_date=logical_date): - _ = EmptyOperator(task_id="task1") - with TaskGroup("group1"): - with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1.upstream_join_id' .*"): - _ = EmptyOperator(task_id="upstream_join_id") - - -def test_task_without_dag(): - """ - Test that if a task doesn't have a DAG when it's being set as the relative of another task which - has a DAG, the task should be added to the root TaskGroup of the other task's DAG. - """ - dag = DAG(dag_id="test_task_without_dag", schedule=None, start_date=pendulum.parse("20200101")) - op1 = EmptyOperator(task_id="op1", dag=dag) - op2 = EmptyOperator(task_id="op2") - op3 = EmptyOperator(task_id="op3") - op1 >> op2 - op3 >> op2 - - assert op1.dag == op2.dag == op3.dag - assert dag.task_group.children.keys() == {"op1", "op2", "op3"} - assert dag.task_group.children.keys() == dag.task_dict.keys() - - # taskgroup decorator tests -def test_build_task_group_deco_context_manager(): +def test_build_task_group_deco_context_manager(dag_maker): """ Tests Following : 1. Nested TaskGroup creation using taskgroup decorator should create same TaskGroup which can be @@ -791,7 +676,7 @@ def test_build_task_group_deco_context_manager(): return section_2(op1) logical_date = pendulum.parse("20201109") - with DAG( + with dag_maker( dag_id="example_nested_task_group_decorator", schedule=None, start_date=logical_date, @@ -818,12 +703,9 @@ def test_build_task_group_deco_context_manager(): node_ids = { "id": None, "children": [ - {"id": "task_start"}, { "id": "section_1", "children": [ - {"id": "section_1.task_1"}, - {"id": "section_1.task_2"}, { "id": "section_1.section_2", "children": [ @@ -831,110 +713,19 @@ def test_build_task_group_deco_context_manager(): {"id": "section_1.section_2.task_4"}, ], }, + {"id": "section_1.task_1"}, + {"id": "section_1.task_2"}, ], }, {"id": "task_end"}, + {"id": "task_start"}, ], } assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids -def test_build_task_group_depended_by_task(): - """A decorator-based task group should be able to be used as a relative to operators.""" - from airflow.sdk import dag as dag_decorator, task - - @dag_decorator(schedule=None, start_date=pendulum.now()) - def build_task_group_depended_by_task(): - @task - def task_start(): - return "[Task_start]" - - @task - def task_end(): - return "[Task_end]" - - @task - def task_thing(value): - return f"[Task_thing {value}]" - - @task_group_decorator - def section_1(): - task_thing(1) - task_thing(2) - - task_start() >> section_1() >> task_end() - - dag = build_task_group_depended_by_task() - task_thing_1 = dag.task_dict["section_1.task_thing"] - task_thing_2 = dag.task_dict["section_1.task_thing__1"] - - # Tasks in the task group don't depend on each other; they both become - # downstreams to task_start, and upstreams to task_end. - assert task_thing_1.upstream_task_ids == task_thing_2.upstream_task_ids == {"task_start"} - assert task_thing_1.downstream_task_ids == task_thing_2.downstream_task_ids == {"task_end"} - - -def test_build_task_group_with_operators(): - """Tests DAG with Tasks created with *Operators and TaskGroup created with taskgroup decorator""" - from airflow.sdk import task - - def task_start(): - """Dummy Task which is First Task of Dag""" - return "[Task_start]" - - def task_end(): - """Dummy Task which is Last Task of Dag""" - print("[ Task_End ]") - - # Creating Tasks - @task - def task_1(value): - """Dummy Task1""" - return f"[ Task1 {value} ]" - - @task - def task_2(value): - """Dummy Task2""" - return f"[ Task2 {value} ]" - - @task - def task_3(value): - """Dummy Task3""" - print(f"[ Task3 {value} ]") - - # Creating TaskGroups - @task_group_decorator(group_id="section_1") - def section_a(value): - """TaskGroup for grouping related Tasks""" - return task_3(task_2(task_1(value))) - - logical_date = pendulum.parse("20201109") - with DAG( - dag_id="example_task_group_decorator_mix", - schedule=None, - start_date=logical_date, - tags=["example"], - ) as dag: - t_start = PythonOperator(task_id="task_start", python_callable=task_start, dag=dag) - sec_1 = section_a(t_start.output) - t_end = PythonOperator(task_id="task_end", python_callable=task_end, dag=dag) - sec_1.set_downstream(t_end) - - # Testing Tasks in DAG - assert set(dag.task_group.children.keys()) == {"section_1", "task_start", "task_end"} - assert set(dag.task_group.children["section_1"].children.keys()) == { - "section_1.task_2", - "section_1.task_3", - "section_1.task_1", - } - - # Testing Tasks downstream - assert dag.task_dict["task_start"].downstream_task_ids == {"section_1.task_1"} - assert dag.task_dict["section_1.task_3"].downstream_task_ids == {"task_end"} - - -def test_task_group_context_mix(): +def test_task_group_context_mix(dag_maker): """Test cases to check nested TaskGroup context manager with taskgroup decorator""" from airflow.sdk import task @@ -969,13 +760,13 @@ def test_task_group_context_mix(): return task_3(task_2(task_1(value))) logical_date = pendulum.parse("20201109") - with DAG( + with dag_maker( dag_id="example_task_group_decorator_mix", schedule=None, start_date=logical_date, tags=["example"], ) as dag: - t_start = PythonOperator(task_id="task_start", python_callable=task_start, dag=dag) + t_start = PythonOperator(task_id="task_start", python_callable=task_start) with TaskGroup("section_1", tooltip="section_1") as section_1: sec_2 = section_2(t_start.output) @@ -986,7 +777,7 @@ def test_task_group_context_mix(): sec_2.set_downstream(task_s1) task_s1 >> [task_s2, task_s3] - t_end = PythonOperator(task_id="task_end", python_callable=task_end, dag=dag) + t_end = PythonOperator(task_id="task_end", python_callable=task_end) t_start >> section_1 >> t_end node_ids = { @@ -1018,26 +809,7 @@ def test_task_group_context_mix(): assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids -def test_default_args(): - """Testing TaskGroup with default_args""" - logical_date = pendulum.parse("20201109") - with DAG( - dag_id="example_task_group_default_args", - schedule=None, - start_date=logical_date, - default_args={"owner": "dag"}, - ): - with TaskGroup("group1", default_args={"owner": "group"}): - task_1 = EmptyOperator(task_id="task_1") - task_2 = EmptyOperator(task_id="task_2", owner="task") - task_3 = EmptyOperator(task_id="task_3", default_args={"owner": "task"}) - - assert task_1.owner == "group" - assert task_2.owner == "task" - assert task_3.owner == "task" - - -def test_duplicate_task_group_id(): +def test_duplicate_task_group_id(dag_maker): """Testing automatic suffix assignment for duplicate group_id""" from airflow.sdk import task @@ -1082,7 +854,7 @@ def test_duplicate_task_group_id(): task_end() logical_date = pendulum.parse("20201109") - with DAG( + with dag_maker( dag_id="example_duplicate_task_group_id", schedule=None, start_date=logical_date, @@ -1110,8 +882,7 @@ def test_duplicate_task_group_id(): assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids -def test_call_taskgroup_twice(): - """Test for using same taskgroup decorated function twice""" +def test_call_taskgroup_twice(dag_maker): from airflow.sdk import task @task(task_id="start_task") @@ -1138,7 +909,7 @@ def test_call_taskgroup_twice(): task_end() logical_date = pendulum.parse("20201109") - with DAG( + with dag_maker( dag_id="example_multi_call_task_groups", schedule=None, start_date=logical_date, @@ -1153,17 +924,17 @@ def test_call_taskgroup_twice(): { "id": "task_group1", "children": [ + {"id": "task_group1.end_task"}, {"id": "task_group1.start_task"}, {"id": "task_group1.task"}, - {"id": "task_group1.end_task"}, ], }, { "id": "task_group1__1", "children": [ + {"id": "task_group1__1.end_task"}, {"id": "task_group1__1.start_task"}, {"id": "task_group1__1.task"}, - {"id": "task_group1__1.end_task"}, ], }, ], @@ -1172,72 +943,6 @@ def test_call_taskgroup_twice(): assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids -def test_pass_taskgroup_output_to_task(): - """Test that the output of a task group can be passed to a task.""" - from airflow.sdk import task - - @task - def one(): - return 1 - - @task_group_decorator - def addition_task_group(num): - @task - def add_one(i): - return i + 1 - - return add_one(num) - - @task - def increment(num): - return num + 1 - - @dag(schedule=None, start_date=pendulum.DateTime(2022, 1, 1), default_args={"owner": "airflow"}) - def wrap(): - total_1 = one() - assert isinstance(total_1, XComArg) - total_2 = addition_task_group(total_1) - assert isinstance(total_2, XComArg) - total_3 = increment(total_2) - assert isinstance(total_3, XComArg) - - wrap() - - -def test_decorator_unknown_args(): - """Test that unknown args passed to the decorator cause an error at parse time""" - with pytest.raises(TypeError): - - @task_group_decorator(b=2) - def tg(): ... - - -def test_decorator_multiple_use_task(): - from airflow.sdk import task - - @dag("test-dag", schedule=None, start_date=DEFAULT_DATE) - def _test_dag(): - @task - def t(): - pass - - @task_group_decorator - def tg(): - for _ in range(3): - t() - - t() >> tg() >> t() - - test_dag = _test_dag() - assert test_dag.task_ids == [ - "t", # Start end. - "tg.t", - "tg.t__1", - "tg.t__2", - "t__1", # End node. - ] - - def test_topological_sort1(): dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) @@ -1411,197 +1116,8 @@ def test_topological_group_dep(): ] -def test_add_to_sub_group(): - with DAG("test_dag", schedule=None, start_date=pendulum.parse("20200101")): - tg = TaskGroup("section") - task = EmptyOperator(task_id="task") - with pytest.raises(TaskAlreadyInTaskGroup) as ctx: - tg.add(task) - - assert str(ctx.value) == "cannot add 'task' to 'section' (already in the DAG's root group)" - - -def test_add_to_another_group(): - with DAG("test_dag", schedule=None, start_date=pendulum.parse("20200101")): - tg = TaskGroup("section_1") - with TaskGroup("section_2"): - task = EmptyOperator(task_id="task") - with pytest.raises(TaskAlreadyInTaskGroup) as ctx: - tg.add(task) - - assert str(ctx.value) == "cannot add 'section_2.task' to 'section_1' (already in group 'section_2')" - - -def test_task_group_edge_modifier_chain(): - from airflow.sdk import Label, chain - - with DAG(dag_id="test", schedule=None, start_date=pendulum.DateTime(2022, 5, 20)) as dag: - start = EmptyOperator(task_id="sleep_3_seconds") - - with TaskGroup(group_id="group1") as tg: - t1 = EmptyOperator(task_id="dummy1") - t2 = EmptyOperator(task_id="dummy2") - - t3 = EmptyOperator(task_id="echo_done") - - # The case we are testing for is when a Label is inside a list -- meaning that we do tg.set_upstream - # instead of label.set_downstream - chain(start, [Label("branch three")], tg, t3) - - assert start.downstream_task_ids == {t1.node_id, t2.node_id} - assert t3.upstream_task_ids == {t1.node_id, t2.node_id} - assert tg.upstream_task_ids == set() - assert tg.downstream_task_ids == {t3.node_id} - # Check that we can perform a topological_sort - dag.topological_sort() - - -def test_mapped_task_group_id_prefix_task_id(): - from tests_common.test_utils.mock_operators import MockOperator - - with DAG(dag_id="d", schedule=None, start_date=DEFAULT_DATE) as dag: - t1 = MockOperator.partial(task_id="t1").expand(arg1=[]) - with TaskGroup("g"): - t2 = MockOperator.partial(task_id="t2").expand(arg1=[]) - - assert t1.task_id == "t1" - assert t2.task_id == "g.t2" - - dag.get_task("t1") == t1 - dag.get_task("g.t2") == t2 - - -def test_iter_tasks(): - with DAG("test_dag", schedule=None, start_date=pendulum.parse("20200101")) as dag: - with TaskGroup("section_1") as tg1: - EmptyOperator(task_id="task1") - - with TaskGroup("section_2") as tg2: - task2 = EmptyOperator(task_id="task2") - task3 = EmptyOperator(task_id="task3") - mapped_bash_operator = BashOperator.partial(task_id="bash_task").expand( - bash_command=[ - "echo hello 1", - "echo hello 2", - "echo hello 3", - ] - ) - task2 >> task3 >> mapped_bash_operator - - tg1 >> tg2 - root_group = dag.task_group - assert [t.task_id for t in root_group.iter_tasks()] == [ - "section_1.task1", - "section_2.task2", - "section_2.task3", - "section_2.bash_task", - ] - assert [t.task_id for t in tg1.iter_tasks()] == [ - "section_1.task1", - ] - assert [t.task_id for t in tg2.iter_tasks()] == [ - "section_2.task2", - "section_2.task3", - "section_2.bash_task", - ] - - -def test_override_dag_default_args(): - with DAG( - dag_id="test_dag", - schedule=None, - start_date=pendulum.parse("20200101"), - default_args={ - "retries": 1, - "owner": "x", - }, - ): - with TaskGroup( - group_id="task_group", - default_args={ - "owner": "y", - "execution_timeout": timedelta(seconds=10), - }, - ): - task = EmptyOperator(task_id="task") - - assert task.retries == 1 - assert task.owner == "y" - assert task.execution_timeout == timedelta(seconds=10) - - -def test_override_dag_default_args_in_nested_tg(): - with DAG( - dag_id="test_dag", - schedule=None, - start_date=pendulum.parse("20200101"), - default_args={ - "retries": 1, - "owner": "x", - }, - ): - with TaskGroup( - group_id="task_group", - default_args={ - "owner": "y", - "execution_timeout": timedelta(seconds=10), - }, - ): - with TaskGroup(group_id="nested_task_group"): - task = EmptyOperator(task_id="task") - - assert task.retries == 1 - assert task.owner == "y" - assert task.execution_timeout == timedelta(seconds=10) - - -def test_override_dag_default_args_in_multi_level_nested_tg(): - with DAG( - dag_id="test_dag", - schedule=None, - start_date=pendulum.parse("20200101"), - default_args={ - "retries": 1, - "owner": "x", - }, - ): - with TaskGroup( - group_id="task_group", - default_args={ - "owner": "y", - "execution_timeout": timedelta(seconds=10), - }, - ): - with TaskGroup( - group_id="first_nested_task_group", - default_args={ - "owner": "z", - }, - ): - with TaskGroup(group_id="second_nested_task_group"): - with TaskGroup(group_id="third_nested_task_group"): - task = EmptyOperator(task_id="task") - - assert task.retries == 1 - assert task.owner == "z" - assert task.execution_timeout == timedelta(seconds=10) - - -def test_task_group_arrow_with_setups_teardowns(): - with DAG(dag_id="hi", schedule=None, start_date=pendulum.datetime(2022, 1, 1)): - with TaskGroup(group_id="tg1") as tg1: - s1 = BaseOperator(task_id="s1") - w1 = BaseOperator(task_id="w1") - t1 = BaseOperator(task_id="t1") - s1 >> w1 >> t1.as_teardown(setups=s1) - w2 = BaseOperator(task_id="w2") - tg1 >> w2 - assert t1.downstream_task_ids == set() - assert w1.downstream_task_ids == {"tg1.t1", "w2"} - - def test_task_group_arrow_with_setup_group(): - with DAG(dag_id="setup_group_teardown_group", schedule=None, start_date=pendulum.now()): + with DAG(dag_id="setup_group_teardown_group") as dag: with TaskGroup("group_1") as g1: @setup @@ -1638,6 +1154,8 @@ def test_task_group_arrow_with_setup_group(): assert set(t2.operator.downstream_task_ids) == set() def get_nodes(group): + serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + group = serialized_dag.task_group_dict[g1.group_id] d = task_group_to_dict(group) new_d = {} new_d["id"] = d["id"] @@ -1654,63 +1172,9 @@ def test_task_group_arrow_with_setup_group(): } -def test_task_group_arrow_with_setup_group_deeper_setup(): - """ - When recursing upstream for a non-teardown leaf, we should ignore setups that - are direct upstream of a teardown. - """ - with DAG(dag_id="setup_group_teardown_group_2", schedule=None, start_date=pendulum.now()): - with TaskGroup("group_1") as g1: - - @setup - def setup_1(): ... - - @setup - def setup_2(): ... - - @teardown - def teardown_0(): ... - - s1 = setup_1() - s2 = setup_2() - t0 = teardown_0() - s2 >> t0 - - with TaskGroup("group_2") as g2: - - @teardown - def teardown_1(): ... - - @teardown - def teardown_2(): ... - - t1 = teardown_1() - t2 = teardown_2() - - @task_decorator - def work(): ... - - w1 = work() - g1 >> w1 >> g2 - t1.as_teardown(setups=s1) - t2.as_teardown(setups=s2) - assert set(s1.operator.downstream_task_ids) == {"work", "group_2.teardown_1"} - assert set(s2.operator.downstream_task_ids) == {"group_1.teardown_0", "group_2.teardown_2"} - assert set(w1.operator.downstream_task_ids) == {"group_2.teardown_1", "group_2.teardown_2"} - assert set(t1.operator.downstream_task_ids) == set() - assert set(t2.operator.downstream_task_ids) == set() - - -def test_task_group_with_invalid_arg_type_raises_error(): - error_msg = r"'ui_color' must be <class 'str'> \(got 123 that is a <class 'int'>\)\." - with DAG(dag_id="dag_with_tg_invalid_arg_type", schedule=None): - with pytest.raises(TypeError, match=error_msg): - _ = TaskGroup("group_1", ui_color=123) - - -def test_task_group_display_name_used_as_label(): +def test_task_group_display_name_used_as_label(dag_maker): """Test that the group_display_name for TaskGroup is used as the label for display on the UI.""" - with DAG(dag_id="display_name", schedule=None, start_date=pendulum.datetime(2022, 1, 1)) as dag: + with dag_maker(dag_id="display_name", schedule=None, start_date=pendulum.datetime(2022, 1, 1)) as dag: with TaskGroup(group_id="tg", group_display_name="my_custom_name") as tg: task1 = BaseOperator(task_id="task1") task2 = BaseOperator(task_id="task2") diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index ac4816b4dcf..29da84bb891 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -24,15 +24,12 @@ import functools import operator import re import weakref -from collections.abc import Callable, Generator, Iterator, Sequence -from functools import cache -from operator import methodcaller +from collections.abc import Generator, Iterator, Sequence from typing import TYPE_CHECKING, Any import attrs import methodtools -from airflow.configuration import conf from airflow.exceptions import ( AirflowException, DuplicateTaskIdFound, @@ -673,97 +670,3 @@ class MappedTaskGroup(TaskGroup): for op, _ in XComArg.iter_xcom_references(self._expand_input): yield op - - -@cache -def get_task_group_children_getter() -> Callable: - """Get the Task Group Children Getter for the DAG.""" - sort_order = conf.get("api", "grid_view_sorting_order") - if sort_order == "topological": - return methodcaller("topological_sort") - return methodcaller("hierarchical_alphabetical_sort") - - -def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False): - """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" - from airflow.models.mappedoperator import MappedOperator - from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator - from airflow.serialization.serialized_objects import SerializedBaseOperator - - if isinstance(task := task_item_or_group, (AbstractOperator, SerializedBaseOperator, MappedOperator)): - node_operator = { - "id": task.task_id, - "label": task.label, - "operator": task.operator_name, - "type": "task", - } - if task.is_setup: - node_operator["setup_teardown_type"] = "setup" - elif task.is_teardown: - node_operator["setup_teardown_type"] = "teardown" - if isinstance(task, MappedOperator) or parent_group_is_mapped: - node_operator["is_mapped"] = True - return node_operator - - task_group = task_item_or_group - is_mapped = isinstance(task_group, MappedTaskGroup) - children = [ - task_group_to_dict(child, parent_group_is_mapped=parent_group_is_mapped or is_mapped) - for child in get_task_group_children_getter()(task_group) - ] - - if task_group.upstream_group_ids or task_group.upstream_task_ids: - # This is the join node used to reduce the number of edges between two TaskGroup. - children.append({"id": task_group.upstream_join_id, "label": "", "type": "join"}) - - if task_group.downstream_group_ids or task_group.downstream_task_ids: - # This is the join node used to reduce the number of edges between two TaskGroup. - children.append({"id": task_group.downstream_join_id, "label": "", "type": "join"}) - - return { - "id": task_group.group_id, - "label": task_group.label, - "tooltip": task_group.tooltip, - "is_mapped": is_mapped, - "children": children, - "type": "task", - } - - -def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False): - """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" - from airflow.models.mappedoperator import MappedOperator - from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator - from airflow.serialization.serialized_objects import SerializedBaseOperator - - if isinstance(task := task_item_or_group, (AbstractOperator, MappedOperator, SerializedBaseOperator)): - is_mapped = None - if task.is_mapped or parent_group_is_mapped: - is_mapped = True - setup_teardown_type = None - if task.is_setup is True: - setup_teardown_type = "setup" - elif task.is_teardown is True: - setup_teardown_type = "teardown" - return { - "id": task.task_id, - "label": task.label, - "is_mapped": is_mapped, - "children": None, - "setup_teardown_type": setup_teardown_type, - } - - task_group = task_item_or_group - task_group_sort = get_task_group_children_getter() - is_mapped_group = isinstance(task_group, MappedTaskGroup) - children = [ - task_group_to_dict_grid(x, parent_group_is_mapped=parent_group_is_mapped or is_mapped_group) - for x in task_group_sort(task_group) - ] - - return { - "id": task_group.group_id, - "label": task_group.label, - "is_mapped": is_mapped_group or None, - "children": children or None, - } diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py index 4654f3def03..1ee67d4c92e 100644 --- a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py +++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py @@ -20,10 +20,23 @@ from __future__ import annotations import pendulum import pytest -from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.sdk.definitions.dag import DAG +from airflow.exceptions import TaskAlreadyInTaskGroup +from airflow.sdk import ( + DAG, + XComArg, + dag, + setup, + task as task_decorator, + task_group as task_group_decorator, + teardown, + timezone, +) from airflow.sdk.definitions.taskgroup import TaskGroup +from tests_common.test_utils.compat import BashOperator, EmptyOperator, PythonOperator + +DEFAULT_DATE = timezone.datetime(2025, 1, 1) + class TestTaskGroup: @pytest.mark.parametrize( @@ -120,3 +133,775 @@ def test_task_group_dependencies_between_tasks_if_task_group_is_empty_3(): assert task1.downstream_task_ids == {"task2"} assert task2.downstream_task_ids == {"group5.task3"} + + +def test_build_task_group_context_manager(): + """Test basic TaskGroup functionality using context manager.""" + logical_date = pendulum.parse("20200101") + with DAG("test_build_task_group_context_manager", schedule=None, start_date=logical_date) as dag: + task1 = EmptyOperator(task_id="task1") + with TaskGroup("group234") as group234: + _ = EmptyOperator(task_id="task2") + + with TaskGroup("group34") as group34: + _ = EmptyOperator(task_id="task3") + _ = EmptyOperator(task_id="task4") + + task5 = EmptyOperator(task_id="task5") + task1 >> group234 + group34 >> task5 + + assert task1.get_direct_relative_ids(upstream=False) == { + "group234.group34.task4", + "group234.group34.task3", + "group234.task2", + } + assert task5.get_direct_relative_ids(upstream=True) == { + "group234.group34.task4", + "group234.group34.task3", + } + + assert dag.task_group.group_id is None + assert dag.task_group.is_root + assert set(dag.task_group.children.keys()) == {"task1", "group234", "task5"} + assert group34.group_id == "group234.group34" + + +def test_build_task_group(): + """ + Test alternative syntax to use TaskGroup. It should result in the same TaskGroup + as using context manager. + """ + logical_date = pendulum.parse("20200101") + dag = DAG("test_build_task_group", schedule=None, start_date=logical_date) + task1 = EmptyOperator(task_id="task1", dag=dag) + group234 = TaskGroup("group234", dag=dag) + _ = EmptyOperator(task_id="task2", dag=dag, task_group=group234) + group34 = TaskGroup("group34", dag=dag, parent_group=group234) + _ = EmptyOperator(task_id="task3", dag=dag, task_group=group34) + _ = EmptyOperator(task_id="task4", dag=dag, task_group=group34) + task5 = EmptyOperator(task_id="task5", dag=dag) + + task1 >> group234 + group34 >> task5 + + # Test basic TaskGroup structure + assert dag.task_group.group_id is None + assert dag.task_group.is_root + assert set(dag.task_group.children.keys()) == {"task1", "group234", "task5"} + assert group34.group_id == "group234.group34" + + +def test_build_task_group_with_prefix(): + """ + Tests that prefix_group_id turns on/off prefixing of task_id with group_id. + """ + logical_date = pendulum.parse("20200101") + with DAG("test_build_task_group_with_prefix", start_date=logical_date): + task1 = EmptyOperator(task_id="task1") + with TaskGroup("group234", prefix_group_id=False) as group234: + task2 = EmptyOperator(task_id="task2") + + with TaskGroup("group34") as group34: + task3 = EmptyOperator(task_id="task3") + + with TaskGroup("group4", prefix_group_id=False) as group4: + task4 = EmptyOperator(task_id="task4") + + task5 = EmptyOperator(task_id="task5") + task1 >> group234 + group34 >> task5 + + assert task2.task_id == "task2" + assert group34.group_id == "group34" + assert task3.task_id == "group34.task3" + assert group4.group_id == "group34.group4" + assert task4.task_id == "task4" + assert task5.task_id == "task5" + assert group234.get_child_by_label("task2") == task2 + assert group234.get_child_by_label("group34") == group34 + assert group4.get_child_by_label("task4") == task4 + + +def test_build_task_group_with_prefix_functionality(): + """ + Tests TaskGroup prefix_group_id functionality - additional test for comprehensive coverage. + """ + logical_date = pendulum.parse("20200101") + with DAG("test_prefix_functionality", start_date=logical_date): + task1 = EmptyOperator(task_id="task1") + with TaskGroup("group234", prefix_group_id=False) as group234: + task2 = EmptyOperator(task_id="task2") + + with TaskGroup("group34") as group34: + task3 = EmptyOperator(task_id="task3") + + with TaskGroup("group4", prefix_group_id=False) as group4: + task4 = EmptyOperator(task_id="task4") + + task5 = EmptyOperator(task_id="task5") + task1 >> group234 + group34 >> task5 + + # Test prefix_group_id behavior + assert task2.task_id == "task2" # prefix_group_id=False, so no prefix + assert group34.group_id == "group34" # nested group gets prefixed + assert task3.task_id == "group34.task3" # task in nested group gets full prefix + assert group4.group_id == "group34.group4" # nested group gets parent prefix + assert task4.task_id == "task4" # prefix_group_id=False, so no prefix + assert task5.task_id == "task5" # root level task, no prefix + + # Test group hierarchy and child access + assert group234.get_child_by_label("task2") == task2 + assert group234.get_child_by_label("group34") == group34 + assert group4.get_child_by_label("task4") == task4 + + +def test_build_task_group_with_task_decorator(): + """ + Test that TaskGroup can be used with the @task decorator. + """ + from airflow.sdk import task + + @task + def task_1(): + print("task_1") + + @task + def task_2(): + return "task_2" + + @task + def task_3(): + return "task_3" + + @task + def task_4(task_2_output, task_3_output): + print(task_2_output, task_3_output) + + @task + def task_5(): + print("task_5") + + logical_date = pendulum.parse("20200101") + with DAG("test_build_task_group_with_task_decorator", start_date=logical_date): + tsk_1 = task_1() + + with TaskGroup("group234") as group234: + tsk_2 = task_2() + tsk_3 = task_3() + tsk_4 = task_4(tsk_2, tsk_3) + + tsk_5 = task_5() + + tsk_1 >> group234 >> tsk_5 + + # Test TaskGroup functionality with @task decorator + assert tsk_1.operator in tsk_2.operator.upstream_list + assert tsk_1.operator in tsk_3.operator.upstream_list + assert tsk_5.operator in tsk_4.operator.downstream_list + + # Test TaskGroup structure + assert group234.group_id == "group234" + assert len(group234.children) == 3 # task_2, task_3, task_4 + assert "group234.task_2" in group234.children + assert "group234.task_3" in group234.children + assert "group234.task_4" in group234.children + + +def test_sub_dag_task_group(): + """ + Tests dag.partial_subset() updates task_group correctly. + """ + logical_date = pendulum.parse("20200101") + with DAG("test_test_task_group_sub_dag", schedule=None, start_date=logical_date) as dag: + task1 = EmptyOperator(task_id="task1") + with TaskGroup("group234") as group234: + _ = EmptyOperator(task_id="task2") + + with TaskGroup("group34") as group34: + _ = EmptyOperator(task_id="task3") + _ = EmptyOperator(task_id="task4") + + with TaskGroup("group6") as group6: + _ = EmptyOperator(task_id="task6") + + task7 = EmptyOperator(task_id="task7") + task5 = EmptyOperator(task_id="task5") + + task1 >> group234 + group34 >> task5 + group234 >> group6 + group234 >> task7 + + subset = dag.partial_subset(task_ids="task5", include_upstream=True, include_downstream=False) + + # Test that partial_subset correctly updates task_group structure + groups = subset.task_group.get_task_group_dict() + assert groups.keys() == {None, "group234", "group234.group34"} + + included_group_ids = {"group234", "group234.group34"} + included_task_ids = {"group234.group34.task3", "group234.group34.task4", "task1", "task5"} + + # Test that subset maintains correct group relationships + for task_group in groups.values(): + assert task_group.upstream_group_ids.issubset(included_group_ids) + assert task_group.downstream_group_ids.issubset(included_group_ids) + assert task_group.upstream_task_ids.issubset(included_task_ids) + assert task_group.downstream_task_ids.issubset(included_task_ids) + + # Test that subset maintains correct task relationships + for task in subset.task_group: + assert task.upstream_task_ids.issubset(included_task_ids) + assert task.downstream_task_ids.issubset(included_task_ids) + + # Test basic subset properties + assert len(subset.tasks) == 4 # task1, task3, task4, task5 + assert subset.task_dict["task1"].task_id == "task1" + assert subset.task_dict["group234.group34.task3"].task_id == "group234.group34.task3" + assert subset.task_dict["group234.group34.task4"].task_id == "group234.group34.task4" + assert subset.task_dict["task5"].task_id == "task5" + + +def test_dag_edges_task_group_structure(): + logical_date = pendulum.parse("20200101") + with DAG("test_dag_edges", schedule=None, start_date=logical_date): + task1 = EmptyOperator(task_id="task1") + with TaskGroup("group_a") as group_a: + with TaskGroup("group_b") as group_b: + task2 = EmptyOperator(task_id="task2") + task3 = EmptyOperator(task_id="task3") + task4 = EmptyOperator(task_id="task4") + task2 >> [task3, task4] + + task5 = EmptyOperator(task_id="task5") + + task5 << group_b + + task1 >> group_a + + with TaskGroup("group_c") as group_c: + task6 = EmptyOperator(task_id="task6") + task7 = EmptyOperator(task_id="task7") + task8 = EmptyOperator(task_id="task8") + [task6, task7] >> task8 + group_a >> group_c + + task5 >> task8 + + task9 = EmptyOperator(task_id="task9") + task10 = EmptyOperator(task_id="task10") + + group_c >> [task9, task10] + + with TaskGroup("group_d") as group_d: + task11 = EmptyOperator(task_id="task11") + task12 = EmptyOperator(task_id="task12") + task11 >> task12 + + group_d << group_c + + # Test TaskGroup structure and relationships + assert group_a.group_id == "group_a" + assert group_b.group_id == "group_a.group_b" + assert group_c.group_id == "group_c" + assert group_d.group_id == "group_d" + + # Test task relationships within groups + assert task2.downstream_task_ids == {"group_a.group_b.task3", "group_a.group_b.task4"} + assert task3.upstream_task_ids == {"group_a.group_b.task2"} + assert task4.upstream_task_ids == {"group_a.group_b.task2"} + assert task5.upstream_task_ids == {"group_a.group_b.task3", "group_a.group_b.task4"} + + # Test cross-group relationships + assert task1.downstream_task_ids == {"group_a.group_b.task2"} + assert task6.upstream_task_ids == {"group_a.task5"} + assert task7.upstream_task_ids == {"group_a.task5"} + assert task8.upstream_task_ids == {"group_c.task6", "group_c.task7", "group_a.task5"} + + # Test group-to-task relationships + assert task9.upstream_task_ids == {"group_c.task8"} + assert task10.upstream_task_ids == {"group_c.task8"} + assert task11.upstream_task_ids == {"group_c.task8"} + + +def test_duplicate_group_id(): + from airflow.exceptions import DuplicateTaskIdFound + + logical_date = pendulum.parse("20200101") + + with DAG("test_duplicate_group_id", schedule=None, start_date=logical_date): + _ = EmptyOperator(task_id="task1") + with pytest.raises(DuplicateTaskIdFound, match=r".* 'task1' .*"), TaskGroup("task1"): + pass + + with DAG("test_duplicate_group_id", schedule=None, start_date=logical_date): + _ = EmptyOperator(task_id="task1") + with TaskGroup("group1", prefix_group_id=False): + with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"), TaskGroup("group1"): + pass + + with DAG("test_duplicate_group_id", schedule=None, start_date=logical_date): + with TaskGroup("group1", prefix_group_id=False): + with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"): + _ = EmptyOperator(task_id="group1") + + with DAG("test_duplicate_group_id", schedule=None, start_date=logical_date): + _ = EmptyOperator(task_id="task1") + with TaskGroup("group1"): + with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1.downstream_join_id' .*"): + _ = EmptyOperator(task_id="downstream_join_id") + + with DAG("test_duplicate_group_id", schedule=None, start_date=logical_date): + _ = EmptyOperator(task_id="task1") + with TaskGroup("group1"): + with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1.upstream_join_id' .*"): + _ = EmptyOperator(task_id="upstream_join_id") + + +def test_task_without_dag(): + """ + Test that if a task doesn't have a DAG when it's being set as the relative of another task which + has a DAG, the task should be added to the root TaskGroup of the other task's DAG. + """ + dag = DAG(dag_id="test_task_without_dag", schedule=None, start_date=pendulum.parse("20200101")) + op1 = EmptyOperator(task_id="op1", dag=dag) + op2 = EmptyOperator(task_id="op2") + op3 = EmptyOperator(task_id="op3") + op1 >> op2 + op3 >> op2 + + assert op1.dag == op2.dag == op3.dag + assert dag.task_group.children.keys() == {"op1", "op2", "op3"} + assert dag.task_group.children.keys() == dag.task_dict.keys() + + +def test_default_args(): + """Testing TaskGroup with default_args""" + logical_date = pendulum.parse("20201109") + with DAG( + dag_id="example_task_group_default_args", + schedule=None, + start_date=logical_date, + default_args={"owner": "dag"}, + ): + with TaskGroup("group1", default_args={"owner": "group"}): + task_1 = EmptyOperator(task_id="task_1") + task_2 = EmptyOperator(task_id="task_2", owner="task") + task_3 = EmptyOperator(task_id="task_3", default_args={"owner": "task"}) + + assert task_1.owner == "group" + assert task_2.owner == "task" + assert task_3.owner == "task" + + +def test_iter_tasks(): + with DAG("test_dag", schedule=None, start_date=pendulum.parse("20200101")) as dag: + with TaskGroup("section_1") as tg1: + EmptyOperator(task_id="task1") + + with TaskGroup("section_2") as tg2: + task2 = EmptyOperator(task_id="task2") + task3 = EmptyOperator(task_id="task3") + mapped_bash_operator = BashOperator.partial(task_id="bash_task").expand( + bash_command=[ + "echo hello 1", + "echo hello 2", + "echo hello 3", + ] + ) + task2 >> task3 >> mapped_bash_operator + + tg1 >> tg2 + root_group = dag.task_group + assert [t.task_id for t in root_group.iter_tasks()] == [ + "section_1.task1", + "section_2.task2", + "section_2.task3", + "section_2.bash_task", + ] + assert [t.task_id for t in tg1.iter_tasks()] == [ + "section_1.task1", + ] + assert [t.task_id for t in tg2.iter_tasks()] == [ + "section_2.task2", + "section_2.task3", + "section_2.bash_task", + ] + + +def test_override_dag_default_args(): + logical_date = pendulum.parse("20201109") + with DAG( + dag_id="example_task_group_default_args", + schedule=None, + start_date=logical_date, + default_args={"owner": "dag"}, + ): + with TaskGroup("group1", default_args={"owner": "group"}): + task_1 = EmptyOperator(task_id="task_1") + task_2 = EmptyOperator(task_id="task_2", owner="task") + task_3 = EmptyOperator(task_id="task_3", default_args={"owner": "task"}) + + assert task_1.owner == "group" + assert task_2.owner == "task" + assert task_3.owner == "task" + + +def test_override_dag_default_args_in_nested_tg(): + logical_date = pendulum.parse("20201109") + with DAG( + dag_id="example_task_group_default_args", + schedule=None, + start_date=logical_date, + default_args={"owner": "dag"}, + ): + with TaskGroup("group1", default_args={"owner": "group1"}): + task_1 = EmptyOperator(task_id="task_1") + with TaskGroup("group2", default_args={"owner": "group2"}): + task_2 = EmptyOperator(task_id="task_2") + task_3 = EmptyOperator(task_id="task_3", owner="task") + + assert task_1.owner == "group1" + assert task_2.owner == "group2" + assert task_3.owner == "task" + + +def test_override_dag_default_args_in_multi_level_nested_tg(): + logical_date = pendulum.parse("20201109") + with DAG( + dag_id="example_task_group_default_args", + schedule=None, + start_date=logical_date, + default_args={"owner": "dag"}, + ): + with TaskGroup("group1", default_args={"owner": "group1"}): + task_1 = EmptyOperator(task_id="task_1") + with TaskGroup("group2"): + task_2 = EmptyOperator(task_id="task_2") + with TaskGroup("group3", default_args={"owner": "group3"}): + task_3 = EmptyOperator(task_id="task_3") + task_4 = EmptyOperator(task_id="task_4", owner="task") + + assert task_1.owner == "group1" + assert task_2.owner == "group1" # inherits from group1 + assert task_3.owner == "group3" + assert task_4.owner == "task" + + +def test_task_group_arrow_with_setups_teardowns(): + with DAG(dag_id="hi", schedule=None, start_date=pendulum.datetime(2022, 1, 1)): + with TaskGroup(group_id="tg1") as tg1: + s1 = EmptyOperator(task_id="s1") + w1 = EmptyOperator(task_id="w1") + t1 = EmptyOperator(task_id="t1") + s1 >> w1 >> t1.as_teardown(setups=s1) + w2 = EmptyOperator(task_id="w2") + tg1 >> w2 + + assert t1.downstream_task_ids == set() + assert w1.downstream_task_ids == {"tg1.t1", "w2"} + assert s1.downstream_task_ids == {"tg1.t1", "tg1.w1"} + assert t1.upstream_task_ids == {"tg1.s1", "tg1.w1"} + + +def test_task_group_arrow_basic(): + with DAG(dag_id="basic_group_test"): + with TaskGroup("group_1") as g1: + task_1 = EmptyOperator(task_id="task_1") + + with TaskGroup("group_2") as g2: + task_2 = EmptyOperator(task_id="task_2") + + g1 >> g2 + + # Test basic TaskGroup relationships + assert task_1.downstream_task_ids == {"group_2.task_2"} + assert task_2.upstream_task_ids == {"group_1.task_1"} + + +def test_task_group_nested_structure(): + with DAG(dag_id="nested_group_test"): + with TaskGroup("group_1") as g1: + with TaskGroup("group_1_1") as g1_1: + task_1_1 = EmptyOperator(task_id="task_1_1") + + with TaskGroup("group_2") as g2: + task_2 = EmptyOperator(task_id="task_2") + + g1 >> g2 + + # Test nested TaskGroup structure + assert g1_1.group_id == "group_1.group_1_1" + assert task_1_1.task_id == "group_1.group_1_1.task_1_1" + assert task_1_1.downstream_task_ids == {"group_2.task_2"} + assert task_2.upstream_task_ids == {"group_1.group_1_1.task_1_1"} + + +def test_task_group_with_invalid_arg_type_raises_error(): + error_msg = r"'ui_color' must be <class 'str'> \(got 123 that is a <class 'int'>\)\." + with DAG(dag_id="dag_with_tg_invalid_arg_type", schedule=None): + with pytest.raises(TypeError, match=error_msg): + _ = TaskGroup("group_1", ui_color=123) + + +def test_task_group_arrow_with_setup_group_deeper_setup(): + """ + When recursing upstream for a non-teardown leaf, we should ignore setups that + are direct upstream of a teardown. + """ + with DAG(dag_id="setup_group_teardown_group_2", schedule=None, start_date=pendulum.now()): + with TaskGroup("group_1") as g1: + + @setup + def setup_1(): ... + + @setup + def setup_2(): ... + + @teardown + def teardown_0(): ... + + s1 = setup_1() + s2 = setup_2() + t0 = teardown_0() + s2 >> t0 + + with TaskGroup("group_2") as g2: + + @teardown + def teardown_1(): ... + + @teardown + def teardown_2(): ... + + t1 = teardown_1() + t2 = teardown_2() + + @task_decorator + def work(): ... + + w1 = work() + g1 >> w1 >> g2 + t1.as_teardown(setups=s1) + t2.as_teardown(setups=s2) + assert set(s1.operator.downstream_task_ids) == {"work", "group_2.teardown_1"} + assert set(s2.operator.downstream_task_ids) == {"group_1.teardown_0", "group_2.teardown_2"} + assert set(w1.operator.downstream_task_ids) == {"group_2.teardown_1", "group_2.teardown_2"} + assert set(t1.operator.downstream_task_ids) == set() + assert set(t2.operator.downstream_task_ids) == set() + + +def test_add_to_sub_group(): + with DAG("test_dag", schedule=None, start_date=pendulum.parse("20200101")): + tg = TaskGroup("section") + task = EmptyOperator(task_id="task") + with pytest.raises(TaskAlreadyInTaskGroup) as ctx: + tg.add(task) + + assert str(ctx.value) == "cannot add 'task' to 'section' (already in the DAG's root group)" + + +def test_add_to_another_group(): + with DAG("test_dag", schedule=None, start_date=pendulum.parse("20200101")): + tg = TaskGroup("section_1") + with TaskGroup("section_2"): + task = EmptyOperator(task_id="task") + with pytest.raises(TaskAlreadyInTaskGroup) as ctx: + tg.add(task) + + assert str(ctx.value) == "cannot add 'section_2.task' to 'section_1' (already in group 'section_2')" + + +def test_task_group_edge_modifier_chain(): + from airflow.sdk import Label, chain + + with DAG(dag_id="test", schedule=None, start_date=pendulum.DateTime(2022, 5, 20)) as dag: + start = EmptyOperator(task_id="sleep_3_seconds") + + with TaskGroup(group_id="group1") as tg: + t1 = EmptyOperator(task_id="dummy1") + t2 = EmptyOperator(task_id="dummy2") + + t3 = EmptyOperator(task_id="echo_done") + + # The case we are testing for is when a Label is inside a list -- meaning that we do tg.set_upstream + # instead of label.set_downstream + chain(start, [Label("branch three")], tg, t3) + + assert start.downstream_task_ids == {t1.node_id, t2.node_id} + assert t3.upstream_task_ids == {t1.node_id, t2.node_id} + assert tg.upstream_task_ids == set() + assert tg.downstream_task_ids == {t3.node_id} + # Check that we can perform a topological_sort + dag.topological_sort() + + +def test_mapped_task_group_id_prefix_task_id(): + from tests_common.test_utils.mock_operators import MockOperator + + with DAG(dag_id="d", schedule=None, start_date=DEFAULT_DATE) as dag: + t1 = MockOperator.partial(task_id="t1").expand(arg1=[]) + with TaskGroup("g"): + t2 = MockOperator.partial(task_id="t2").expand(arg1=[]) + + assert t1.task_id == "t1" + assert t2.task_id == "g.t2" + + dag.get_task("t1") == t1 + dag.get_task("g.t2") == t2 + + +def test_pass_taskgroup_output_to_task(): + """Test that the output of a task group can be passed to a task.""" + from airflow.sdk import task + + @task + def one(): + return 1 + + @task_group_decorator + def addition_task_group(num): + @task + def add_one(i): + return i + 1 + + return add_one(num) + + @task + def increment(num): + return num + 1 + + @dag(schedule=None, start_date=pendulum.DateTime(2022, 1, 1), default_args={"owner": "airflow"}) + def wrap(): + total_1 = one() + assert isinstance(total_1, XComArg) + total_2 = addition_task_group(total_1) + assert isinstance(total_2, XComArg) + total_3 = increment(total_2) + assert isinstance(total_3, XComArg) + + wrap() + + +def test_decorator_unknown_args(): + """Test that unknown args passed to the decorator cause an error at parse time""" + with pytest.raises(TypeError): + + @task_group_decorator(b=2) + def tg(): ... + + +def test_decorator_multiple_use_task(): + from airflow.sdk import task + + @dag("test-dag", schedule=None, start_date=DEFAULT_DATE) + def _test_dag(): + @task + def t(): + pass + + @task_group_decorator + def tg(): + for _ in range(3): + t() + + t() >> tg() >> t() + + test_dag = _test_dag() + assert test_dag.task_ids == [ + "t", # Start end. + "tg.t", + "tg.t__1", + "tg.t__2", + "t__1", # End node. + ] + + +def test_build_task_group_depended_by_task(): + """A decorator-based task group should be able to be used as a relative to operators.""" + from airflow.sdk import dag as dag_decorator, task + + @dag_decorator(schedule=None, start_date=pendulum.now()) + def build_task_group_depended_by_task(): + @task + def task_start(): + return "[Task_start]" + + @task + def task_end(): + return "[Task_end]" + + @task + def task_thing(value): + return f"[Task_thing {value}]" + + @task_group_decorator + def section_1(): + task_thing(1) + task_thing(2) + + task_start() >> section_1() >> task_end() + + dag = build_task_group_depended_by_task() + task_thing_1 = dag.task_dict["section_1.task_thing"] + task_thing_2 = dag.task_dict["section_1.task_thing__1"] + + # Tasks in the task group don't depend on each other; they both become + # downstreams to task_start, and upstreams to task_end. + assert task_thing_1.upstream_task_ids == task_thing_2.upstream_task_ids == {"task_start"} + assert task_thing_1.downstream_task_ids == task_thing_2.downstream_task_ids == {"task_end"} + + +def test_build_task_group_with_operators(): + """Tests DAG with Tasks created with *Operators and TaskGroup created with taskgroup decorator""" + from airflow.sdk import task + + def task_start(): + return "[Task_start]" + + def task_end(): + print("[ Task_End ]") + + # Creating Tasks + @task + def task_1(value): + return f"[ Task1 {value} ]" + + @task + def task_2(value): + return f"[ Task2 {value} ]" + + @task + def task_3(value): + print(f"[ Task3 {value} ]") + + # Creating TaskGroups + @task_group_decorator(group_id="section_1") + def section_a(value): + """TaskGroup for grouping related Tasks""" + return task_3(task_2(task_1(value))) + + logical_date = pendulum.parse("20201109") + with DAG( + dag_id="example_task_group_decorator_mix", + schedule=None, + start_date=logical_date, + tags=["example"], + ) as dag: + t_start = PythonOperator(task_id="task_start", python_callable=task_start, dag=dag) + sec_1 = section_a(t_start.output) + t_end = PythonOperator(task_id="task_end", python_callable=task_end, dag=dag) + sec_1.set_downstream(t_end) + + # Testing Tasks in DAG + assert set(dag.task_group.children.keys()) == {"section_1", "task_start", "task_end"} + assert set(dag.task_group.children["section_1"].children.keys()) == { + "section_1.task_2", + "section_1.task_3", + "section_1.task_1", + } + + # Testing Tasks downstream + assert dag.task_dict["task_start"].downstream_task_ids == {"section_1.task_1"} + assert dag.task_dict["section_1.task_3"].downstream_task_ids == {"task_end"}