turbaszek commented on a change in pull request #10153:
URL: https://github.com/apache/airflow/pull/10153#discussion_r484263443



##########
File path: airflow/utils/task_group.py
##########
@@ -0,0 +1,392 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+A TaskGroup is a collection of closely related tasks on the same DAG that 
should be grouped
+together when the DAG is displayed graphically.
+"""
+
+from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, 
Set, Union
+
+from airflow.exceptions import AirflowException, DuplicateTaskIdFound
+
+if TYPE_CHECKING:
+    from airflow.models.baseoperator import BaseOperator
+    from airflow.models.dag import DAG
+
+
+class TaskGroup:
+    """
+    A collection of tasks. When set_downstream() or set_upstream() are called 
on the
+    TaskGroup, it is applied across all tasks within the group if necessary.
+
+    :param group_id: a unique, meaningful id for the TaskGroup. group_id must 
not conflict
+        with group_id of TaskGroup or task_id of tasks in the DAG. Root 
TaskGroup has group_id
+        set to None.
+    :type group_id: str
+    :param prefix_group_id: If set to True, child task_id and group_id will be 
prefixed with
+        this TaskGroup's group_id. If set to False, child task_id and group_id 
are not prefixed.
+        Default is True.
+    :type prerfix_group_id: bool
+    :param parent_group: The parent TaskGroup of this TaskGroup. parent_group 
is set to None
+        for the root TaskGroup.
+    :type parent_group: TaskGroup
+    :param dag: The DAG that this TaskGroup belongs to.
+    :type dag: airflow.models.DAG
+    :param tooltip: The tooltip of the TaskGroup node when displayed in the UI
+    :type tooltip: str
+    :param ui_color: The fill color of the TaskGroup node when displayed in 
the UI
+    :type ui_color: str
+    :param ui_fgcolor: The label color of the TaskGroup node when displayed in 
the UI
+    :type ui_fgcolor: str
+    """
+
+    def __init__(
+        self,
+        group_id: Optional[str],
+        prefix_group_id: bool = True,
+        parent_group: Optional["TaskGroup"] = None,
+        dag: Optional["DAG"] = None,
+        tooltip: str = "",
+        ui_color: str = "CornflowerBlue",
+        ui_fgcolor: str = "#000",
+    ):
+        from airflow.models.dag import DagContext
+
+        self.prefix_group_id = prefix_group_id
+
+        if group_id is None:
+            # This creates a root TaskGroup.
+            if parent_group:
+                raise AirflowException("Root TaskGroup cannot have 
parent_group")
+            # used_group_ids is shared across all TaskGroups in the same DAG 
to keep track
+            # of used group_id to avoid duplication.
+            self.used_group_ids: Set[Optional[str]] = set()
+            self._parent_group = None
+        else:
+            if not isinstance(group_id, str):
+                raise ValueError("group_id must be str")
+            if not group_id:
+                raise ValueError("group_id must not be empty")
+
+            dag = dag or DagContext.get_current_dag()
+
+            if not parent_group and not dag:
+                raise AirflowException("TaskGroup can only be used inside a 
dag")
+
+            self._parent_group = parent_group or 
TaskGroupContext.get_current_task_group(dag)
+            if not self._parent_group:
+                raise AirflowException("TaskGroup must have a parent_group 
except for the root TaskGroup")
+            self.used_group_ids = self._parent_group.used_group_ids
+
+        self._group_id = group_id
+        if self.group_id in self.used_group_ids:
+            raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has 
already been added to the DAG")
+        self.used_group_ids.add(self.group_id)
+        self.used_group_ids.add(self.downstream_join_id)
+        self.used_group_ids.add(self.upstream_join_id)
+        self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {}
+        if self._parent_group:
+            self._parent_group.add(self)
+
+        self.tooltip = tooltip
+        self.ui_color = ui_color
+        self.ui_fgcolor = ui_fgcolor
+
+        # Keep track of TaskGroups or tasks that depend on this entire 
TaskGroup separately
+        # so that we can optimize the number of edges when entire TaskGroups 
depend on each other.
+        self.upstream_group_ids: Set[Optional[str]] = set()
+        self.downstream_group_ids: Set[Optional[str]] = set()
+        self.upstream_task_ids: Set[Optional[str]] = set()
+        self.downstream_task_ids: Set[Optional[str]] = set()
+
+    @classmethod
+    def create_root(cls, dag: "DAG"):

Review comment:
       ```suggestion
       def create_root(cls, dag: "DAG") -> "TaskGroup":
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to