This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c160ab70a00 Introduce serialized task groups; use them in core (#55169)
c160ab70a00 is described below
commit c160ab70a00faa41ef5996c4ba4a6580a64194ad
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Sep 4 09:42:51 2025 +0800
Introduce serialized task groups; use them in core (#55169)
---
.pre-commit-config.yaml | 2 +-
.../api_fastapi/core_api/services/ui/grid.py | 8 +-
.../api_fastapi/core_api/services/ui/task_group.py | 25 +-
.../airflow/example_dags/example_setup_teardown.py | 3 +-
.../src/airflow/example_dags/example_task_group.py | 3 +-
airflow-core/src/airflow/models/dagrun.py | 4 +-
airflow-core/src/airflow/models/mappedoperator.py | 31 ++-
airflow-core/src/airflow/models/taskinstance.py | 22 +-
.../airflow/serialization/definitions/__init__.py | 17 ++
.../airflow/serialization/definitions/taskgroup.py | 284 +++++++++++++++++++++
.../airflow/serialization/serialized_objects.py | 200 ++++++++++++---
.../ti_deps/deps/mapped_task_upstream_dep.py | 6 +-
.../src/airflow/ti_deps/deps/trigger_rule_dep.py | 21 +-
airflow-core/src/airflow/utils/dot_renderer.py | 15 +-
airflow-core/tests/unit/models/test_dagrun.py | 161 ++++--------
.../tests/unit/models/test_taskinstance.py | 57 ++---
.../unit/serialization/test_dag_serialization.py | 4 +-
.../unit/ti_deps/deps/test_trigger_rule_dep.py | 17 +-
airflow-core/tests/unit/utils/test_task_group.py | 53 ++--
devel-common/src/tests_common/pytest_plugin.py | 28 +-
task-sdk/src/airflow/sdk/definitions/dag.py | 14 +-
task-sdk/src/airflow/sdk/definitions/taskgroup.py | 38 ++-
22 files changed, 692 insertions(+), 321 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 3c454e6fb8c..623ab537ab9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1662,7 +1662,7 @@ repos:
^airflow-core/src/airflow/operators/subdag\.py$|
^airflow-core/src/airflow/plugins_manager\.py$|
^airflow-core/src/airflow/providers_manager\.py$|
- ^airflow-core/src/airflow/serialization/dag\.py$|
+ ^airflow-core/src/airflow/serialization/definitions/[_a-z]+\.py$|
^airflow-core/src/airflow/serialization/enums\.py$|
^airflow-core/src/airflow/serialization/helpers\.py$|
^airflow-core/src/airflow/serialization/serialized_objects\.py$|
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
index 124f526cd05..1f64ffcefa8 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
@@ -26,7 +26,7 @@ from airflow.api_fastapi.common.parameters import
state_priority
from airflow.api_fastapi.core_api.services.ui.task_group import
get_task_group_children_getter
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskmap import TaskMap
-from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
+from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
from airflow.serialization.serialized_objects import SerializedBaseOperator
log = structlog.get_logger(logger_name=__name__)
@@ -78,8 +78,8 @@ def _get_aggs_for_node(detail):
def _find_aggregates(
- node: TaskGroup | MappedTaskGroup | SerializedBaseOperator | TaskMap,
- parent_node: TaskGroup | MappedTaskGroup | SerializedBaseOperator |
TaskMap | None,
+ node: SerializedTaskGroup | SerializedBaseOperator | TaskMap,
+ parent_node: SerializedTaskGroup | SerializedBaseOperator | TaskMap | None,
ti_details: dict[str, list],
) -> Iterable[dict]:
"""Recursively fill the Task Group Map."""
@@ -98,7 +98,7 @@ def _find_aggregates(
}
return
- if isinstance(node, TaskGroup):
+ if isinstance(node, SerializedTaskGroup):
children = []
for child in get_task_group_children_getter()(node):
for child_node in _find_aggregates(node=child, parent_node=node,
ti_details=ti_details):
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py
index f88dca353c4..ed9a96718e9 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py
@@ -24,8 +24,7 @@ from functools import cache
from operator import methodcaller
from airflow.configuration import conf
-from airflow.models.mappedoperator import MappedOperator
-from airflow.sdk.definitions.taskgroup import MappedTaskGroup
+from airflow.models.mappedoperator import MappedOperator, is_mapped
from airflow.serialization.serialized_objects import SerializedBaseOperator
@@ -51,14 +50,14 @@ def task_group_to_dict(task_item_or_group,
parent_group_is_mapped=False):
node_operator["setup_teardown_type"] = "setup"
elif task.is_teardown:
node_operator["setup_teardown_type"] = "teardown"
- if isinstance(task, MappedOperator) or parent_group_is_mapped:
+ if is_mapped(task) or parent_group_is_mapped:
node_operator["is_mapped"] = True
return node_operator
task_group = task_item_or_group
- is_mapped = isinstance(task_group, MappedTaskGroup)
+ mapped = is_mapped(task_group)
children = [
- task_group_to_dict(child,
parent_group_is_mapped=parent_group_is_mapped or is_mapped)
+ task_group_to_dict(child,
parent_group_is_mapped=parent_group_is_mapped or mapped)
for child in get_task_group_children_getter()(task_group)
]
@@ -74,7 +73,7 @@ def task_group_to_dict(task_item_or_group,
parent_group_is_mapped=False):
"id": task_group.group_id,
"label": task_group.label,
"tooltip": task_group.tooltip,
- "is_mapped": is_mapped,
+ "is_mapped": mapped,
"children": children,
"type": "task",
}
@@ -83,9 +82,9 @@ def task_group_to_dict(task_item_or_group,
parent_group_is_mapped=False):
def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False):
"""Create a nested dict representation of this TaskGroup and its children
used to construct the Grid."""
if isinstance(task := task_item_or_group, (MappedOperator,
SerializedBaseOperator)):
- is_mapped = None
- if task.is_mapped or parent_group_is_mapped:
- is_mapped = True
+ mapped = None
+ if parent_group_is_mapped or is_mapped(task):
+ mapped = True
setup_teardown_type = None
if task.is_setup is True:
setup_teardown_type = "setup"
@@ -94,22 +93,22 @@ def task_group_to_dict_grid(task_item_or_group,
parent_group_is_mapped=False):
return {
"id": task.task_id,
"label": task.label,
- "is_mapped": is_mapped,
+ "is_mapped": mapped,
"children": None,
"setup_teardown_type": setup_teardown_type,
}
task_group = task_item_or_group
task_group_sort = get_task_group_children_getter()
- is_mapped_group = isinstance(task_group, MappedTaskGroup)
+ mapped = is_mapped(task_group)
children = [
- task_group_to_dict_grid(x,
parent_group_is_mapped=parent_group_is_mapped or is_mapped_group)
+ task_group_to_dict_grid(x,
parent_group_is_mapped=parent_group_is_mapped or mapped)
for x in task_group_sort(task_group)
]
return {
"id": task_group.group_id,
"label": task_group.label,
- "is_mapped": is_mapped_group or None,
+ "is_mapped": mapped or None,
"children": children or None,
}
diff --git a/airflow-core/src/airflow/example_dags/example_setup_teardown.py
b/airflow-core/src/airflow/example_dags/example_setup_teardown.py
index 052377736ea..cefa3b31463 100644
--- a/airflow-core/src/airflow/example_dags/example_setup_teardown.py
+++ b/airflow-core/src/airflow/example_dags/example_setup_teardown.py
@@ -22,8 +22,7 @@ from __future__ import annotations
import pendulum
from airflow.providers.standard.operators.bash import BashOperator
-from airflow.sdk import DAG
-from airflow.sdk.definitions.taskgroup import TaskGroup
+from airflow.sdk import DAG, TaskGroup
with DAG(
dag_id="example_setup_teardown",
diff --git a/airflow-core/src/airflow/example_dags/example_task_group.py
b/airflow-core/src/airflow/example_dags/example_task_group.py
index c882c269c47..39010441d86 100644
--- a/airflow-core/src/airflow/example_dags/example_task_group.py
+++ b/airflow-core/src/airflow/example_dags/example_task_group.py
@@ -23,8 +23,7 @@ import pendulum
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.sdk import DAG
-from airflow.sdk.definitions.taskgroup import TaskGroup
+from airflow.sdk import DAG, TaskGroup
# [START howto_task_group]
with DAG(
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index c8749685824..26d1d42e524 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -1676,9 +1676,7 @@ class DagRun(Base, LoggingMixin):
# Create the missing tasks, including mapped tasks
tis_to_create = self._create_tasks(
- # TODO (GH-52141): task_dict in scheduler should contain scheduler
- # types instead, but currently it inherits SDK's DAG.
- (task for task in cast("Iterable[Operator]",
dag.task_dict.values()) if task_filter(task)),
+ (task for task in dag.task_dict.values() if task_filter(task)),
task_creator,
session=session,
)
diff --git a/airflow-core/src/airflow/models/mappedoperator.py
b/airflow-core/src/airflow/models/mappedoperator.py
index 9a4a66a9fe8..310573985da 100644
--- a/airflow-core/src/airflow/models/mappedoperator.py
+++ b/airflow-core/src/airflow/models/mappedoperator.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import functools
import operator
-from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard
+from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard, overload
import attrs
import methodtools
@@ -31,7 +31,7 @@ from airflow.exceptions import AirflowException, NotMapped
from airflow.sdk import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.mappedoperator import MappedOperator as
TaskSDKMappedOperator
-from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
+from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup, SerializedTaskGroup
from airflow.serialization.enums import DagAttributeTypes
from airflow.serialization.serialized_objects import DEFAULT_OPERATOR_DEPS,
SerializedBaseOperator
from airflow.task.priority_strategy import PriorityWeightStrategy,
validate_and_load_priority_weight_strategy
@@ -57,8 +57,16 @@ if TYPE_CHECKING:
log = structlog.get_logger(__name__)
-def is_mapped(task: Operator) -> TypeGuard[MappedOperator]:
- return task.is_mapped
+@overload
+def is_mapped(obj: Operator) -> TypeGuard[MappedOperator]: ...
+
+
+@overload
+def is_mapped(obj: SerializedTaskGroup) ->
TypeGuard[SerializedMappedTaskGroup]: ...
+
+
+def is_mapped(obj: Operator | SerializedTaskGroup) -> TypeGuard[MappedOperator
| SerializedMappedTaskGroup]:
+ return obj.is_mapped
@attrs.define(
@@ -100,8 +108,11 @@ class MappedOperator(DAGNode):
start_from_trigger: bool = False
_needs_expansion: bool = True
- dag: SerializedDAG = attrs.field(init=False)
- task_group: TaskGroup = attrs.field(init=False)
+ # TODO (GH-52141): These should contain serialized containers, but
currently
+ # this class inherits from an SDK one.
+ dag: SerializedDAG = attrs.field(init=False) # type: ignore[assignment]
+ task_group: SerializedTaskGroup = attrs.field(init=False) # type:
ignore[assignment]
+
start_date: pendulum.DateTime | None = attrs.field(init=False,
default=None)
end_date: pendulum.DateTime | None = attrs.field(init=False, default=None)
upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
@@ -388,7 +399,7 @@ class MappedOperator(DAGNode):
return getattr(self, self._expand_input_attr)
# TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
- def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
+ def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]:
"""
Return mapped task groups this task belongs to.
@@ -401,7 +412,7 @@ class MappedOperator(DAGNode):
yield from group.iter_mapped_task_groups()
# TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
- def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
+ def get_closest_mapped_task_group(self) -> SerializedMappedTaskGroup |
None:
"""
Get the mapped task group "closest" to this task in the DAG.
@@ -504,7 +515,7 @@ def _(task: MappedOperator | TaskSDKMappedOperator, run_id:
str, *, session: Ses
@get_mapped_ti_count.register
-def _(group: TaskGroup, run_id: str, *, session: Session) -> int:
+def _(group: SerializedTaskGroup, run_id: str, *, session: Session) -> int:
"""
Return the number of instances a task in this group should be mapped to at
run time.
@@ -523,7 +534,7 @@ def _(group: TaskGroup, run_id: str, *, session: Session)
-> int:
def iter_mapped_task_group_lengths(group) -> Iterator[int]:
while group is not None:
- if isinstance(group, MappedTaskGroup):
+ if isinstance(group, SerializedMappedTaskGroup):
exp_input = group._expand_input
# TODO (GH-52141): 'group' here should be scheduler-bound and
returns scheduler expand input.
if not hasattr(exp_input, "get_total_map_length"):
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 682787141d2..05cd9ce357e 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -124,8 +124,8 @@ if TYPE_CHECKING:
from airflow.sdk import DAG
from airflow.sdk.api.datamodels._generated import AssetProfile
from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey,
AssetUriRef
- from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
from airflow.sdk.types import RuntimeTaskInstanceProtocol
+ from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.context import Context
@@ -1534,12 +1534,9 @@ class TaskInstance(Base, LoggingMixin):
assert original_task is not None
assert original_task.dag is not None
- serialized_task = SerializedDAG.deserialize_dag(
- SerializedDAG.serialize_dag(original_task.dag)
- ).task_dict[original_task.task_id]
- # TODO (GH-52141): task_dict in scheduler should contain scheduler
- # types instead, but currently it inherits SDK's DAG.
- self.task = cast("Operator", serialized_task)
+ self.task =
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(original_task.dag)).task_dict[
+ original_task.task_id
+ ]
res = self.check_and_change_state_before_execution(
verbose=verbose,
ignore_all_deps=ignore_all_deps,
@@ -2286,7 +2283,7 @@ class TaskInstance(Base, LoggingMixin):
)
-def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) ->
MappedTaskGroup | None:
+def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) ->
SerializedTaskGroup | None:
"""Given two operators, find their innermost common mapped task group."""
if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id:
return None
@@ -2295,16 +2292,15 @@ def _find_common_ancestor_mapped_group(node1: Operator,
node2: Operator) -> Mapp
return next(common_groups, None)
-def _is_further_mapped_inside(operator: Operator, container: TaskGroup) ->
bool:
+def _is_further_mapped_inside(operator: Operator, container:
SerializedTaskGroup) -> bool:
"""Whether given operator is *further* mapped inside a task group."""
- from airflow.models.mappedoperator import MappedOperator
- from airflow.sdk.definitions.taskgroup import MappedTaskGroup
+ from airflow.models.mappedoperator import is_mapped
- if isinstance(operator, MappedOperator):
+ if is_mapped(operator):
return True
task_group = operator.task_group
while task_group is not None and task_group.group_id != container.group_id:
- if isinstance(task_group, MappedTaskGroup):
+ if is_mapped(task_group):
return True
task_group = task_group.parent_group
return False
diff --git a/airflow-core/src/airflow/serialization/definitions/__init__.py
b/airflow-core/src/airflow/serialization/definitions/__init__.py
new file mode 100644
index 00000000000..217e5db9607
--- /dev/null
+++ b/airflow-core/src/airflow/serialization/definitions/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
new file mode 100644
index 00000000000..e26c6cfb4ae
--- /dev/null
+++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
@@ -0,0 +1,284 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import copy
+import functools
+import operator
+import weakref
+from typing import TYPE_CHECKING
+
+import attrs
+import methodtools
+
+from airflow.sdk.definitions._internal.node import DAGNode
+
+if TYPE_CHECKING:
+ from collections.abc import Generator, Iterator
+ from typing import Any, ClassVar
+
+ from airflow.models.expandinput import SchedulerExpandInput
+ from airflow.serialization.serialized_objects import SerializedDAG,
SerializedOperator
+
+
[email protected](kw_only=True, repr=False)
+class SerializedTaskGroup(DAGNode):
+ """Serialized representation of a TaskGroup used in protected processes."""
+
+ _group_id: str | None = attrs.field(alias="group_id")
+ group_display_name: str | None = attrs.field()
+ prefix_group_id: bool = attrs.field()
+ parent_group: SerializedTaskGroup | None = attrs.field()
+ dag: SerializedDAG = attrs.field()
+ tooltip: str = attrs.field()
+ default_args: dict[str, Any] = attrs.field(factory=dict)
+
+ # TODO: Are these actually useful?
+ ui_color: str = attrs.field(default="CornflowerBlue")
+ ui_fgcolor: str = attrs.field(default="#000")
+
+ children: dict[str, DAGNode] = attrs.field(factory=dict, init=False)
+ upstream_group_ids: set[str | None] = attrs.field(factory=set, init=False)
+ downstream_group_ids: set[str | None] = attrs.field(factory=set,
init=False)
+ upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
+ downstream_task_ids: set[str] = attrs.field(factory=set, init=False)
+
+ is_mapped: ClassVar[bool] = False
+
+ @staticmethod
+ def _iter_child(child):
+ """Iterate over the children of this TaskGroup."""
+ if isinstance(child, SerializedTaskGroup):
+ yield from child
+ else:
+ yield child
+
+ def __iter__(self):
+ for child in self.children.values():
+ yield from self._iter_child(child)
+
+ @property
+ def group_id(self) -> str | None:
+ if (
+ self._group_id
+ and self.parent_group
+ and self.parent_group.prefix_group_id
+ and self.parent_group._group_id
+ ):
+ return self.parent_group.child_id(self._group_id)
+ return self._group_id
+
+ @property
+ def label(self) -> str:
+ """group_id excluding parent's group_id used as the node label in
UI."""
+ return self.group_display_name or self._group_id or ""
+
+ @property
+ def node_id(self) -> str:
+ return self.group_id or ""
+
+ @property
+ def is_root(self) -> bool:
+ return not self._group_id
+
+ # TODO (GH-52141): This shouldn't need to be writable after serialization,
+ # but DAGNode defines the property as writable.
+ @property
+ def task_group(self) -> SerializedTaskGroup | None: # type:
ignore[override]
+ return self.parent_group
+
+ def child_id(self, label: str) -> str:
+ if self.prefix_group_id and (group_id := self.group_id):
+ return f"{group_id}.{label}"
+ return label
+
+ @property
+ def upstream_join_id(self) -> str:
+ return f"{self.group_id}.upstream_join_id"
+
+ @property
+ def downstream_join_id(self) -> str:
+ return f"{self.group_id}.downstream_join_id"
+
+ @property
+ def roots(self) -> list[DAGNode]:
+ return list(self.get_roots())
+
+ @property
+ def leaves(self) -> list[DAGNode]:
+ return list(self.get_leaves())
+
+ def get_roots(self) -> Generator[SerializedOperator, None, None]:
+ """Return a generator of tasks with no upstream dependencies within
the TaskGroup."""
+ tasks = list(self)
+ ids = {x.task_id for x in tasks}
+ for task in tasks:
+ if task.upstream_task_ids.isdisjoint(ids):
+ yield task
+
+ def get_leaves(self) -> Generator[SerializedOperator, None, None]:
+ """Return a generator of tasks with no downstream dependencies within
the TaskGroup."""
+ tasks = list(self)
+ ids = {x.task_id for x in tasks}
+
+ def has_non_teardown_downstream(task, exclude: str):
+ for down_task in task.downstream_list:
+ if down_task.task_id == exclude:
+ continue
+ if down_task.task_id not in ids:
+ continue
+ if not down_task.is_teardown:
+ return True
+ return False
+
+ def recurse_for_first_non_teardown(task):
+ for upstream_task in task.upstream_list:
+ if upstream_task.task_id not in ids:
+ # upstream task is not in task group
+ continue
+ elif upstream_task.is_teardown:
+ yield from recurse_for_first_non_teardown(upstream_task)
+ elif task.is_teardown and upstream_task.is_setup:
+ # don't go through the teardown-to-setup path
+ continue
+ # return unless upstream task already has non-teardown
downstream in group
+ elif not has_non_teardown_downstream(upstream_task,
exclude=task.task_id):
+ yield upstream_task
+
+ for task in tasks:
+ if task.downstream_task_ids.isdisjoint(ids):
+ if not task.is_teardown:
+ yield task
+ else:
+ yield from recurse_for_first_non_teardown(task)
+
+ def get_task_group_dict(self) -> dict[str | None, SerializedTaskGroup]:
+ """Create a flat dict of group_id: TaskGroup."""
+
+ def build_map(node: DAGNode) -> Generator[tuple[str | None,
SerializedTaskGroup]]:
+ if not isinstance(node, SerializedTaskGroup):
+ return
+ yield node.group_id, node
+ for child in node.children.values():
+ yield from build_map(child)
+
+ return dict(build_map(self))
+
+ def iter_tasks(self) -> Iterator[SerializedOperator]:
+ """Return an iterator of the child tasks."""
+ from airflow.models.mappedoperator import MappedOperator
+ from airflow.serialization.serialized_objects import
SerializedBaseOperator
+
+ groups_to_visit = [self]
+ while groups_to_visit:
+ for child in groups_to_visit.pop(0).children.values():
+ if isinstance(child, (MappedOperator, SerializedBaseOperator)):
+ yield child
+ elif isinstance(child, SerializedTaskGroup):
+ groups_to_visit.append(child)
+ else:
+ raise ValueError(
+ f"Encountered a DAGNode that is not a task or task "
+ f"group: {type(child).__module__}.{type(child)}"
+ )
+
+ def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]:
+ """
+ Find mapped task groups in the hierarchy.
+
+ Groups are returned from the closest to the outmost. If *self* is a
+ mapped task group, it is returned first.
+ """
+ group: SerializedTaskGroup | None = self
+ while group is not None:
+ if isinstance(group, SerializedMappedTaskGroup):
+ yield group
+ group = group.parent_group
+
+ def topological_sort(self) -> list[DAGNode]:
+ """
+ Sorts children in topographical order.
+
+ A task in the result would come after any of its upstream dependencies.
+ """
+ # This uses a modified version of Kahn's Topological Sort algorithm to
+ # not have to pre-compute the "in-degree" of the nodes.
+ graph_unsorted = copy.copy(self.children)
+ graph_sorted: list[DAGNode] = []
+ if not self.children:
+ return graph_sorted
+ while graph_unsorted:
+ for node in list(graph_unsorted.values()):
+ for edge in node.upstream_list:
+ if edge.node_id in graph_unsorted:
+ break
+ # Check for task's group is a child (or grand child) of
this TG,
+ tg = edge.task_group
+ while tg:
+ if tg.node_id in graph_unsorted:
+ break
+ tg = tg.parent_group
+ else:
+ del graph_unsorted[node.node_id]
+ graph_sorted.append(node)
+ return graph_sorted
+
+ def add(self, node: DAGNode) -> DAGNode:
+ # Set the TG first, as setting it might change the return value of
node_id!
+ node.task_group = weakref.proxy(self)
+ if isinstance(node, SerializedTaskGroup):
+ if self.dag:
+ node.dag = self.dag
+ self.children[node.node_id] = node
+ return node
+
+
[email protected](kw_only=True, repr=False)
+class SerializedMappedTaskGroup(SerializedTaskGroup):
+ """Serialized representation of a MappedTaskGroup used in protected
processes."""
+
+ _expand_input: SchedulerExpandInput = attrs.field(alias="expand_input")
+
+ is_mapped: ClassVar[bool] = True
+
+ @methodtools.lru_cache(maxsize=None)
+ def get_parse_time_mapped_ti_count(self) -> int:
+ """
+ Return the number of instances a task in this group should be mapped
to.
+
+ This only considers literal mapped arguments, and would return *None*
+ when any non-literal values are used for mapping.
+
+ If this group is inside mapped task groups, all the nested counts are
+ multiplied and accounted.
+
+ :raise NotFullyPopulated: If any non-literal mapped arguments are
encountered.
+ :return: The total number of mapped instances each task should have.
+ """
+ return functools.reduce(
+ operator.mul,
+ (g._expand_input.get_parse_time_mapped_ti_count() for g in
self.iter_mapped_task_groups()),
+ )
+
+ def iter_mapped_dependencies(self) -> Iterator[SerializedOperator]:
+ """Upstream dependencies that provide XComs used by this mapped task
group."""
+ from airflow.models.xcom_arg import SchedulerXComArg
+
+ for op, _ in SchedulerXComArg.iter_xcom_references(self._expand_input):
+ yield op
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 24ab2e106f1..090f5c21432 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -33,7 +33,18 @@ from collections.abc import Collection, Iterable, Iterator,
Mapping, Sequence
from functools import cached_property, lru_cache
from inspect import signature
from textwrap import dedent
-from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple,
TypeAlias, TypeVar, cast, overload
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ ClassVar,
+ Literal,
+ NamedTuple,
+ TypeAlias,
+ TypeGuard,
+ TypeVar,
+ cast,
+ overload,
+)
import attrs
import lazy_object_proxy
@@ -76,6 +87,7 @@ from airflow.sdk.definitions.taskgroup import
MappedTaskGroup, TaskGroup
from airflow.sdk.definitions.xcom_arg import serialize_xcom_arg
from airflow.sdk.execution_time.context import OutletEventAccessor,
OutletEventAccessors
from airflow.serialization.dag_dependency import DagDependency
+from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup, SerializedTaskGroup
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
from airflow.serialization.json_schema import load_dag_schema
@@ -1235,8 +1247,10 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
_task_display_name: str | None
_weight_rule: str | PriorityWeightStrategy = "downstream"
- dag: SerializedDAG | None = None
- task_group: TaskGroup | None = None
+ # TODO (GH-52141): These should contain serialized containers, but
currently
+ # this class inherits from an SDK one.
+ dag: SerializedDAG | None = None # type: ignore[assignment]
+ task_group: SerializedTaskGroup | None = None # type: ignore[assignment]
allow_nested_operators: bool = True
depends_on_past: bool = False
@@ -1664,7 +1678,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
setattr(op, "start_from_trigger",
bool(encoded_op.get("start_from_trigger", False)))
@staticmethod
- def set_task_dag_references(task: SerializedOperator, dag: SerializedDAG)
-> None:
+ def set_task_dag_references(task: SerializedOperator | MappedOperator,
dag: SerializedDAG) -> None:
"""
Handle DAG references on an operator.
@@ -2147,7 +2161,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
return result
- def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator |
MappedTaskGroup]:
+ def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator |
SerializedMappedTaskGroup]:
"""
Return mapped nodes that are direct dependencies of the current task.
@@ -2164,7 +2178,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
:meth:`iter_mapped_dependants` instead.
"""
- def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
+ def _walk_group(group: SerializedTaskGroup) -> Iterable[tuple[str,
DAGNode]]:
"""
Recursively walk children in a task group.
@@ -2173,7 +2187,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
"""
for key, child in group.children.items():
yield key, child
- if isinstance(child, TaskGroup):
+ if isinstance(child, SerializedTaskGroup):
yield from _walk_group(child)
if not (dag := self.dag):
@@ -2181,12 +2195,12 @@ class SerializedBaseOperator(DAGNode,
BaseSerialization):
for key, child in _walk_group(dag.task_group):
if key == self.node_id:
continue
- if not isinstance(child, MappedOperator | MappedTaskGroup):
+ if not isinstance(child, MappedOperator |
SerializedMappedTaskGroup):
continue
if self.node_id in child.upstream_task_ids:
yield child
- def iter_mapped_dependants(self) -> Iterator[MappedOperator |
MappedTaskGroup]:
+ def iter_mapped_dependants(self) -> Iterator[MappedOperator |
SerializedMappedTaskGroup]:
"""
Return mapped nodes that depend on the current task the expansion.
@@ -2202,7 +2216,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
)
# TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
- def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
+ def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]:
"""
Return mapped task groups this task belongs to.
@@ -2215,7 +2229,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
yield from group.iter_mapped_task_groups()
# TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
- def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
+ def get_closest_mapped_task_group(self) -> SerializedMappedTaskGroup |
None:
"""
Get the mapped task group "closest" to this task in the DAG.
@@ -2310,7 +2324,6 @@ def _create_orm_dagrun(
return run
[email protected](hash=False, repr=False, eq=False, slots=False)
class SerializedDAG(DAG, BaseSerialization):
"""
A JSON serializable representation of DAG.
@@ -2322,10 +2335,15 @@ class SerializedDAG(DAG, BaseSerialization):
_decorated_fields: ClassVar[set[str]] = {"default_args", "access_control"}
- last_loaded: datetime.datetime | None = attrs.field(init=False,
factory=utcnow)
+ # TODO (GH-52141): These should contain serialized containers, but
currently
+ # this class inherits from an SDK one.
+ task_group: SerializedTaskGroup # type: ignore[assignment]
+ task_dict: dict[str, SerializedBaseOperator | SerializedMappedOperator] #
type: ignore[assignment]
+
+ last_loaded: datetime.datetime
# this will only be set at serialization time
# it's only use is for determining the relative fileloc based only on the
serialize dag
- _processor_dags_folder: str = attrs.field(init=False)
+ _processor_dags_folder: str
@staticmethod
def __get_constructor_defaults():
@@ -2404,6 +2422,7 @@ class SerializedDAG(DAG, BaseSerialization):
) -> SerializedDAG:
"""Handle the main Dag deserialization logic."""
dag = SerializedDAG(dag_id=encoded_dag["dag_id"], schedule=None)
+ dag.last_loaded = utcnow()
# Note: Context is passed explicitly through method parameters, no
class attributes needed
@@ -2449,18 +2468,24 @@ class SerializedDAG(DAG, BaseSerialization):
tg = TaskGroupSerialization.deserialize_task_group(
encoded_dag["task_group"],
None,
- # TODO (GH-52141): SerializedDAG's task_dict should contain
- # scheduler types instead, but currently it inherits SDK's DAG.
- cast("dict[str, SerializedOperator]", dag.task_dict),
+ dag.task_dict,
dag,
)
object.__setattr__(dag, "task_group", tg)
else:
- # This must be old data that had no task_group. Create a root
TaskGroup and add
- # all tasks to it.
- object.__setattr__(dag, "task_group", TaskGroup.create_root(dag))
+ # This must be old data that had no task_group. Create a root
+ # task group and add all tasks to it.
+ tg = SerializedTaskGroup(
+ group_id=None,
+ group_display_name=None,
+ prefix_group_id=True,
+ parent_group=None,
+ dag=dag,
+ tooltip="",
+ )
+ object.__setattr__(dag, "task_group", tg)
for task in dag.tasks:
- dag.task_group.add(task)
+ tg.add(task)
# Set has_on_*_callbacks to True if they exist in Serialized blob as
False is the default
if "has_on_success_callback" in encoded_dag:
@@ -2475,10 +2500,8 @@ class SerializedDAG(DAG, BaseSerialization):
for k in keys_to_set_none:
setattr(dag, k, None)
- # TODO (GH-52141): SerializedDAG's task_dict should contain scheduler
- # types instead, but currently it inherits SDK's DAG.
- for task in dag.task_dict.values():
-
SerializedBaseOperator.set_task_dag_references(cast("SerializedOperator",
task), dag)
+ for t in dag.task_dict.values():
+ SerializedBaseOperator.set_task_dag_references(t, dag)
return dag
@@ -2705,6 +2728,125 @@ class SerializedDAG(DAG, BaseSerialization):
dag_op.update_dag_asset_expression(orm_dags=orm_dags,
orm_assets=orm_assets)
session.flush()
+ # TODO (GH-52141): This needs to take scheduler types, but currently it
inherits SDK's DAG.
+ # TODO (GH-52141): This shouldn't need to be writable, but SDK's DAG
defines it as such.
+ @property # type: ignore[misc]
+ def tasks(self) -> Sequence[SerializedOperator]: # type: ignore[override]
+ return list(self.task_dict.values())
+
+ def partial_subset(
+ self,
+ task_ids: str | Iterable[str],
+ include_downstream: bool = False,
+ include_upstream: bool = True,
+ include_direct_upstream: bool = False,
+ ):
+ from airflow.models.mappedoperator import MappedOperator as
SerializedMappedOperator
+
+ def is_task(obj) -> TypeGuard[SerializedOperator]:
+ return isinstance(obj, (SerializedMappedOperator,
SerializedBaseOperator))
+
+ # deep-copying self.task_dict and self.task_group takes a long time,
and we don't want all
+ # the tasks anyway, so we copy the tasks manually later
+ memo = {id(self.task_dict): None, id(self.task_group): None}
+ dag = copy.deepcopy(self, memo)
+
+ if isinstance(task_ids, str):
+ matched_tasks = [t for t in self.tasks if task_ids in t.task_id]
+ else:
+ matched_tasks = [t for t in self.tasks if t.task_id in task_ids]
+
+ also_include_ids: set[str] = set()
+ for t in matched_tasks:
+ if include_downstream:
+ for rel in t.get_flat_relatives(upstream=False):
+ also_include_ids.add(rel.task_id)
+ if rel not in matched_tasks: # if it's in there, we're
already processing it
+ # need to include setups and teardowns for tasks that
are in multiple
+ # non-collinear setup/teardown paths
+ if not rel.is_setup and not rel.is_teardown:
+ also_include_ids.update(
+ x.task_id for x in
rel.get_upstreams_only_setups_and_teardowns()
+ )
+ if include_upstream:
+ also_include_ids.update(x.task_id for x in
t.get_upstreams_follow_setups())
+ else:
+ if not t.is_setup and not t.is_teardown:
+ also_include_ids.update(x.task_id for x in
t.get_upstreams_only_setups_and_teardowns())
+ if t.is_setup and not include_downstream:
+ also_include_ids.update(x.task_id for x in t.downstream_list
if x.is_teardown)
+
+ also_include: list[SerializedOperator] = [self.task_dict[x] for x in
also_include_ids]
+ direct_upstreams: list[SerializedOperator] = []
+ if include_direct_upstream:
+ for t in itertools.chain(matched_tasks, also_include):
+ upstream = (u for u in t.upstream_list if is_task(u))
+ direct_upstreams.extend(upstream)
+
+ # Make sure to not recursively deepcopy the dag or task_group while
copying the task.
+ # task_group is reset later
+ def _deepcopy_task(t) -> SerializedOperator:
+ memo.setdefault(id(t.task_group), None)
+ return copy.deepcopy(t, memo)
+
+ # Compiling the unique list of tasks that made the cut
+ dag.task_dict = {
+ t.task_id: _deepcopy_task(t)
+ for t in itertools.chain(matched_tasks, also_include,
direct_upstreams)
+ }
+
+ def filter_task_group(group, parent_group):
+ """Exclude tasks not included in the partial dag from the given
TaskGroup."""
+ # We want to deepcopy _most but not all_ attributes of the task
group, so we create a shallow copy
+ # and then manually deep copy the instances. (memo argument to
deepcopy only works for instances
+ # of classes, not "native" properties of an instance)
+ copied = copy.copy(group)
+
+ memo[id(group.children)] = {}
+ if parent_group:
+ memo[id(group.parent_group)] = parent_group
+ for attr in type(group).__slots__:
+ value = getattr(group, attr)
+ value = copy.deepcopy(value, memo)
+ object.__setattr__(copied, attr, value)
+
+ proxy = weakref.proxy(copied)
+
+ for child in group.children.values():
+ if is_task(child):
+ if child.task_id in dag.task_dict:
+ task = copied.children[child.task_id] =
dag.task_dict[child.task_id]
+ task.task_group = proxy
+ else:
+ filtered_child = filter_task_group(child, proxy)
+
+ # Only include this child TaskGroup if it is non-empty.
+ if filtered_child.children:
+ copied.children[child.group_id] = filtered_child
+
+ return copied
+
+ object.__setattr__(dag, "task_group",
filter_task_group(self.task_group, None))
+
+ # Removing upstream/downstream references to tasks and TaskGroups that
did not make
+ # the cut.
+ groups = dag.task_group.get_task_group_dict()
+ for g in groups.values():
+ g.upstream_group_ids.intersection_update(groups)
+ g.downstream_group_ids.intersection_update(groups)
+ g.upstream_task_ids.intersection_update(dag.task_dict)
+ g.downstream_task_ids.intersection_update(dag.task_dict)
+
+ for t in dag.tasks:
+ # Removing upstream/downstream references to tasks that did not
+ # make the cut
+ t.upstream_task_ids.intersection_update(dag.task_dict)
+ t.downstream_task_ids.intersection_update(dag.task_dict)
+
+ dag.partial = len(dag.tasks) < len(self.tasks)
+
+ return dag
+
@cached_property
def _time_restriction(self) -> TimeRestriction:
start_dates = [t.start_date for t in self.tasks if t.start_date]
@@ -3416,10 +3558,10 @@ class TaskGroupSerialization(BaseSerialization):
def deserialize_task_group(
cls,
encoded_group: dict[str, Any],
- parent_group: TaskGroup | None,
+ parent_group: SerializedTaskGroup | None,
task_dict: dict[str, SerializedOperator],
dag: SerializedDAG,
- ) -> TaskGroup:
+ ) -> SerializedTaskGroup:
"""Deserializes a TaskGroup from a JSON object."""
group_id = cls.deserialize(encoded_group["_group_id"])
kwargs = {
@@ -3429,10 +3571,10 @@ class TaskGroupSerialization(BaseSerialization):
kwargs["group_display_name"] =
cls.deserialize(encoded_group.get("group_display_name", ""))
if not encoded_group.get("is_mapped"):
- group = TaskGroup(group_id=group_id, parent_group=parent_group,
dag=dag, **kwargs)
+ group = SerializedTaskGroup(group_id=group_id,
parent_group=parent_group, dag=dag, **kwargs)
else:
xi = encoded_group["expand_input"]
- group = MappedTaskGroup(
+ group = SerializedMappedTaskGroup(
group_id=group_id,
parent_group=parent_group,
dag=dag,
diff --git a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
index d5922074ef5..5e00d4b7b1b 100644
--- a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
@@ -18,7 +18,7 @@
from __future__ import annotations
from collections.abc import Iterator
-from typing import TYPE_CHECKING, TypeAlias, cast
+from typing import TYPE_CHECKING, TypeAlias
from sqlalchemy import select
@@ -63,9 +63,7 @@ class MappedTaskUpstreamDep(BaseTIDep):
elif is_mapped(ti.task):
mapped_dependencies = ti.task.iter_mapped_dependencies()
elif (task_group := ti.task.get_closest_mapped_task_group()) is not
None:
- # TODO (GH-52141): Task group in scheduler needs to return
scheduler
- # types instead, but currently the scheduler uses SDK's TaskGroup.
- mapped_dependencies = cast("Iterator[Operator]",
task_group.iter_mapped_dependencies())
+ mapped_dependencies = task_group.iter_mapped_dependencies()
else:
return
diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 1298518b96e..971b156a067 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -21,12 +21,11 @@ import collections.abc
import functools
from collections import Counter
from collections.abc import Iterator, KeysView
-from typing import TYPE_CHECKING, NamedTuple, cast
+from typing import TYPE_CHECKING, NamedTuple
from sqlalchemy import and_, func, or_, select
from airflow.models.taskinstance import PAST_DEPENDS_MET
-from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.task.trigger_rule import TriggerRule as TR
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.state import TaskInstanceState
@@ -35,10 +34,8 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnOperators
- from airflow import DAG
- from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
- from airflow.serialization.serialized_objects import SerializedBaseOperator
+ from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
@@ -131,6 +128,7 @@ class TriggerRuleDep(BaseTIDep):
"""
from airflow.exceptions import NotMapped
from airflow.models.expandinput import NotFullyPopulated
+ from airflow.models.mappedoperator import is_mapped
from airflow.models.taskinstance import TaskInstance
@functools.lru_cache
@@ -148,9 +146,7 @@ class TriggerRuleDep(BaseTIDep):
return get_mapped_ti_count(ti.task, ti.run_id, session=session)
- def _iter_expansion_dependencies(task_group: MappedTaskGroup) ->
Iterator[str]:
- from airflow.models.mappedoperator import is_mapped
-
+ def _iter_expansion_dependencies(task_group:
SerializedMappedTaskGroup) -> Iterator[str]:
if (task := ti.task) is not None and is_mapped(task):
for op in task.iter_mapped_dependencies():
yield op.task_id
@@ -172,9 +168,10 @@ class TriggerRuleDep(BaseTIDep):
"""
if TYPE_CHECKING:
assert ti.task
- assert isinstance(ti.task.dag, DAG)
+ assert ti.task.dag
+ assert ti.task.task_group
- if isinstance(ti.task.task_group, MappedTaskGroup):
+ if is_mapped(ti.task.task_group):
is_fast_triggered = ti.task.trigger_rule in (TR.ONE_SUCCESS,
TR.ONE_FAILED, TR.ONE_DONE)
if is_fast_triggered and upstream_id not in set(
_iter_expansion_dependencies(task_group=ti.task.task_group)
@@ -186,9 +183,7 @@ class TriggerRuleDep(BaseTIDep):
except (NotFullyPopulated, NotMapped):
return None
return ti.get_relevant_upstream_map_indexes(
- # TODO (GH-52141): task_dict in scheduler should contain
- # scheduler types instead, but currently it inherits SDK's DAG.
- upstream=cast("MappedOperator | SerializedBaseOperator",
ti.task.dag.task_dict[upstream_id]),
+ upstream=ti.task.dag.task_dict[upstream_id],
ti_count=expanded_ti_count,
session=session,
)
diff --git a/airflow-core/src/airflow/utils/dot_renderer.py
b/airflow-core/src/airflow/utils/dot_renderer.py
index 66b83492269..259b4ced252 100644
--- a/airflow-core/src/airflow/utils/dot_renderer.py
+++ b/airflow-core/src/airflow/utils/dot_renderer.py
@@ -24,9 +24,10 @@ import warnings
from typing import TYPE_CHECKING, Any
from airflow.exceptions import AirflowException
-from airflow.sdk import DAG, BaseOperator
+from airflow.models.mappedoperator import MappedOperator as
SerializedMappedOperator
+from airflow.sdk import DAG, BaseOperator, TaskGroup
from airflow.sdk.definitions.mappedoperator import MappedOperator
-from airflow.sdk.definitions.taskgroup import TaskGroup
+from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.dag_edges import dag_edges
from airflow.utils.state import State
@@ -69,7 +70,7 @@ def _refine_color(color: str):
def _draw_task(
- task: BaseOperator | MappedOperator | SerializedBaseOperator,
+ task: BaseOperator | MappedOperator | SerializedBaseOperator |
SerializedMappedOperator,
parent_graph: graphviz.Digraph,
states_by_task_id: dict[Any, Any] | None,
) -> None:
@@ -95,7 +96,9 @@ def _draw_task(
def _draw_task_group(
- task_group: TaskGroup, parent_graph: graphviz.Digraph, states_by_task_id:
dict[str, str] | None
+ task_group: TaskGroup | SerializedTaskGroup,
+ parent_graph: graphviz.Digraph,
+ states_by_task_id: dict[str, str] | None,
) -> None:
"""Draw the given task_group and its children on the given parent_graph."""
# Draw joins
@@ -136,10 +139,10 @@ def _draw_nodes(
node: DependencyMixin, parent_graph: graphviz.Digraph, states_by_task_id:
dict[str, str] | None
) -> None:
"""Draw the node and its children on the given parent_graph recursively."""
- if isinstance(node, (BaseOperator, MappedOperator,
SerializedBaseOperator)):
+ if isinstance(node, (BaseOperator, MappedOperator, SerializedBaseOperator,
SerializedMappedOperator)):
_draw_task(node, parent_graph, states_by_task_id)
else:
- if not isinstance(node, TaskGroup):
+ if not isinstance(node, (SerializedTaskGroup, TaskGroup)):
raise AirflowException(f"The node {node} should be TaskGroup and
is not")
# Draw TaskGroup
if node.is_root:
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index 8d92eab5aa0..369f530077b 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -240,14 +240,11 @@ class TestDagRun:
schedule=datetime.timedelta(days=1),
start_date=timezone.datetime(2017, 1, 1),
) as dag:
- ...
- dag_task1 = ShortCircuitOperator(
- task_id="test_short_circuit_false", dag=dag,
python_callable=lambda: False
- )
- dag_task2 = EmptyOperator(task_id="test_state_skipped1", dag=dag)
- dag_task3 = EmptyOperator(task_id="test_state_skipped2", dag=dag)
- dag_task1.set_downstream(dag_task2)
- dag_task2.set_downstream(dag_task3)
+ dag_task1 =
ShortCircuitOperator(task_id="test_short_circuit_false", python_callable=bool)
+ dag_task2 = EmptyOperator(task_id="test_state_skipped1")
+ dag_task3 = EmptyOperator(task_id="test_state_skipped2")
+ dag_task1.set_downstream(dag_task2)
+ dag_task2.set_downstream(dag_task3)
initial_task_states = {
"test_short_circuit_false": TaskInstanceState.SUCCESS,
@@ -268,14 +265,11 @@ class TestDagRun:
schedule=datetime.timedelta(days=1),
start_date=timezone.datetime(2017, 1, 1),
) as dag:
- ...
- dag_task1 = ShortCircuitOperator(
- task_id="test_short_circuit_false", dag=dag,
python_callable=lambda: False
- )
- dag_task2 = EmptyOperator(task_id="test_state_skipped1", dag=dag)
- dag_task3 = EmptyOperator(task_id="test_state_skipped2", dag=dag)
- dag_task1.set_downstream(dag_task2)
- dag_task2.set_downstream(dag_task3)
+ dag_task1 =
ShortCircuitOperator(task_id="test_short_circuit_false", python_callable=bool)
+ dag_task2 = EmptyOperator(task_id="test_state_skipped1")
+ dag_task3 = EmptyOperator(task_id="test_state_skipped2")
+ dag_task1.set_downstream(dag_task2)
+ dag_task2.set_downstream(dag_task3)
initial_task_states = {
"test_short_circuit_false": TaskInstanceState.REMOVED,
@@ -397,19 +391,15 @@ class TestDagRun:
start_date=datetime.datetime(2017, 1, 1),
on_success_callback=on_success_callable,
) as dag:
- ...
- dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
- dag_task2 = EmptyOperator(task_id="test_state_succeeded2", dag=dag)
- dag_task1.set_downstream(dag_task2)
+ dag_task1 = EmptyOperator(task_id="test_state_succeeded1")
+ dag_task2 = EmptyOperator(task_id="test_state_succeeded2")
+ dag_task1.set_downstream(dag_task2)
initial_task_states = {
"test_state_succeeded1": TaskInstanceState.SUCCESS,
"test_state_succeeded2": TaskInstanceState.SUCCESS,
}
- # Scheduler uses Serialized DAG -- so use that instead of the Actual
DAG
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
dag_run = self.create_dag_run(dag=dag,
task_states=initial_task_states, session=session)
_, callback = dag_run.update_state()
assert dag_run.state == DagRunState.SUCCESS
@@ -426,9 +416,8 @@ class TestDagRun:
start_date=datetime.datetime(2017, 1, 1),
on_failure_callback=on_failure_callable,
) as dag:
- ...
- dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
- dag_task2 = EmptyOperator(task_id="test_state_failed2", dag=dag)
+ dag_task1 = EmptyOperator(task_id="test_state_succeeded1")
+ dag_task2 = EmptyOperator(task_id="test_state_failed2")
initial_task_states = {
"test_state_succeeded1": TaskInstanceState.SUCCESS,
@@ -436,9 +425,6 @@ class TestDagRun:
}
dag_task1.set_downstream(dag_task2)
- # Scheduler uses Serialized DAG -- so use that instead of the Actual
DAG
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
dag_run = self.create_dag_run(dag=dag,
task_states=initial_task_states, session=session)
_, callback = dag_run.update_state()
assert dag_run.state == DagRunState.FAILED
@@ -481,27 +467,21 @@ class TestDagRun:
assert dag_run.state == DagRunState.SUCCESS
mock_on_success.assert_called_once()
- def test_start_dr_spans_if_needed_new_span(self, testing_dag_bundle,
dag_maker, session):
+ def test_start_dr_spans_if_needed_new_span(self, dag_maker, session):
with dag_maker(
dag_id="test_start_dr_spans_if_needed_new_span",
schedule=datetime.timedelta(days=1),
start_date=datetime.datetime(2017, 1, 1),
) as dag:
- ...
- SerializedDAG.bulk_write_to_db("testing", None, dags=[dag],
session=session)
-
- dag_task1 = EmptyOperator(task_id="test_task1", dag=dag)
- dag_task2 = EmptyOperator(task_id="test_task2", dag=dag)
- dag_task1.set_downstream(dag_task2)
+ dag_task1 = EmptyOperator(task_id="test_task1")
+ dag_task2 = EmptyOperator(task_id="test_task2")
+ dag_task1.set_downstream(dag_task2)
initial_task_states = {
"test_task1": TaskInstanceState.QUEUED,
"test_task2": TaskInstanceState.QUEUED,
}
- # Scheduler uses Serialized DAG -- so use that instead of the Actual
DAG
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
dag_run = self.create_dag_run(dag=dag,
task_states=initial_task_states, session=session)
active_spans = ThreadSafeDict()
@@ -518,27 +498,21 @@ class TestDagRun:
assert dag_run.span_status == SpanStatus.ACTIVE
assert dag_run.active_spans.get("dr:" + str(dag_run.id)) is not None
- def test_start_dr_spans_if_needed_span_with_continuance(self,
testing_dag_bundle, dag_maker, session):
+ def test_start_dr_spans_if_needed_span_with_continuance(self, dag_maker,
session):
with dag_maker(
dag_id="test_start_dr_spans_if_needed_span_with_continuance",
schedule=datetime.timedelta(days=1),
start_date=datetime.datetime(2017, 1, 1),
) as dag:
- ...
- SerializedDAG.bulk_write_to_db("testing", None, dags=[dag],
session=session)
-
- dag_task1 = EmptyOperator(task_id="test_task1", dag=dag)
- dag_task2 = EmptyOperator(task_id="test_task2", dag=dag)
- dag_task1.set_downstream(dag_task2)
+ dag_task1 = EmptyOperator(task_id="test_task1")
+ dag_task2 = EmptyOperator(task_id="test_task2")
+ dag_task1.set_downstream(dag_task2)
initial_task_states = {
"test_task1": TaskInstanceState.RUNNING,
"test_task2": TaskInstanceState.QUEUED,
}
- # Scheduler uses Serialized DAG -- so use that instead of the Actual
DAG
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
dag_run = self.create_dag_run(dag=dag,
task_states=initial_task_states, session=session)
active_spans = ThreadSafeDict()
@@ -570,21 +544,15 @@ class TestDagRun:
schedule=datetime.timedelta(days=1),
start_date=datetime.datetime(2017, 1, 1),
) as dag:
- ...
- SerializedDAG.bulk_write_to_db("testing", None, dags=[dag],
session=session)
-
- dag_task1 = EmptyOperator(task_id="test_task1", dag=dag)
- dag_task2 = EmptyOperator(task_id="test_task2", dag=dag)
- dag_task1.set_downstream(dag_task2)
+ dag_task1 = EmptyOperator(task_id="test_task1")
+ dag_task2 = EmptyOperator(task_id="test_task2")
+ dag_task1.set_downstream(dag_task2)
initial_task_states = {
"test_task1": TaskInstanceState.SUCCESS,
"test_task2": TaskInstanceState.SUCCESS,
}
- # Scheduler uses Serialized DAG -- so use that instead of the Actual
DAG
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
dag_run = self.create_dag_run(dag=dag,
task_states=initial_task_states, session=session)
active_spans = ThreadSafeDict()
@@ -612,21 +580,15 @@ class TestDagRun:
schedule=datetime.timedelta(days=1),
start_date=datetime.datetime(2017, 1, 1),
) as dag:
- ...
- SerializedDAG.bulk_write_to_db("testing", None, dags=[dag],
session=session)
-
- dag_task1 = EmptyOperator(task_id="test_task1", dag=dag)
- dag_task2 = EmptyOperator(task_id="test_task2", dag=dag)
- dag_task1.set_downstream(dag_task2)
+ dag_task1 = EmptyOperator(task_id="test_task1")
+ dag_task2 = EmptyOperator(task_id="test_task2")
+ dag_task1.set_downstream(dag_task2)
initial_task_states = {
"test_task1": TaskInstanceState.SUCCESS,
"test_task2": TaskInstanceState.SUCCESS,
}
- # Scheduler uses Serialized DAG -- so use that instead of the Actual
DAG
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
-
dag_run = self.create_dag_run(dag=dag,
task_states=initial_task_states, session=session)
active_spans = ThreadSafeDict()
@@ -652,23 +614,18 @@ class TestDagRun:
start_date=datetime.datetime(2017, 1, 1),
on_success_callback=on_success_callable,
) as dag:
- ...
+ dag_task1 = EmptyOperator(task_id="test_state_succeeded1")
+ dag_task2 = EmptyOperator(task_id="test_state_succeeded2")
+ dag_task1.set_downstream(dag_task2)
dm = DagModel.get_dagmodel(dag.dag_id, session=session)
dm.relative_fileloc = relative_fileloc
session.merge(dm)
session.commit()
- dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
- dag_task2 = EmptyOperator(task_id="test_state_succeeded2", dag=dag)
- dag_task1.set_downstream(dag_task2)
-
initial_task_states = {
"test_state_succeeded1": TaskInstanceState.SUCCESS,
"test_state_succeeded2": TaskInstanceState.SUCCESS,
}
-
- # Scheduler uses Serialized DAG -- so use that instead of the Actual
DAG
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
dag.relative_fileloc = relative_fileloc
SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag),
bundle_name="dag_maker")
session.commit()
@@ -704,23 +661,18 @@ class TestDagRun:
start_date=datetime.datetime(2017, 1, 1),
on_failure_callback=on_failure_callable,
) as dag:
- ...
+ dag_task1 = EmptyOperator(task_id="test_state_succeeded1")
+ dag_task2 = EmptyOperator(task_id="test_state_failed2")
+ dag_task1.set_downstream(dag_task2)
dm = DagModel.get_dagmodel(dag.dag_id, session=session)
dm.relative_fileloc = relative_fileloc
session.merge(dm)
session.commit()
- dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
- dag_task2 = EmptyOperator(task_id="test_state_failed2", dag=dag)
- dag_task1.set_downstream(dag_task2)
-
initial_task_states = {
"test_state_succeeded1": TaskInstanceState.SUCCESS,
"test_state_failed2": TaskInstanceState.FAILED,
}
-
- # Scheduler uses Serialized DAG -- so use that instead of the Actual
DAG
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
dag.relative_fileloc = relative_fileloc
SerializedDagModel.write_dag(LazyDeserializedDAG.from_dag(dag),
bundle_name="dag_maker")
session.commit()
@@ -873,34 +825,32 @@ class TestDagRun:
assert dagrun.logical_date == timezone.datetime(2015, 1, 2)
def test_removed_task_instances_can_be_restored(self, dag_maker, session):
- def with_all_tasks_removed(dag):
- with dag_maker(
- dag_id=dag.dag_id,
+ def create_dag():
+ return dag_maker(
+ dag_id="test_task_restoration",
schedule=datetime.timedelta(days=1),
- start_date=dag.start_date,
- ) as dag:
- pass
- return dag
+ start_date=DEFAULT_DATE,
+ )
- with dag_maker(
- "test_task_restoration",
- schedule=datetime.timedelta(days=1),
- start_date=DEFAULT_DATE,
- ) as ori_dag:
+ with create_dag() as dag:
EmptyOperator(task_id="flaky_task", owner="test")
- dagrun = self.create_dag_run(ori_dag, session=session)
+ dagrun = self.create_dag_run(dag, session=session)
flaky_ti = dagrun.get_task_instances()[0]
assert flaky_ti.task_id == "flaky_task"
assert flaky_ti.state is None
- dagrun.dag = with_all_tasks_removed(ori_dag)
- dag_version_id = DagVersion.get_latest_version(ori_dag.dag_id,
session=session).id
+ with create_dag() as dag:
+ pass
+
+ dagrun.dag = dag
+ dag_version_id = DagVersion.get_latest_version(dag.dag_id,
session=session).id
dagrun.verify_integrity(dag_version_id=dag_version_id)
flaky_ti.refresh_from_db()
assert flaky_ti.state is None
- dagrun.dag.add_task(ori_dag.task_dict["flaky_task"])
+ with create_dag() as dag:
+ EmptyOperator(task_id="flaky_task", owner="test")
dagrun.verify_integrity(dag_version_id=dag_version_id)
flaky_ti.refresh_from_db()
@@ -1211,9 +1161,8 @@ class TestDagRun:
with dag_maker(
dag_id="test_dagrun_states", schedule=datetime.timedelta(days=1),
start_date=DEFAULT_DATE
) as dag:
- ...
- dag_task_success = EmptyOperator(task_id="dummy", dag=dag)
- dag_task_failed = EmptyOperator(task_id="dummy2", dag=dag)
+ dag_task_success = EmptyOperator(task_id="dummy")
+ dag_task_failed = EmptyOperator(task_id="dummy2")
initial_task_states = {
dag_task_success.task_id: TaskInstanceState.SUCCESS,
@@ -1319,10 +1268,9 @@ class TestDagRun:
callback=AsyncCallback(empty_callback_for_deadline),
),
) as dag:
- ...
- dag_task1 = EmptyOperator(task_id="test_state_succeeded1", dag=dag)
- dag_task2 = EmptyOperator(task_id="test_state_succeeded2", dag=dag)
- dag_task1.set_downstream(dag_task2)
+ dag_task1 = EmptyOperator(task_id="test_state_succeeded1")
+ dag_task2 = EmptyOperator(task_id="test_state_succeeded2")
+ dag_task1.set_downstream(dag_task2)
initial_task_states = {
"test_state_succeeded1": TaskInstanceState.SUCCESS,
@@ -1330,7 +1278,6 @@ class TestDagRun:
}
# Scheduler uses Serialized DAG -- so use that instead of the Actual
DAG.
- dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
dag_run = self.create_dag_run(dag=dag,
task_states=initial_task_states, session=session)
dag_run = session.merge(dag_run)
dag_run.dag = dag
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index 094b8f362e8..0df775ca41e 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -151,6 +151,7 @@ class TestTaskInstance:
def teardown_method(self):
self.clean_db()
+ @pytest.mark.need_serialized_dag(False)
def test_set_task_dates(self, dag_maker):
"""
Test that tasks properly take start/end dates from DAGs
@@ -159,7 +160,6 @@ class TestTaskInstance:
pass
op1 = EmptyOperator(task_id="op_1")
-
assert op1.start_date is None
assert op1.end_date is None
@@ -190,6 +190,7 @@ class TestTaskInstance:
assert op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1)
assert op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9)
+ @pytest.mark.need_serialized_dag(False)
def test_set_dag(self, dag_maker):
"""
Test assigning Operators to Dags, including deferred assignment
@@ -2417,23 +2418,25 @@ class TestTaskInstance:
def test_handle_failure_fail_fast(self, dag_maker, session):
start_date = timezone.datetime(2016, 6, 1)
- clear_db_runs()
class CustomOp(BaseOperator):
def execute(self, context): ...
+ reg_states = [State.RUNNING, State.FAILED, State.QUEUED,
State.SCHEDULED, State.DEFERRED]
+
with dag_maker(
dag_id="test_handle_failure_fail_fast",
start_date=start_date,
schedule=None,
fail_fast=True,
- ) as dag:
- task1 = CustomOp(task_id="task1", trigger_rule="all_success")
-
- dag_maker.create_dagrun(run_type=DagRunType.MANUAL,
start_date=start_date)
+ ):
+ CustomOp(task_id="task1", trigger_rule="all_success")
+ for i, _ in enumerate(reg_states):
+ CustomOp(task_id=f"reg_Task{i}")
+ CustomOp(task_id="fail_Task")
logical_date = timezone.utcnow()
- dr = dag.create_dagrun(
+ dr = dag_maker.create_dagrun(
run_id="test_ff",
run_type=DagRunType.MANUAL,
logical_date=logical_date,
@@ -2445,31 +2448,23 @@ class TestTaskInstance:
)
dr.set_state(DagRunState.SUCCESS)
- ti1 = dr.get_task_instance(task1.task_id, session=session)
- ti1.task = task1
- ti1.state = State.SUCCESS
-
- states = [State.RUNNING, State.FAILED, State.QUEUED, State.SCHEDULED,
State.DEFERRED]
- tasks = []
- for i, state in enumerate(states):
- op = CustomOp(task_id=f"reg_Task{i}", dag=dag)
- ti = TI(task=op, run_id=dr.run_id,
dag_version_id=ti1.dag_version_id)
- ti.state = state
- session.add(ti)
- tasks.append(ti)
-
- fail_task = CustomOp(task_id="fail_Task", dag=dag)
- ti_ff = TI(task=fail_task, run_id=dr.run_id,
dag_version_id=ti1.dag_version_id)
- ti_ff.state = State.FAILED
- session.add(ti_ff)
- session.commit()
- ti_ff.handle_failure("test retry handling")
+ tis = {ti.task_id: ti for ti in dr.task_instances}
+ tis["task1"].state = State.SUCCESS
+ for i, state in enumerate(reg_states):
+ tis[f"reg_Task{i}"].state = state
+ tis["fail_Task"].state = State.FAILED
+ session.flush()
- assert ti1.state == State.SUCCESS
- assert ti_ff.state == State.FAILED
- exp_states = [State.FAILED, State.FAILED, State.SKIPPED,
State.SKIPPED, State.SKIPPED]
- for i in range(len(states)):
- assert tasks[i].state == exp_states[i]
+ tis["fail_Task"].handle_failure("test retry handling")
+ assert {task_id: ti.state for task_id, ti in tis.items()} == {
+ "task1": State.SUCCESS,
+ "fail_Task": State.FAILED,
+ "reg_Task0": State.FAILED,
+ "reg_Task1": State.FAILED,
+ "reg_Task2": State.SKIPPED,
+ "reg_Task3": State.SKIPPED,
+ "reg_Task4": State.SKIPPED,
+ }
def test_does_not_retry_on_airflow_fail_exception(self, dag_maker):
def fail():
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index 6d71851dc87..96b26f69335 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -2914,7 +2914,7 @@ def test_taskflow_expand_kwargs_serde(strict):
def test_mapped_task_group_serde():
from airflow.models.expandinput import SchedulerDictOfListsExpandInput
from airflow.sdk.definitions.decorators.task_group import task_group
- from airflow.sdk.definitions.taskgroup import MappedTaskGroup
+ from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as
dag:
@@ -2955,7 +2955,7 @@ def test_mapped_task_group_serde():
serde_dag = SerializedDAG.deserialize_dag(ser_dag[Encoding.VAR])
serde_tg = serde_dag.task_group.children["tg"]
- assert isinstance(serde_tg, MappedTaskGroup)
+ assert isinstance(serde_tg, SerializedTaskGroup)
assert serde_tg._expand_input == SchedulerDictOfListsExpandInput({"a":
[".", ".."]})
diff --git a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
index 75c6888509a..5d2eda6dfac 100644
--- a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
+++ b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
@@ -1419,12 +1419,13 @@ def
test_upstream_in_mapped_group_when_mapped_tasks_list_is_empty(dag_maker, ses
@pytest.mark.parametrize("flag_upstream_failed", [True, False])
[email protected]_serialized_dag
def test_mapped_task_check_before_expand(dag_maker, session,
flag_upstream_failed):
"""
t3 depends on t2, which depends on t1 for expansion. Since t1 has not yet
run, t2 has not expanded yet,
and we need to guarantee this lack of expansion does not fail the
dependency-checking logic.
"""
- with dag_maker(session=session):
+ with dag_maker(session=session) as dag:
@task
def t(x):
@@ -1439,9 +1440,11 @@ def test_mapped_task_check_before_expand(dag_maker,
session, flag_upstream_faile
tg.expand(a=t([1, 2, 3]))
dr: DagRun = dag_maker.create_dagrun()
+ ti = next(ti for ti in dr.task_instances if ti.task_id == "tg.t3" and
ti.map_index == -1)
+ ti.refresh_from_task(dag.get_task(ti.task_id))
_test_trigger_rule(
- ti=next(ti for ti in dr.task_instances if ti.task_id == "tg.t3" and
ti.map_index == -1),
+ ti=ti,
session=session,
flag_upstream_failed=flag_upstream_failed,
expected_reason="requires all upstream tasks to have succeeded, but
found 1",
@@ -1449,6 +1452,7 @@ def test_mapped_task_check_before_expand(dag_maker,
session, flag_upstream_faile
@pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True,
SKIPPED), (False, None)])
[email protected]_serialized_dag
def test_mapped_task_group_finished_upstream_before_expand(
dag_maker, session, flag_upstream_failed, expected_ti_state
):
@@ -1456,7 +1460,7 @@ def
test_mapped_task_group_finished_upstream_before_expand(
t3 depends on t2, which was skipped before it was expanded. We need to
guarantee this lack of expansion
does not fail the dependency-checking logic.
"""
- with dag_maker(session=session):
+ with dag_maker(session=session) as dag:
@task
def t(x):
@@ -1472,6 +1476,8 @@ def
test_mapped_task_group_finished_upstream_before_expand(
tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)}
tis["t2"].set_state(SKIPPED, session=session)
session.flush()
+
+ tis["tg.t3"].refresh_from_task(dag.get_task("tg.t3"))
_test_trigger_rule(
ti=tis["tg.t3"],
session=session,
@@ -1734,6 +1740,7 @@ def
test_setup_constraint_wait_for_past_depends_before_skipping(
@pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True,
SKIPPED), (False, None)])
[email protected]_serialized_dag
def test_setup_mapped_task_group_finished_upstream_before_expand(
dag_maker, session, flag_upstream_failed, expected_ti_state
):
@@ -1741,7 +1748,7 @@ def
test_setup_mapped_task_group_finished_upstream_before_expand(
t3 indirectly depends on t1, which was skipped before it was expanded. We
need to guarantee this lack of
expansion does not fail the dependency-checking logic.
"""
- with dag_maker(session=session):
+ with dag_maker(session=session) as dag:
@task(trigger_rule=TriggerRule.ALL_DONE)
def t(x):
@@ -1760,6 +1767,8 @@ def
test_setup_mapped_task_group_finished_upstream_before_expand(
tis["t1"].set_state(SKIPPED, session=session)
tis["t2"].set_state(SUCCESS, session=session)
session.flush()
+
+ tis["tg.t3"].refresh_from_task(dag.get_task("tg.t3"))
_test_trigger_rule(
ti=tis["tg.t3"],
session=session,
diff --git a/airflow-core/tests/unit/utils/test_task_group.py
b/airflow-core/tests/unit/utils/test_task_group.py
index 447578329aa..52524ecda2f 100644
--- a/airflow-core/tests/unit/utils/test_task_group.py
+++ b/airflow-core/tests/unit/utils/test_task_group.py
@@ -21,20 +21,21 @@ import pendulum
import pytest
from airflow.api_fastapi.core_api.services.ui.task_group import
task_group_to_dict
-from airflow.models.baseoperator import BaseOperator
-from airflow.models.dag import DAG
+from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
+from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import (
+ DAG,
+ BaseOperator,
+ TaskGroup,
setup,
task as task_decorator,
task_group as task_group_decorator,
teardown,
)
-from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.dag_edges import dag_edges
-from tests_common.test_utils.compat import BashOperator, PythonOperator
from unit.models import DEFAULT_DATE
pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag]
@@ -157,6 +158,7 @@ EXPECTED_JSON_LEGACY = {
EXPECTED_JSON = {
"children": [
{"id": "task1", "label": "task1", "operator": "EmptyOperator", "type":
"task"},
+ {"id": "task5", "label": "task5", "operator": "EmptyOperator", "type":
"task"},
{
"children": [
{
@@ -195,11 +197,10 @@ EXPECTED_JSON = {
"tooltip": "",
"type": "task",
},
- {"id": "task5", "label": "task5", "operator": "EmptyOperator", "type":
"task"},
],
"id": None,
"is_mapped": False,
- "label": None,
+ "label": "",
"tooltip": "",
"type": "task",
}
@@ -276,7 +277,10 @@ def test_task_group_to_dict_with_prefix(dag_maker):
expected_node_id = {
"children": [
{"id": "task1", "label": "task1"},
+ {"id": "task5", "label": "task5"},
{
+ "id": "group234",
+ "label": "group234",
"children": [
{
"children": [
@@ -294,13 +298,10 @@ def test_task_group_to_dict_with_prefix(dag_maker):
{"id": "task2", "label": "task2"},
{"id": "group234.upstream_join_id", "label": ""},
],
- "id": "group234",
- "label": "group234",
},
- {"id": "task5", "label": "task5"},
],
"id": None,
- "label": None,
+ "label": "",
}
assert extract_node_id(task_group_to_dict(dag.task_group),
include_label=True) == expected_node_id
@@ -346,6 +347,7 @@ def test_task_group_to_dict_with_task_decorator(dag_maker):
"id": None,
"children": [
{"id": "task_1"},
+ {"id": "task_5"},
{
"id": "group234",
"children": [
@@ -356,7 +358,6 @@ def test_task_group_to_dict_with_task_decorator(dag_maker):
{"id": "group234.downstream_join_id"},
],
},
- {"id": "task_5"},
],
}
@@ -402,6 +403,7 @@ def test_task_group_to_dict_sub_dag(dag_maker):
"id": None,
"children": [
{"id": "task1"},
+ {"id": "task5"},
{
"id": "group234",
"children": [
@@ -416,7 +418,6 @@ def test_task_group_to_dict_sub_dag(dag_maker):
{"id": "group234.upstream_join_id"},
],
},
- {"id": "task5"},
],
}
@@ -477,6 +478,16 @@ def test_task_group_to_dict_and_dag_edges(dag_maker):
expected_node_id = {
"id": None,
"children": [
+ {
+ "id": "group_c",
+ "children": [
+ {"id": "group_c.task6"},
+ {"id": "group_c.task7"},
+ {"id": "group_c.task8"},
+ {"id": "group_c.upstream_join_id"},
+ {"id": "group_c.downstream_join_id"},
+ ],
+ },
{
"id": "group_d",
"children": [
@@ -486,6 +497,8 @@ def test_task_group_to_dict_and_dag_edges(dag_maker):
],
},
{"id": "task1"},
+ {"id": "task10"},
+ {"id": "task9"},
{
"id": "group_a",
"children": [
@@ -503,18 +516,6 @@ def test_task_group_to_dict_and_dag_edges(dag_maker):
{"id": "group_a.downstream_join_id"},
],
},
- {
- "id": "group_c",
- "children": [
- {"id": "group_c.task6"},
- {"id": "group_c.task7"},
- {"id": "group_c.task8"},
- {"id": "group_c.upstream_join_id"},
- {"id": "group_c.downstream_join_id"},
- ],
- },
- {"id": "task10"},
- {"id": "task9"},
],
}
@@ -783,6 +784,7 @@ def test_task_group_context_mix(dag_maker):
node_ids = {
"id": None,
"children": [
+ {"id": "task_end"},
{"id": "task_start"},
{
"id": "section_1",
@@ -802,7 +804,6 @@ def test_task_group_context_mix(dag_maker):
{"id": "section_1.downstream_join_id"},
],
},
- {"id": "task_end"},
],
}
@@ -1184,7 +1185,7 @@ def test_task_group_display_name_used_as_label(dag_maker):
assert tg.label == "my_custom_name"
expected_node_id = {
"id": None,
- "label": None,
+ "label": "",
"children": [
{
"id": "tg",
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index af87248da46..ce588352d4e 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -43,7 +43,6 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.dagrun import DagRun, DagRunType
- from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import DAG, BaseOperator, Context, TriggerRule
@@ -51,8 +50,8 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.dag import ScheduleArg
from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
- from airflow.sdk.types import DagRunProtocol
- from airflow.serialization.serialized_objects import
SerializedBaseOperator, SerializedDAG
+ from airflow.sdk.types import DagRunProtocol, Operator
+ from airflow.serialization.serialized_objects import SerializedDAG
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Self
from airflow.utils.state import DagRunState, TaskInstanceState
@@ -892,10 +891,17 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
self.dag.__enter__()
if self.want_serialized:
+ factory = self
class DAGProxy(lazy_object_proxy.Proxy):
- # Make `@dag.task` decorator work when need_serialized_dag
marker is set
- task = self.dag.task
+ """Wrapper to make test patterns work with serialized
dag."""
+
+ task = factory.dag.task # Expose the @dag.task decorator.
+
+ # When adding a task to the dag, automatically
re-serialize.
+ def add_task(self, task):
+ factory.dag.add_task(task)
+ factory._make_serdag(factory.dag)
return DAGProxy(self._serialized_dag)
return self.dag
@@ -2310,13 +2316,12 @@ def create_runtime_ti(mocked_parse):
from airflow.sdk import DAG
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails
- from airflow.serialization.serialized_objects import SerializedDAG
from airflow.timetables.base import TimeRestriction
timezone = _import_timezone()
def _create_task_instance(
- task: MappedOperator | SerializedBaseOperator,
+ task: Operator,
dag_id: str = "test_dag",
run_id: str = "test_run",
logical_date: str | datetime | None = "2024-12-01T01:00:00Z",
@@ -2353,14 +2358,7 @@ def create_runtime_ti(mocked_parse):
ti_id = uuid7()
if not task.has_dag():
- dag = SerializedDAG.deserialize_dag(
- SerializedDAG.serialize_dag(DAG(dag_id=dag_id,
start_date=timezone.datetime(2024, 12, 3)))
- )
- # Fixture only helps in regular base operator tasks, so mypy is
wrong here
- task.dag = dag
- # TODO (GH-52141): Scheduler DAG should contain scheduler tasks,
but
- # currently this inherits from SDK DAG.
- task = dag.task_dict[task.task_id] # type: ignore[assignment]
+ task.dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024,
12, 3))
if TYPE_CHECKING:
assert task.dag is not None
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 03b694f57d1..849de411cdf 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -29,13 +29,7 @@ from collections import abc, defaultdict, deque
from collections.abc import Callable, Collection, Iterable, MutableSet
from datetime import datetime, timedelta
from inspect import signature
-from typing import (
- TYPE_CHECKING,
- Any,
- ClassVar,
- cast,
- overload,
-)
+from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, cast, overload
from urllib.parse import urlsplit
import attrs
@@ -829,15 +823,9 @@ class DAG:
:param include_direct_upstream: Include all tasks directly upstream of
matched
and downstream (if include_downstream = True) tasks
"""
- from typing import TypeGuard
-
- from airflow.models.mappedoperator import MappedOperator as
DbMappedOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import
SerializedBaseOperator
def is_task(obj) -> TypeGuard[Operator]:
- if isinstance(obj, (DbMappedOperator, SerializedBaseOperator)):
- return True # TODO (GH-52141): Split DAG implementation to
straight this up.
return isinstance(obj, (BaseOperator, MappedOperator))
# deep-copying self.task_dict and self.task_group takes a long time,
and we don't want all
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index 45de93b7da8..b2f9aa6b909 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -40,8 +40,6 @@ from airflow.sdk.definitions._internal.node import DAGNode,
validate_group_key
from airflow.sdk.exceptions import AirflowDagCycleException
if TYPE_CHECKING:
- from airflow.models.expandinput import SchedulerExpandInput
- from airflow.models.mappedoperator import MappedOperator
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator
from airflow.sdk.definitions._internal.expandinput import
DictOfListsExpandInput, ListOfDictsExpandInput
@@ -50,7 +48,6 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.edges import EdgeModifier
from airflow.sdk.types import Operator
from airflow.serialization.enums import DagAttributeTypes
- from airflow.serialization.serialized_objects import SerializedBaseOperator
def _default_parent_group() -> TaskGroup | None:
@@ -274,10 +271,14 @@ class TaskGroup(DAGNode):
@property
def group_id(self) -> str | None:
"""group_id of this TaskGroup."""
- if self.parent_group and self.parent_group.prefix_group_id and
self.parent_group._group_id:
+ if (
+ self._group_id
+ and self.parent_group
+ and self.parent_group.prefix_group_id
+ and self.parent_group._group_id
+ ):
# defer to parent whether it adds a prefix
return self.parent_group.child_id(self._group_id)
-
return self._group_id
@property
@@ -585,12 +586,9 @@ class TaskGroup(DAGNode):
yield group
group = group.parent_group
- # TODO (GH-52141): This should only return SDK operators. Have a db
representation for db operators.
- def iter_tasks(self) -> Iterator[AbstractOperator | MappedOperator |
SerializedBaseOperator]:
+ def iter_tasks(self) -> Iterator[AbstractOperator]:
"""Return an iterator of the child tasks."""
- from airflow.models.mappedoperator import MappedOperator
from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator
- from airflow.serialization.serialized_objects import
SerializedBaseOperator
groups_to_visit = [self]
@@ -598,16 +596,18 @@ class TaskGroup(DAGNode):
visiting = groups_to_visit.pop(0)
for child in visiting.children.values():
- if isinstance(child, (AbstractOperator, MappedOperator,
SerializedBaseOperator)):
+ if isinstance(child, AbstractOperator):
yield child
elif isinstance(child, TaskGroup):
groups_to_visit.append(child)
else:
raise ValueError(
- f"Encountered a DAGNode that is not a TaskGroup or an
AbstractOperator: {type(child).__module__}.{type(child)}"
+ f"Encountered a DAGNode that is not a TaskGroup or an "
+ f"AbstractOperator:
{type(child).__module__}.{type(child)}"
)
[email protected](kw_only=True, repr=False)
class MappedTaskGroup(TaskGroup):
"""
A mapped task group.
@@ -619,22 +619,14 @@ class MappedTaskGroup(TaskGroup):
a ``@task_group`` function instead.
"""
- def __init__(
- self,
- *,
- expand_input: SchedulerExpandInput | DictOfListsExpandInput |
ListOfDictsExpandInput,
- **kwargs: Any,
- ) -> None:
- super().__init__(**kwargs)
- self._expand_input = expand_input
+ _expand_input: DictOfListsExpandInput | ListOfDictsExpandInput =
attrs.field(alias="expand_input")
def __iter__(self):
- from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator
-
for child in self.children.values():
- if isinstance(child, AbstractOperator) and child.trigger_rule ==
TriggerRule.ALWAYS:
+ if getattr(child, "trigger_rule", None) == TriggerRule.ALWAYS:
raise ValueError(
- "Task-generated mapping within a mapped task group is not
allowed with trigger rule 'always'"
+ "Task-generated mapping within a mapped task group is not "
+ "allowed with trigger rule 'always'"
)
yield from self._iter_child(child)