This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 229d1b2d564 Always create serdag in dagmaker fixture (#50359)
229d1b2d564 is described below

commit 229d1b2d56476a0117a451ff9d430976165dda87
Author: Daniel Standish <15932138+dstand...@users.noreply.github.com>
AuthorDate: Wed May 14 07:59:23 2025 -0700

    Always create serdag in dagmaker fixture (#50359)
    
    This PR makes it so dag_maker always creates serialized dag and dagversion 
objects, along with dag model.
    
    It's not really sensible anymore to not have these other objects there; 
they are always there in production, and increasingly we need them there for 
code to work right, and their omission can leave some bugs hidden (and some of 
them are resolved as part of this).
    
    Initially, I was going to just remove the option, but, it also controls the 
type of dag object returned by dag_maker (serdag vs dag), so for now I leave 
that as is.
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <uranu...@gmail.com>
---
 .../airflow/serialization/serialized_objects.py    | 41 +++++++++--
 airflow-core/tests/unit/api_fastapi/conftest.py    | 15 ++--
 .../core_api/routes/public/test_dag_versions.py    | 10 +--
 .../versions/head/test_task_instances.py           | 25 +++++--
 .../bundles/test_dag_bundle_manager.py             |  2 -
 airflow-core/tests/unit/jobs/test_scheduler_job.py | 10 ++-
 airflow-core/tests/unit/models/test_cleartasks.py  | 17 +++--
 airflow-core/tests/unit/models/test_dagbag.py      | 22 +++++-
 .../tests/unit/models/test_serialized_dag.py       | 38 +++++-----
 .../tests/unit/models/test_taskinstance.py         |  7 +-
 devel-common/src/tests_common/pytest_plugin.py     | 76 ++++++++++----------
 .../tests/unit/docker/decorators/test_docker.py    | 81 ++++++++++++----------
 12 files changed, 211 insertions(+), 133 deletions(-)

diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py 
b/airflow-core/src/airflow/serialization/serialized_objects.py
index d77d992d739..83a6dbf5cf7 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -230,12 +230,26 @@ class 
_PriorityWeightStrategyNotRegistered(AirflowException):
 
 
 def _encode_trigger(trigger: BaseEventTrigger | dict):
+    def _ensure_serialized(d):
+        """
+        Make sure the kwargs dict is JSON-serializable.
+
+        This is done with BaseSerialization logic. A simple check is added to
+        ensure we don't double-serialize, which is possible when a trigger goes
+        through multiple serialization layers.
+        """
+        if isinstance(d, dict) and Encoding.TYPE in d:
+            return d
+        return BaseSerialization.serialize(d)
+
     if isinstance(trigger, dict):
-        return trigger
-    classpath, kwargs = trigger.serialize()
+        classpath = trigger["classpath"]
+        kwargs = trigger["kwargs"]
+    else:
+        classpath, kwargs = trigger.serialize()
     return {
         "classpath": classpath,
-        "kwargs": kwargs,
+        "kwargs": {k: _ensure_serialized(v) for k, v in kwargs.items()},
     }
 
 
@@ -303,6 +317,18 @@ def decode_asset_condition(var: dict[str, Any]) -> 
BaseAsset:
 
 
 def decode_asset(var: dict[str, Any]):
+    def _smart_decode_trigger_kwargs(d):
+        """
+        Slightly clean up kwargs for display.
+
+        This detects one level of BaseSerialization and tries to deserialize 
the
+        content, removing some __type __var ugliness when the value is 
displayed
+        in UI to the user.
+        """
+        if not isinstance(d, dict) or Encoding.TYPE not in d:
+            return d
+        return BaseSerialization.deserialize(d)
+
     watchers = var.get("watchers", [])
     return Asset(
         name=var["name"],
@@ -310,7 +336,14 @@ def decode_asset(var: dict[str, Any]):
         group=var["group"],
         extra=var["extra"],
         watchers=[
-            SerializedAssetWatcher(name=watcher["name"], 
trigger=watcher["trigger"]) for watcher in watchers
+            SerializedAssetWatcher(
+                name=watcher["name"],
+                trigger={
+                    "classpath": watcher["trigger"]["classpath"],
+                    "kwargs": 
_smart_decode_trigger_kwargs(watcher["trigger"]["kwargs"]),
+                },
+            )
+            for watcher in watchers
         ],
     )
 
diff --git a/airflow-core/tests/unit/api_fastapi/conftest.py 
b/airflow-core/tests/unit/api_fastapi/conftest.py
index eb98893d45b..b39f4c4f743 100644
--- a/airflow-core/tests/unit/api_fastapi/conftest.py
+++ b/airflow-core/tests/unit/api_fastapi/conftest.py
@@ -27,7 +27,6 @@ from fastapi.testclient import TestClient
 from airflow.api_fastapi.app import create_app
 from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser
 from airflow.models import Connection
-from airflow.models.serialized_dag import SerializedDagModel
 from airflow.providers.standard.operators.empty import EmptyOperator
 
 from tests_common.test_utils.config import conf_vars
@@ -141,7 +140,7 @@ def configure_git_connection_for_dag_bundle(session):
 
 
 @pytest.fixture
-def make_dag_with_multiple_versions(dag_maker, 
configure_git_connection_for_dag_bundle):
+def make_dag_with_multiple_versions(dag_maker, 
configure_git_connection_for_dag_bundle, session):
     """
     Create DAG with multiple versions
 
@@ -151,17 +150,19 @@ def make_dag_with_multiple_versions(dag_maker, 
configure_git_connection_for_dag_
     """
     dag_id = "dag_with_multiple_versions"
     for version_number in range(1, 4):
-        with dag_maker(dag_id) as dag:
+        with dag_maker(
+            dag_id,
+            session=session,
+            bundle_version=f"some_commit_hash{version_number}",
+        ):
             for task_number in range(version_number):
                 EmptyOperator(task_id=f"task{task_number + 1}")
-        SerializedDagModel.write_dag(
-            dag, bundle_name="dag_maker", 
bundle_version=f"some_commit_hash{version_number}"
-        )
         dag_maker.create_dagrun(
             run_id=f"run{version_number}",
             logical_date=datetime.datetime(2020, 1, version_number, 
tzinfo=datetime.timezone.utc),
+            session=session,
         )
-        dag.sync_to_db()
+        session.commit()
 
 
 @pytest.fixture(scope="module")
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py
 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py
index ced2530e2f3..f4b34f8c548 100644
--- 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py
@@ -20,7 +20,6 @@ from unittest import mock
 
 import pytest
 
-from airflow.models.serialized_dag import SerializedDagModel
 from airflow.providers.standard.operators.empty import EmptyOperator
 
 from tests_common.test_utils.db import clear_db_dags, clear_db_serialized_dags
@@ -35,16 +34,11 @@ class TestDagVersionEndpoint:
         clear_db_serialized_dags()
 
         with dag_maker(
-            "ANOTHER_DAG_ID",
-        ) as dag:
+            dag_id="ANOTHER_DAG_ID", bundle_version="some_commit_hash", 
bundle_name="another_bundle_name"
+        ):
             EmptyOperator(task_id="task_1")
             EmptyOperator(task_id="task_2")
 
-        dag.sync_to_db()
-        SerializedDagModel.write_dag(
-            dag, bundle_name="another_bundle_name", 
bundle_version="some_commit_hash"
-        )
-
 
 class TestGetDagVersion(TestDagVersionEndpoint):
     @pytest.mark.parametrize(
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 95631484087..cc2b1baa64a 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 import operator
 from datetime import datetime
 from unittest import mock
+from uuid import uuid4
 
 import pytest
 import uuid6
@@ -37,7 +38,13 @@ from airflow.sdk import TaskGroup
 from airflow.utils import timezone
 from airflow.utils.state import State, TaskInstanceState, TerminalTIState
 
-from tests_common.test_utils.db import clear_db_assets, clear_db_runs, 
clear_rendered_ti_fields
+from tests_common.test_utils.db import (
+    clear_db_assets,
+    clear_db_dags,
+    clear_db_runs,
+    clear_db_serialized_dags,
+    clear_rendered_ti_fields,
+)
 
 pytestmark = pytest.mark.db_test
 
@@ -114,9 +121,13 @@ def test_id_matches_sub_claim(client, session, 
create_task_instance):
 class TestTIRunState:
     def setup_method(self):
         clear_db_runs()
+        clear_db_serialized_dags()
+        clear_db_dags()
 
     def teardown_method(self):
         clear_db_runs()
+        clear_db_serialized_dags()
+        clear_db_dags()
 
     @pytest.mark.parametrize(
         "max_tries, should_retry",
@@ -147,6 +158,7 @@ class TestTIRunState:
             state=State.QUEUED,
             session=session,
             start_date=instant,
+            dag_id=str(uuid4()),
         )
         ti.max_tries = max_tries
         session.commit()
@@ -165,7 +177,7 @@ class TestTIRunState:
         assert response.status_code == 200
         assert response.json() == {
             "dag_run": {
-                "dag_id": "dag",
+                "dag_id": ti.dag_id,
                 "run_id": "test",
                 "clear_number": 0,
                 "logical_date": instant_str,
@@ -179,7 +191,7 @@ class TestTIRunState:
                 "consumed_asset_events": [],
             },
             "task_reschedule_count": 0,
-            "upstream_map_indexes": None,
+            "upstream_map_indexes": {},
             "max_tries": max_tries,
             "should_retry": should_retry,
             "variables": [],
@@ -235,6 +247,7 @@ class TestTIRunState:
             state=State.QUEUED,
             session=session,
             start_date=instant,
+            dag_id=str(uuid4()),
         )
 
         ti.next_method = "execute_complete"
@@ -258,7 +271,7 @@ class TestTIRunState:
         assert response.json() == {
             "dag_run": mock.ANY,
             "task_reschedule_count": 0,
-            "upstream_map_indexes": None,
+            "upstream_map_indexes": {},
             "max_tries": 0,
             "should_retry": False,
             "variables": [],
@@ -282,6 +295,7 @@ class TestTIRunState:
             state=State.QUEUED,
             session=session,
             start_date=orig_task_start_time,
+            dag_id=str(uuid4()),
         )
 
         ti.start_date = orig_task_start_time
@@ -320,7 +334,7 @@ class TestTIRunState:
         assert response.json() == {
             "dag_run": mock.ANY,
             "task_reschedule_count": 0,
-            "upstream_map_indexes": None,
+            "upstream_map_indexes": {},
             "max_tries": 0,
             "should_retry": False,
             "variables": [],
@@ -385,6 +399,7 @@ class TestTIRunState:
             state=State.RUNNING,
             session=session,
             start_date=instant,
+            dag_id=str(uuid4()),
         )
         session.commit()
 
diff --git 
a/airflow-core/tests/unit/dag_processing/bundles/test_dag_bundle_manager.py 
b/airflow-core/tests/unit/dag_processing/bundles/test_dag_bundle_manager.py
index f61f6ff4896..9f20ed49f8d 100644
--- a/airflow-core/tests/unit/dag_processing/bundles/test_dag_bundle_manager.py
+++ b/airflow-core/tests/unit/dag_processing/bundles/test_dag_bundle_manager.py
@@ -163,8 +163,6 @@ def test_sync_bundles_to_db(clear_db, dag_maker):
     # Create DAG version with 'my-test-bundle'
     with dag_maker(dag_id="test_dag", schedule=None):
         EmptyOperator(task_id="mytask")
-    with create_session() as session:
-        session.add(DagVersion(dag_id="test_dag", version_number=1, 
bundle_name="my-test-bundle"))
 
     # simulate bundle config change (now 'dags-folder' is active, 
'my-test-bundle' becomes inactive)
     manager = DagBundlesManager()
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 2f0d71ac89f..c1f926cfdbe 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -4115,7 +4115,7 @@ class TestSchedulerJob:
         # Test that custom_task has no Operator Links (after de-serialization) 
in the Scheduling Loop
         assert not custom_task.operator_extra_links
 
-    def test_scheduler_create_dag_runs_does_not_raise_error(self, caplog, 
dag_maker):
+    def 
test_scheduler_create_dag_runs_does_not_raise_error_when_no_serdag(self, 
caplog, dag_maker):
         """
         Test that scheduler._create_dag_runs does not raise an error when the 
DAG does not exist
         in serialized_dag table
@@ -4137,11 +4137,19 @@ class TestSchedulerJob:
                 logger="airflow.jobs.scheduler_job_runner",
             ),
         ):
+            self._clear_serdags(dag_id=dag_maker.dag.dag_id, session=session)
             self.job_runner._create_dag_runs([dag_maker.dag_model], session)
             assert caplog.messages == [
                 "DAG 'test_scheduler_create_dag_runs_does_not_raise_error' not 
found in serialized_dag table",
             ]
 
+    def _clear_serdags(self, dag_id, session):
+        SDM = SerializedDagModel
+        sdms = session.scalars(select(SDM).where(SDM.dag_id == dag_id))
+        for sdm in sdms:
+            session.delete(sdm)
+        session.commit()
+
     def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, 
dag_maker, testing_dag_bundle):
         """
         Test that externally triggered Dag Runs should not affect (by 
skipping) next
diff --git a/airflow-core/tests/unit/models/test_cleartasks.py 
b/airflow-core/tests/unit/models/test_cleartasks.py
index d4a622d8c4e..9ece8824619 100644
--- a/airflow-core/tests/unit/models/test_cleartasks.py
+++ b/airflow-core/tests/unit/models/test_cleartasks.py
@@ -700,14 +700,18 @@ class TestClearTasks:
         ti1, ti2 = sorted(dr.get_task_instances(session=session), key=lambda 
ti: ti.task_id)
         ti1.task = op1
         ti2.task = op2
-
-        session.get(TaskInstance, ti2.id).try_number += 1
+        ti2.refresh_from_db(session=session)
+        ti2.try_number += 1
         session.commit()
-        ti2.run(session=session)
+
         # Dependency not met
         assert ti2.try_number == 1
         assert ti2.max_tries == 1
 
+        ti1.refresh_from_db(session=session)
+        assert ti1.max_tries == 0
+        assert ti1.try_number == 0
+
         op2.clear(upstream=True, session=session)
         ti1.refresh_from_db(session)
         ti2.refresh_from_db(session)
@@ -716,14 +720,9 @@ class TestClearTasks:
         # max tries will be set to retries + curr try number == 1 + 1 == 2
         assert ti2.max_tries == 2
 
-        ti1.try_number += 1
-        session.merge(ti1)
-        session.commit()
-
-        ti1.run(session=session)
         ti1.refresh_from_db(session)
         ti2.refresh_from_db(session)
-        assert ti1.try_number == 1
+        assert ti1.try_number == 0
 
         ti2 = _get_ti(ti2)
         ti2.try_number += 1
diff --git a/airflow-core/tests/unit/models/test_dagbag.py 
b/airflow-core/tests/unit/models/test_dagbag.py
index a9ba6669af9..d72fe10f7d6 100644
--- a/airflow-core/tests/unit/models/test_dagbag.py
+++ b/airflow-core/tests/unit/models/test_dagbag.py
@@ -32,6 +32,7 @@ from unittest.mock import patch
 
 import pytest
 import time_machine
+from sqlalchemy import select
 
 import airflow.example_dags
 from airflow import settings
@@ -480,7 +481,7 @@ class TestDagBag:
         assert dag_id == dag.dag_id
         assert dagbag.process_file_calls == 2
 
-    def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker, 
tmp_path):
+    def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker, 
tmp_path, session):
         """
         Test that if a DAG does not exist in serialized_dag table (as the DAG 
file was removed),
         remove dags from the DagBag
@@ -493,14 +494,31 @@ class TestDagBag:
             start_date=tz.datetime(2021, 10, 12),
         ) as dag:
             EmptyOperator(task_id="task_1")
-        dag_maker.create_dagrun()
+
         dagbag = DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False, read_dags_from_db=True)
         dagbag.dags = {dag.dag_id: 
SerializedDAG.from_dict(SerializedDAG.to_dict(dag))}
         dagbag.dags_last_fetched = {dag.dag_id: (tz.utcnow() - 
timedelta(minutes=2))}
         dagbag.dags_hash = {dag.dag_id: mock.ANY}
 
+        # observe we have serdag and dag is in dagbag
+        assert SerializedDagModel.has_dag(dag.dag_id) is True
+        assert dagbag.get_dag(dag.dag_id) is not None
+
+        # now delete serdags for this dag
+        SDM = SerializedDagModel
+        sdms = session.scalars(select(SDM).where(SDM.dag_id == dag.dag_id))
+        for sdm in sdms:
+            session.delete(sdm)
+        session.commit()
+
+        # first, confirm that serdags are gone for this dag
         assert SerializedDagModel.has_dag(dag.dag_id) is False
 
+        # now see the dag is still in dagbag
+        assert dagbag.get_dag(dag.dag_id) is not None
+
+        # but, let's recreate the dagbag and see if the dag will be there
+        dagbag = DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False, read_dags_from_db=True)
         assert dagbag.get_dag(dag.dag_id) is None
         assert dag.dag_id not in dagbag.dags
         assert dag.dag_id not in dagbag.dags_last_fetched
diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py 
b/airflow-core/tests/unit/models/test_serialized_dag.py
index 0667701eaa0..a962311f0ec 100644
--- a/airflow-core/tests/unit/models/test_serialized_dag.py
+++ b/airflow-core/tests/unit/models/test_serialized_dag.py
@@ -484,22 +484,26 @@ class TestSerializedDagModel:
 
         db.clear_db_assets()
 
-    @pytest.mark.parametrize("min_update_interval", [0, 10])
-    @mock.patch.object(DagVersion, "get_latest_version")
-    def test_min_update_interval_is_respected(
-        self, mock_dv_get_latest_version, min_update_interval, dag_maker
-    ):
-        mock_dv_get_latest_version.return_value = None
+    @pytest.mark.parametrize(
+        "provide_interval, new_task, should_write",
+        [
+            (True, True, False),
+            (True, False, False),
+            (False, True, True),
+            (False, False, False),
+        ],
+    )
+    def test_min_update_interval_is_respected(self, provide_interval, 
new_task, should_write, dag_maker):
+        min_update_interval = 10 if provide_interval else 0
         with dag_maker("dag1") as dag:
             PythonOperator(task_id="task1", python_callable=lambda: None)
-        dag.sync_to_db()
-        SDM.write_dag(dag, bundle_name="testing")
-        # new task
-        PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag)
-        SDM.write_dag(dag, bundle_name="testing", 
min_update_interval=min_update_interval)
-        if min_update_interval:
-            # Because min_update_interval is 10, DagVersion.get_latest_version 
would
-            # be called only once:
-            mock_dv_get_latest_version.assert_called_once()
-        else:
-            assert mock_dv_get_latest_version.call_count == 2
+
+        if new_task:
+            PythonOperator(task_id="task2", python_callable=lambda: None, 
dag=dag)
+
+        did_write = SDM.write_dag(
+            dag,
+            bundle_name="testing",
+            min_update_interval=min_update_interval,
+        )
+        assert did_write is should_write
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index 8f9350d0530..54290a00f5d 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -4719,7 +4719,7 @@ class TestMappedTaskInstanceReceiveValue:
 
         known_versions = []
 
-        with dag_maker(dag_id="test", session=session) as dag:
+        with dag_maker(dag_id="test_89eug7u6f7y", session=session) as dag:
 
             @dag.task
             def show(value, *, ti):
@@ -4727,9 +4727,8 @@ class TestMappedTaskInstanceReceiveValue:
                 known_versions.append(ti.dag_version_id)
 
             show.expand(value=[1, 2, 3])
-        # ensure that there is a dag_version record in the db
-        dag_version = session.merge(DagVersion(dag_id="test", 
bundle_name="test"))
-        session.commit()
+        # get the dag version for the dag
+        dag_version = 
session.scalar(select(DagVersion).where(DagVersion.dag_id == dag.dag_id))
         dag_maker.create_dagrun(session=session)
         task = dag.get_task("show")
         for ti in session.scalars(select(TI)):
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index f3052e5f935..145e2879308 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -894,7 +894,6 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
         def __exit__(self, type, value, traceback):
             from airflow.configuration import conf
             from airflow.models import DagModel
-            from airflow.models.serialized_dag import SerializedDagModel
 
             dag = self.dag
             dag.__exit__(type, value, traceback)
@@ -903,7 +902,7 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
 
             dag.clear(session=self.session)
             if AIRFLOW_V_3_0_PLUS:
-                dag.bulk_write_to_db(self.bundle_name, None, [dag], 
session=self.session)
+                dag.bulk_write_to_db(self.bundle_name, self.bundle_version, 
[dag], session=self.session)
             else:
                 dag.sync_to_db(session=self.session)
 
@@ -916,46 +915,45 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
                 security_manager.sync_perm_for_dag(dag.dag_id, 
dag.access_control)
             self.dag_model = self.session.get(DagModel, dag.dag_id)
 
-            if self.want_serialized:
-                self.serialized_model = SerializedDagModel(dag)
-                sdm = self.session.scalar(
-                    select(SerializedDagModel).where(
-                        SerializedDagModel.dag_id == dag.dag_id,
-                        SerializedDagModel.dag_hash == 
self.serialized_model.dag_hash,
-                    )
-                )
-
-                if AIRFLOW_V_3_0_PLUS and self.serialized_model != sdm:
-                    from airflow.models.dag_version import DagVersion
-                    from airflow.models.dagcode import DagCode
+            self._make_serdag(dag)
+            self._bag_dag_compat(self.dag)
 
-                    dagv = DagVersion.write_dag(
-                        dag_id=dag.dag_id,
-                        bundle_name=self.dag_model.bundle_name,
-                        bundle_version=self.dag_model.bundle_version,
-                        session=self.session,
-                    )
-                    self.session.add(dagv)
-                    self.session.flush()
-                    dag_code = DagCode(dagv, dag.fileloc, "Source")
-                    self.session.merge(dag_code)
-                    self.serialized_model.dag_version = dagv
-                    if self.want_activate_assets:
-                        self._activate_assets()
-                if sdm:
-                    sdm._SerializedDagModel__data_cache = (
-                        self.serialized_model._SerializedDagModel__data_cache
-                    )
-                    sdm._data = self.serialized_model._data
-                    self.serialized_model = sdm
-                else:
-                    self.session.merge(self.serialized_model)
-                serialized_dag = self._serialized_dag()
-                self._bag_dag_compat(serialized_dag)
+        def _make_serdag(self, dag):
+            from airflow.models.serialized_dag import SerializedDagModel
 
+            self.serialized_model = SerializedDagModel(dag)
+            sdm = self.session.scalar(
+                select(SerializedDagModel).where(
+                    SerializedDagModel.dag_id == dag.dag_id,
+                    SerializedDagModel.dag_hash == 
self.serialized_model.dag_hash,
+                )
+            )
+            if AIRFLOW_V_3_0_PLUS and self.serialized_model != sdm:
+                from airflow.models.dag_version import DagVersion
+                from airflow.models.dagcode import DagCode
+
+                dagv = DagVersion.write_dag(
+                    dag_id=dag.dag_id,
+                    bundle_name=self.dag_model.bundle_name,
+                    bundle_version=self.dag_model.bundle_version,
+                    session=self.session,
+                )
+                self.session.add(dagv)
                 self.session.flush()
+                dag_code = DagCode(dagv, dag.fileloc, "Source")
+                self.session.merge(dag_code)
+                self.serialized_model.dag_version = dagv
+                if self.want_activate_assets:
+                    self._activate_assets()
+            if sdm:
+                sdm._SerializedDagModel__data_cache = 
self.serialized_model._SerializedDagModel__data_cache
+                sdm._data = self.serialized_model._data
+                self.serialized_model = sdm
             else:
-                self._bag_dag_compat(self.dag)
+                self.session.merge(self.serialized_model)
+            serialized_dag = self._serialized_dag()
+            self._bag_dag_compat(serialized_dag)
+            self.session.flush()
 
         def create_dagrun(self, *, logical_date=NOTSET, **kwargs):
             from airflow.utils import timezone
@@ -1074,6 +1072,7 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
             fileloc=None,
             relative_fileloc=None,
             bundle_name=None,
+            bundle_version=None,
             session=None,
             **kwargs,
         ):
@@ -1108,6 +1107,7 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
             self.want_serialized = serialized
             self.want_activate_assets = activate_assets
             self.bundle_name = bundle_name or "dag_maker"
+            self.bundle_version = bundle_version
             if AIRFLOW_V_3_0_PLUS:
                 from airflow.models.dagbundle import DagBundleModel
 
diff --git a/providers/docker/tests/unit/docker/decorators/test_docker.py 
b/providers/docker/tests/unit/docker/decorators/test_docker.py
index 03478dfb728..dba160afab9 100644
--- a/providers/docker/tests/unit/docker/decorators/test_docker.py
+++ b/providers/docker/tests/unit/docker/decorators/test_docker.py
@@ -47,35 +47,37 @@ CLOUDPICKLE_MARKER = pytest.mark.skipif(not 
CLOUDPICKLE_INSTALLED, reason="`clou
 
 
 class TestDockerDecorator:
-    def test_basic_docker_operator(self, dag_maker):
+    def test_basic_docker_operator(self, dag_maker, session):
         @task.docker(image="python:3.9-slim", auto_remove="force")
         def f():
             import random
 
             return [random.random() for _ in range(100)]
 
-        with dag_maker():
+        with dag_maker(session=session):
             ret = f()
-
-        dr = dag_maker.create_dagrun()
-        ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date)
-        ti = dr.get_task_instances()[0]
+        session.commit()
+        dr = dag_maker.create_dagrun(session=session)
+        session.expunge_all()
+        ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, 
session=session)
+        ti = dr.get_task_instances(session=session)[0]
         assert len(ti.xcom_pull()) == 100
 
-    def test_basic_docker_operator_with_param(self, dag_maker):
+    def test_basic_docker_operator_with_param(self, dag_maker, session):
         @task.docker(image="python:3.9-slim", auto_remove="force")
         def f(num_results):
             import random
 
             return [random.random() for _ in range(num_results)]
 
-        with dag_maker():
+        with dag_maker(session=session):
             ret = f(50)
 
-        dr = dag_maker.create_dagrun()
+        dr = dag_maker.create_dagrun(session=session)
+        session.expunge_all()
         ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date)
-        ti = dr.get_task_instances()[0]
-        result = ti.xcom_pull()
+        ti = dr.get_task_instances(session=session)[0]
+        result = ti.xcom_pull(session=session)
         assert isinstance(result, list)
         assert len(result) == 50
 
@@ -92,35 +94,39 @@ class TestDockerDecorator:
         rendered = ti.render_templates()
         assert rendered.container_name == f"python_{dr.dag_id}"
 
-    def test_basic_docker_operator_multiple_output(self, dag_maker):
+    def test_basic_docker_operator_multiple_output(self, dag_maker, session):
         @task.docker(image="python:3.9-slim", multiple_outputs=True, 
auto_remove="force")
         def return_dict(number: int):
             return {"number": number + 1, "43": 43}
 
         test_number = 10
-        with dag_maker():
+        with dag_maker(session=session):
             ret = return_dict(test_number)
 
-        dr = dag_maker.create_dagrun()
-        ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date)
+        dr = dag_maker.create_dagrun(session=session)
+        session.expunge_all()
+
+        ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, 
session=session)
 
-        ti = dr.get_task_instances()[0]
-        assert ti.xcom_pull(key="number") == test_number + 1
-        assert ti.xcom_pull(key="43") == 43
-        assert ti.xcom_pull() == {"number": test_number + 1, "43": 43}
+        ti = dr.get_task_instances(session=session)[0]
+        assert ti.xcom_pull(key="number", session=session) == test_number + 1
+        assert ti.xcom_pull(key="43", session=session) == 43
+        assert ti.xcom_pull(session=session) == {"number": test_number + 1, 
"43": 43}
 
-    def test_no_return(self, dag_maker):
+    def test_no_return(self, dag_maker, session):
         @task.docker(image="python:3.9-slim", auto_remove="force")
         def f():
             pass
 
-        with dag_maker():
+        with dag_maker(session=session):
             ret = f()
 
-        dr = dag_maker.create_dagrun()
-        ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date)
-        ti = dr.get_task_instances()[0]
-        assert ti.xcom_pull() is None
+        dr = dag_maker.create_dagrun(session=session)
+        session.expunge_all()
+
+        ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, 
session=session)
+        ti = dr.get_task_instances(session=session)[0]
+        assert ti.xcom_pull(session=session) is None
 
     def test_call_decorated_multiple_times(self):
         """Test calling decorated function 21 times in a DAG"""
@@ -159,21 +165,22 @@ class TestDockerDecorator:
             ({"skip_on_exit_code": (100,)}, 101, TaskInstanceState.FAILED),
         ],
     )
-    def test_skip_docker_operator(self, kwargs, actual_exit_code, 
expected_state, dag_maker):
+    def test_skip_docker_operator(self, kwargs, actual_exit_code, 
expected_state, dag_maker, session):
         @task.docker(image="python:3.9-slim", auto_remove="force", **kwargs)
         def f(exit_code):
             raise SystemExit(exit_code)
 
-        with dag_maker():
+        with dag_maker(session=session):
             ret = f(actual_exit_code)
 
-        dr = dag_maker.create_dagrun()
+        dr = dag_maker.create_dagrun(session=session)
+        session.expunge_all()
         if expected_state == TaskInstanceState.FAILED:
             with pytest.raises(AirflowException):
-                ret.operator.run(start_date=dr.logical_date, 
end_date=dr.logical_date)
+                ret.operator.run(start_date=dr.logical_date, 
end_date=dr.logical_date, session=session)
         else:
-            ret.operator.run(start_date=dr.logical_date, 
end_date=dr.logical_date)
-            ti = dr.get_task_instances()[0]
+            ret.operator.run(start_date=dr.logical_date, 
end_date=dr.logical_date, session=session)
+            ti = dr.get_task_instances(session=session)[0]
             assert ti.state == expected_state
 
     def test_setup_decorator_with_decorated_docker_task(self, dag_maker):
@@ -308,7 +315,7 @@ class TestDockerDecorator:
 
         assert ret.operator.docker_url == "unix://var/run/docker.sock"
 
-    def test_failing_task(self, dag_maker):
+    def test_failing_task(self, dag_maker, session):
         """Test regression #39319
 
         Check the log content of the DockerOperator when the task fails.
@@ -324,13 +331,15 @@ class TestDockerDecorator:
         log_capture_string = StringBuffer()
         ch = logging.StreamHandler(log_capture_string)
         docker_operator_logger.addHandler(ch)
-        with dag_maker():
+        with dag_maker(session=session):
             ret = f()
 
-        dr = dag_maker.create_dagrun()
+        dr = dag_maker.create_dagrun(session=session)
+        session.expunge_all()
+
         with pytest.raises(AirflowException):
-            ret.operator.run(start_date=dr.logical_date, 
end_date=dr.logical_date)
-        ti = dr.get_task_instances()[0]
+            ret.operator.run(start_date=dr.logical_date, 
end_date=dr.logical_date, session=session)
+        ti = dr.get_task_instances(session=session)[0]
         assert ti.state == TaskInstanceState.FAILED
 
         log_content = str(log_capture_string.getvalue())

Reply via email to