turbaszek commented on a change in pull request #10153: URL: https://github.com/apache/airflow/pull/10153#discussion_r484264466
########## File path: airflow/utils/task_group.py ########## @@ -0,0 +1,392 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped +together when the DAG is displayed graphically. +""" + +from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union + +from airflow.exceptions import AirflowException, DuplicateTaskIdFound + +if TYPE_CHECKING: + from airflow.models.baseoperator import BaseOperator + from airflow.models.dag import DAG + + +class TaskGroup: + """ + A collection of tasks. When set_downstream() or set_upstream() are called on the + TaskGroup, it is applied across all tasks within the group if necessary. + + :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict + with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id + set to None. + :type group_id: str + :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with + this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed. + Default is True. + :type prerfix_group_id: bool + :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None + for the root TaskGroup. + :type parent_group: TaskGroup + :param dag: The DAG that this TaskGroup belongs to. + :type dag: airflow.models.DAG + :param tooltip: The tooltip of the TaskGroup node when displayed in the UI + :type tooltip: str + :param ui_color: The fill color of the TaskGroup node when displayed in the UI + :type ui_color: str + :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI + :type ui_fgcolor: str + """ + + def __init__( + self, + group_id: Optional[str], + prefix_group_id: bool = True, + parent_group: Optional["TaskGroup"] = None, + dag: Optional["DAG"] = None, + tooltip: str = "", + ui_color: str = "CornflowerBlue", + ui_fgcolor: str = "#000", + ): + from airflow.models.dag import DagContext + + self.prefix_group_id = prefix_group_id + + if group_id is None: + # This creates a root TaskGroup. + if parent_group: + raise AirflowException("Root TaskGroup cannot have parent_group") + # used_group_ids is shared across all TaskGroups in the same DAG to keep track + # of used group_id to avoid duplication. + self.used_group_ids: Set[Optional[str]] = set() + self._parent_group = None + else: + if not isinstance(group_id, str): + raise ValueError("group_id must be str") + if not group_id: + raise ValueError("group_id must not be empty") + + dag = dag or DagContext.get_current_dag() + + if not parent_group and not dag: + raise AirflowException("TaskGroup can only be used inside a dag") + + self._parent_group = parent_group or TaskGroupContext.get_current_task_group(dag) + if not self._parent_group: + raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup") + self.used_group_ids = self._parent_group.used_group_ids + + self._group_id = group_id + if self.group_id in self.used_group_ids: + raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has already been added to the DAG") + self.used_group_ids.add(self.group_id) + self.used_group_ids.add(self.downstream_join_id) + self.used_group_ids.add(self.upstream_join_id) + self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {} + if self._parent_group: + self._parent_group.add(self) + + self.tooltip = tooltip + self.ui_color = ui_color + self.ui_fgcolor = ui_fgcolor + + # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately + # so that we can optimize the number of edges when entire TaskGroups depend on each other. + self.upstream_group_ids: Set[Optional[str]] = set() + self.downstream_group_ids: Set[Optional[str]] = set() + self.upstream_task_ids: Set[Optional[str]] = set() + self.downstream_task_ids: Set[Optional[str]] = set() + + @classmethod + def create_root(cls, dag: "DAG"): + """ + Create a root TaskGroup with no group_id or parent. + """ + return cls(group_id=None, dag=dag) + + @property + def is_root(self): + """ + Returns True if this TaskGroup is the root TaskGroup. Otherwise False + """ + return not self.group_id + + def __iter__(self): + for child in self.children.values(): + if isinstance(child, TaskGroup): + for inner_task in child: + yield inner_task + else: + yield child + + def add(self, task: Union["BaseOperator", "TaskGroup"]) -> None: + """ + Add a task to this TaskGroup. + """ + key = task.group_id if isinstance(task, TaskGroup) else task.task_id + + if key in self.children: + raise DuplicateTaskIdFound(f"Task id '{key}' has already been added to the DAG") + + if isinstance(task, TaskGroup): + if task.children: + raise AirflowException("Cannot add a non-empty TaskGroup") + + self.children[key] = task # type: ignore + + @property + def group_id(self) -> Optional[str]: + """ + group_id of this TaskGroup. + """ + if self._parent_group and self._parent_group.prefix_group_id and self._parent_group.group_id: + return self._parent_group.child_id(self._group_id) + + return self._group_id + + @property + def label(self): + """ + group_id excluding parent's group_id used as the node label in UI. + """ + return self._group_id + + def _set_relative( + self, + task_or_task_list: Union['BaseOperator', Sequence['BaseOperator'], "TaskGroup"], + upstream: bool = False + ) -> None: + """ + Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. + Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids. + """ + from airflow.models.baseoperator import BaseOperator + + if upstream: + for task in self.get_roots(): + task.set_upstream(task_or_task_list) + else: + for task in self.get_leaves(): + task.set_downstream(task_or_task_list) + + # Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids + # accordingly so that we can reduce the number of edges when displaying Graph View. + if isinstance(task_or_task_list, TaskGroup): + # Handles TaskGroup and TaskGroup + if upstream: + parent, child = (self, task_or_task_list) + else: + parent, child = (task_or_task_list, self) + + parent.upstream_group_ids.add(child.group_id) + child.downstream_group_ids.add(parent.group_id) + else: + # Handles TaskGroup and task or list of tasks + try: + task_list = list(task_or_task_list) # type: ignore + except TypeError: + task_list = [task_or_task_list] # type: ignore + + for task in task_list: + if not isinstance(task, BaseOperator): + raise AirflowException("Relationships can only be set between TaskGroup or operators; " + f"received {task.__class__.__name__}") + + if upstream: + self.upstream_task_ids.add(task.task_id) + else: + self.downstream_task_ids.add(task.task_id) + + def set_downstream( + self, task_or_task_list: Union['BaseOperator', Sequence['BaseOperator'], "TaskGroup"] + ) -> None: + """ + Set a TaskGroup/task/list of task downstream of this TaskGroup. + """ + self._set_relative(task_or_task_list, upstream=False) + + def set_upstream( + self, task_or_task_list: Union['BaseOperator', Sequence['BaseOperator'], "TaskGroup"] + ) -> None: + """ + Set a TaskGroup/task/list of task upstream of this TaskGroup. + """ + self._set_relative(task_or_task_list, upstream=True) + + def __enter__(self): + TaskGroupContext.push_context_managed_task_group(self) + return self + + def __exit__(self, _type, _value, _tb): + TaskGroupContext.pop_context_managed_task_group() + + def has_task(self, task: "BaseOperator") -> bool: + """ + Returns True if this TaskGroup or its children TaskGroups contains the given task. + """ + if task.task_id in self.children: + return True + + return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup)) + + def get_roots(self) -> Generator["BaseOperator", None, None]: + """ + Returns a generator of tasks that are root tasks, i.e. those with no upstream + dependencies within the TaskGroup. + """ + for task in self: + if not any(self.has_task(parent) for parent in task.get_direct_relatives(upstream=True)): + yield task + + def get_leaves(self) -> Generator["BaseOperator", None, None]: + """ + Returns a generator of tasks that are leaf tasks, i.e. those with no downstream + dependencies within the TaskGroup + """ + for task in self: + if not any(self.has_task(child) for child in task.get_direct_relatives(upstream=False)): + yield task + + def __rshift__(self, other): + """ + Implements Self >> Other == self.set_downstream(other) + """ + self.set_downstream(other) + return other + + def __lshift__(self, other): + """ + Implements Self << Other == self.set_upstream(other) + """ + self.set_upstream(other) + return other + + def __rrshift__(self, other): + """ + Called for Operator >> [Operator] because list don't have + __rshift__ operators. + """ + self.__lshift__(other) + return self + + def __rlshift__(self, other): + """ + Called for Operator << [Operator] because list don't have + __lshift__ operators. + """ + self.__rshift__(other) + return self + + def child_id(self, label): + """ + Prefix label with group_id if prefix_group_id is True. Otherwise return the label + as-is. + """ + if self.prefix_group_id and self.group_id: + return f"{self.group_id}.{label}" + + return label + + @property + def upstream_join_id(self): + """ + If this TaskGroup has immediate upstream TaskGroups or tasks, a dummy node called + upstream_join_id will be created in Graph View to join the outgoing edges from this + TaskGroup to reduce the total number of edges needed to be displayed. + """ + return f"{self.group_id}.upstream_join_id" + + @property + def downstream_join_id(self): + """ + If this TaskGroup has immediate downstream TaskGroups or tasks, a dummy node called + downstream_join_id will be created in Graph View to join the outgoing edges from this + TaskGroup to reduce the total number of edges needed to be displayed. + """ + return f"{self.group_id}.downstream_join_id" + + def get_task_group_dict(self) -> Dict[str, "TaskGroup"]: + """ + Returns a flat dictionary of group_id: TaskGroup + """ + task_group_map = {} + + def build_map(task_group): + if not isinstance(task_group, TaskGroup): + return + + task_group_map[task_group.group_id] = task_group + + for child in task_group.children.values(): + build_map(child) + + build_map(self) + return task_group_map + + def get_child_by_label(self, label): Review comment: Would you mind adding type hints? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
