This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 13e54646f99 Ensure that you can create a second DAG whilst another one is already "active" (#44484) 13e54646f99 is described below commit 13e54646f997da0ddbdad154679529c770cafa75 Author: Ash Berlin-Taylor <a...@apache.org> AuthorDate: Fri Nov 29 18:44:22 2024 +0000 Ensure that you can create a second DAG whilst another one is already "active" (#44484) Why would you want to do this? Who knows, maybe you are calling a dag factory from inside a `with DAG` block. Either way, this exposed a subtle bug in `TaskGroup.create_root()`. This is the other half of the fix for the flakey tests fixed in #44480, and after much digging with @kaxil and @potiuk we've finally worked out why it was flakey: It was the "Non-DB" test job that were faling sometimes, and those tests use xdist to parallelize the tests. Couple that with the fact that `get_serialized_fields()` caches the answer on the class object, the test would only fail when nothing else in the current test process had previously called `DAG.get_serialized_fields()`. And to make this less likely to occur in future, the __serialized_fields is moved to being created eagerly at parse time, no more lazy loaded cache! --- task_sdk/src/airflow/sdk/definitions/dag.py | 58 +++++++++++------------ task_sdk/src/airflow/sdk/definitions/taskgroup.py | 2 +- task_sdk/tests/defintions/test_dag.py | 6 +++ 3 files changed, 36 insertions(+), 30 deletions(-) diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index d02480c8ab1..2b1270d6f25 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -357,7 +357,7 @@ class DAG: :param dag_display_name: The display name of the DAG which appears on the UI. """ - __serialized_fields: ClassVar[frozenset[str] | None] = None + __serialized_fields: ClassVar[frozenset[str]] # Note: mypy gets very confused about the use of `@${attr}.default` for attrs without init=False -- and it # doesn't correctly track/notice that they have default values (it gives errors about `Missing positional @@ -964,34 +964,6 @@ class DAG: @classmethod def get_serialized_fields(cls): """Stringified DAGs and operators contain exactly these fields.""" - if not cls.__serialized_fields: - exclusion_list = { - "schedule_asset_references", - "schedule_asset_alias_references", - "task_outlet_asset_references", - "_old_context_manager_dags", - "safe_dag_id", - "last_loaded", - "user_defined_filters", - "user_defined_macros", - "partial", - "params", - "_log", - "task_dict", - "template_searchpath", - # "sla_miss_callback", - "on_success_callback", - "on_failure_callback", - "template_undefined", - "jinja_environment_kwargs", - # has_on_*_callback are only stored if the value is True, as the default is False - "has_on_success_callback", - "has_on_failure_callback", - "auto_register", - "fail_stop", - "schedule", - } - cls.__serialized_fields = frozenset(a.name for a in attrs.fields(cls)) - exclusion_list return cls.__serialized_fields def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: @@ -1030,6 +1002,34 @@ class DAG: ) +# Since we define all the attributes of the class with attrs, we can compute this statically at parse time +DAG._DAG__serialized_fields = frozenset(a.name for a in attrs.fields(DAG)) - { # type: ignore[attr-defined] + "schedule_asset_references", + "schedule_asset_alias_references", + "task_outlet_asset_references", + "_old_context_manager_dags", + "safe_dag_id", + "last_loaded", + "user_defined_filters", + "user_defined_macros", + "partial", + "params", + "_log", + "task_dict", + "template_searchpath", + # "sla_miss_callback", + "on_success_callback", + "on_failure_callback", + "template_undefined", + "jinja_environment_kwargs", + # has_on_*_callback are only stored if the value is True, as the default is False + "has_on_success_callback", + "has_on_failure_callback", + "auto_register", + "fail_stop", + "schedule", +} + if TYPE_CHECKING: # NOTE: Please keep the list of arguments in sync with DAG.__init__. # Only exception: dag_id here should have a default value, but not in DAG. diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py index 1c8d1ded824..7395b341740 100644 --- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -175,7 +175,7 @@ class TaskGroup(DAGNode): @classmethod def create_root(cls, dag: DAG) -> TaskGroup: """Create a root TaskGroup with no group_id or parent.""" - return cls(group_id=None, dag=dag) + return cls(group_id=None, dag=dag, parent_group=None) @property def node_id(self): diff --git a/task_sdk/tests/defintions/test_dag.py b/task_sdk/tests/defintions/test_dag.py index 49699b67353..f0e634f19b6 100644 --- a/task_sdk/tests/defintions/test_dag.py +++ b/task_sdk/tests/defintions/test_dag.py @@ -417,3 +417,9 @@ class TestDagDecorator: # Test that if arg is not passed it raises a type error as expected. with pytest.raises(TypeError): noop_pipeline() + + def test_create_dag_while_active_context(self): + """Test that we can safely create a DAG whilst a DAG is activated via ``with dag1:``.""" + with DAG(dag_id="simple_dag"): + DAG(dag_id="dag2") + # No asserts needed, it just needs to not fail