This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 13e54646f99 Ensure that you can create a second DAG whilst another one 
is already "active" (#44484)
13e54646f99 is described below

commit 13e54646f997da0ddbdad154679529c770cafa75
Author: Ash Berlin-Taylor <a...@apache.org>
AuthorDate: Fri Nov 29 18:44:22 2024 +0000

    Ensure that you can create a second DAG whilst another one is already 
"active" (#44484)
    
    Why would you want to do this? Who knows, maybe you are calling a dag 
factory
    from inside a `with DAG` block. Either way, this exposed a subtle bug in
    `TaskGroup.create_root()`.
    
    This is the other half of the fix for the flakey tests fixed in #44480, and
    after much digging with @kaxil and @potiuk we've finally worked out why it 
was
    flakey:
    
    It was the "Non-DB" test job that were faling sometimes, and those tests use
    xdist to parallelize the tests. Couple that with the fact that
    `get_serialized_fields()` caches the answer on the class object, the test
    would only fail when nothing else in the current test process had previously
    called `DAG.get_serialized_fields()`.
    
    And to make this less likely to occur in future, the __serialized_fields is
    moved to being created eagerly at parse time, no more lazy loaded cache!
---
 task_sdk/src/airflow/sdk/definitions/dag.py       | 58 +++++++++++------------
 task_sdk/src/airflow/sdk/definitions/taskgroup.py |  2 +-
 task_sdk/tests/defintions/test_dag.py             |  6 +++
 3 files changed, 36 insertions(+), 30 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index d02480c8ab1..2b1270d6f25 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -357,7 +357,7 @@ class DAG:
     :param dag_display_name: The display name of the DAG which appears on the 
UI.
     """
 
-    __serialized_fields: ClassVar[frozenset[str] | None] = None
+    __serialized_fields: ClassVar[frozenset[str]]
 
     # Note: mypy gets very confused about the use of `@${attr}.default` for 
attrs without init=False -- and it
     # doesn't correctly track/notice that they have default values (it gives 
errors about `Missing positional
@@ -964,34 +964,6 @@ class DAG:
     @classmethod
     def get_serialized_fields(cls):
         """Stringified DAGs and operators contain exactly these fields."""
-        if not cls.__serialized_fields:
-            exclusion_list = {
-                "schedule_asset_references",
-                "schedule_asset_alias_references",
-                "task_outlet_asset_references",
-                "_old_context_manager_dags",
-                "safe_dag_id",
-                "last_loaded",
-                "user_defined_filters",
-                "user_defined_macros",
-                "partial",
-                "params",
-                "_log",
-                "task_dict",
-                "template_searchpath",
-                # "sla_miss_callback",
-                "on_success_callback",
-                "on_failure_callback",
-                "template_undefined",
-                "jinja_environment_kwargs",
-                # has_on_*_callback are only stored if the value is True, as 
the default is False
-                "has_on_success_callback",
-                "has_on_failure_callback",
-                "auto_register",
-                "fail_stop",
-                "schedule",
-            }
-            cls.__serialized_fields = frozenset(a.name for a in 
attrs.fields(cls)) - exclusion_list
         return cls.__serialized_fields
 
     def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> 
EdgeInfoType:
@@ -1030,6 +1002,34 @@ class DAG:
             )
 
 
+# Since we define all the attributes of the class with attrs, we can compute 
this statically at parse time
+DAG._DAG__serialized_fields = frozenset(a.name for a in attrs.fields(DAG)) - { 
 # type: ignore[attr-defined]
+    "schedule_asset_references",
+    "schedule_asset_alias_references",
+    "task_outlet_asset_references",
+    "_old_context_manager_dags",
+    "safe_dag_id",
+    "last_loaded",
+    "user_defined_filters",
+    "user_defined_macros",
+    "partial",
+    "params",
+    "_log",
+    "task_dict",
+    "template_searchpath",
+    # "sla_miss_callback",
+    "on_success_callback",
+    "on_failure_callback",
+    "template_undefined",
+    "jinja_environment_kwargs",
+    # has_on_*_callback are only stored if the value is True, as the default 
is False
+    "has_on_success_callback",
+    "has_on_failure_callback",
+    "auto_register",
+    "fail_stop",
+    "schedule",
+}
+
 if TYPE_CHECKING:
     # NOTE: Please keep the list of arguments in sync with DAG.__init__.
     # Only exception: dag_id here should have a default value, but not in DAG.
diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
index 1c8d1ded824..7395b341740 100644
--- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -175,7 +175,7 @@ class TaskGroup(DAGNode):
     @classmethod
     def create_root(cls, dag: DAG) -> TaskGroup:
         """Create a root TaskGroup with no group_id or parent."""
-        return cls(group_id=None, dag=dag)
+        return cls(group_id=None, dag=dag, parent_group=None)
 
     @property
     def node_id(self):
diff --git a/task_sdk/tests/defintions/test_dag.py 
b/task_sdk/tests/defintions/test_dag.py
index 49699b67353..f0e634f19b6 100644
--- a/task_sdk/tests/defintions/test_dag.py
+++ b/task_sdk/tests/defintions/test_dag.py
@@ -417,3 +417,9 @@ class TestDagDecorator:
         # Test that if arg is not passed it raises a type error as expected.
         with pytest.raises(TypeError):
             noop_pipeline()
+
+    def test_create_dag_while_active_context(self):
+        """Test that we can safely create a DAG whilst a DAG is activated via 
``with dag1:``."""
+        with DAG(dag_id="simple_dag"):
+            DAG(dag_id="dag2")
+            # No asserts needed, it just needs to not fail

Reply via email to