This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 13e54646f99 Ensure that you can create a second DAG whilst another one
is already "active" (#44484)
13e54646f99 is described below
commit 13e54646f997da0ddbdad154679529c770cafa75
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Nov 29 18:44:22 2024 +0000
Ensure that you can create a second DAG whilst another one is already
"active" (#44484)
Why would you want to do this? Who knows, maybe you are calling a dag
factory
from inside a `with DAG` block. Either way, this exposed a subtle bug in
`TaskGroup.create_root()`.
This is the other half of the fix for the flakey tests fixed in #44480, and
after much digging with @kaxil and @potiuk we've finally worked out why it
was
flakey:
It was the "Non-DB" test job that were faling sometimes, and those tests use
xdist to parallelize the tests. Couple that with the fact that
`get_serialized_fields()` caches the answer on the class object, the test
would only fail when nothing else in the current test process had previously
called `DAG.get_serialized_fields()`.
And to make this less likely to occur in future, the __serialized_fields is
moved to being created eagerly at parse time, no more lazy loaded cache!
---
task_sdk/src/airflow/sdk/definitions/dag.py | 58 +++++++++++------------
task_sdk/src/airflow/sdk/definitions/taskgroup.py | 2 +-
task_sdk/tests/defintions/test_dag.py | 6 +++
3 files changed, 36 insertions(+), 30 deletions(-)
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py
b/task_sdk/src/airflow/sdk/definitions/dag.py
index d02480c8ab1..2b1270d6f25 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -357,7 +357,7 @@ class DAG:
:param dag_display_name: The display name of the DAG which appears on the
UI.
"""
- __serialized_fields: ClassVar[frozenset[str] | None] = None
+ __serialized_fields: ClassVar[frozenset[str]]
# Note: mypy gets very confused about the use of `@${attr}.default` for
attrs without init=False -- and it
# doesn't correctly track/notice that they have default values (it gives
errors about `Missing positional
@@ -964,34 +964,6 @@ class DAG:
@classmethod
def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
- if not cls.__serialized_fields:
- exclusion_list = {
- "schedule_asset_references",
- "schedule_asset_alias_references",
- "task_outlet_asset_references",
- "_old_context_manager_dags",
- "safe_dag_id",
- "last_loaded",
- "user_defined_filters",
- "user_defined_macros",
- "partial",
- "params",
- "_log",
- "task_dict",
- "template_searchpath",
- # "sla_miss_callback",
- "on_success_callback",
- "on_failure_callback",
- "template_undefined",
- "jinja_environment_kwargs",
- # has_on_*_callback are only stored if the value is True, as
the default is False
- "has_on_success_callback",
- "has_on_failure_callback",
- "auto_register",
- "fail_stop",
- "schedule",
- }
- cls.__serialized_fields = frozenset(a.name for a in
attrs.fields(cls)) - exclusion_list
return cls.__serialized_fields
def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) ->
EdgeInfoType:
@@ -1030,6 +1002,34 @@ class DAG:
)
+# Since we define all the attributes of the class with attrs, we can compute
this statically at parse time
+DAG._DAG__serialized_fields = frozenset(a.name for a in attrs.fields(DAG)) - {
# type: ignore[attr-defined]
+ "schedule_asset_references",
+ "schedule_asset_alias_references",
+ "task_outlet_asset_references",
+ "_old_context_manager_dags",
+ "safe_dag_id",
+ "last_loaded",
+ "user_defined_filters",
+ "user_defined_macros",
+ "partial",
+ "params",
+ "_log",
+ "task_dict",
+ "template_searchpath",
+ # "sla_miss_callback",
+ "on_success_callback",
+ "on_failure_callback",
+ "template_undefined",
+ "jinja_environment_kwargs",
+ # has_on_*_callback are only stored if the value is True, as the default
is False
+ "has_on_success_callback",
+ "has_on_failure_callback",
+ "auto_register",
+ "fail_stop",
+ "schedule",
+}
+
if TYPE_CHECKING:
# NOTE: Please keep the list of arguments in sync with DAG.__init__.
# Only exception: dag_id here should have a default value, but not in DAG.
diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
index 1c8d1ded824..7395b341740 100644
--- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -175,7 +175,7 @@ class TaskGroup(DAGNode):
@classmethod
def create_root(cls, dag: DAG) -> TaskGroup:
"""Create a root TaskGroup with no group_id or parent."""
- return cls(group_id=None, dag=dag)
+ return cls(group_id=None, dag=dag, parent_group=None)
@property
def node_id(self):
diff --git a/task_sdk/tests/defintions/test_dag.py
b/task_sdk/tests/defintions/test_dag.py
index 49699b67353..f0e634f19b6 100644
--- a/task_sdk/tests/defintions/test_dag.py
+++ b/task_sdk/tests/defintions/test_dag.py
@@ -417,3 +417,9 @@ class TestDagDecorator:
# Test that if arg is not passed it raises a type error as expected.
with pytest.raises(TypeError):
noop_pipeline()
+
+ def test_create_dag_while_active_context(self):
+ """Test that we can safely create a DAG whilst a DAG is activated via
``with dag1:``."""
+ with DAG(dag_id="simple_dag"):
+ DAG(dag_id="dag2")
+ # No asserts needed, it just needs to not fail