This is an automated email from the ASF dual-hosted git repository. dstandish pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 229d1b2d564 Always create serdag in dagmaker fixture (#50359) 229d1b2d564 is described below commit 229d1b2d56476a0117a451ff9d430976165dda87 Author: Daniel Standish <15932138+dstand...@users.noreply.github.com> AuthorDate: Wed May 14 07:59:23 2025 -0700 Always create serdag in dagmaker fixture (#50359) This PR makes it so dag_maker always creates serialized dag and dagversion objects, along with dag model. It's not really sensible anymore to not have these other objects there; they are always there in production, and increasingly we need them there for code to work right, and their omission can leave some bugs hidden (and some of them are resolved as part of this). Initially, I was going to just remove the option, but, it also controls the type of dag object returned by dag_maker (serdag vs dag), so for now I leave that as is. --------- Co-authored-by: Tzu-ping Chung <uranu...@gmail.com> --- .../airflow/serialization/serialized_objects.py | 41 +++++++++-- airflow-core/tests/unit/api_fastapi/conftest.py | 15 ++-- .../core_api/routes/public/test_dag_versions.py | 10 +-- .../versions/head/test_task_instances.py | 25 +++++-- .../bundles/test_dag_bundle_manager.py | 2 - airflow-core/tests/unit/jobs/test_scheduler_job.py | 10 ++- airflow-core/tests/unit/models/test_cleartasks.py | 17 +++-- airflow-core/tests/unit/models/test_dagbag.py | 22 +++++- .../tests/unit/models/test_serialized_dag.py | 38 +++++----- .../tests/unit/models/test_taskinstance.py | 7 +- devel-common/src/tests_common/pytest_plugin.py | 76 ++++++++++---------- .../tests/unit/docker/decorators/test_docker.py | 81 ++++++++++++---------- 12 files changed, 211 insertions(+), 133 deletions(-) diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index d77d992d739..83a6dbf5cf7 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -230,12 +230,26 @@ class _PriorityWeightStrategyNotRegistered(AirflowException): def _encode_trigger(trigger: BaseEventTrigger | dict): + def _ensure_serialized(d): + """ + Make sure the kwargs dict is JSON-serializable. + + This is done with BaseSerialization logic. A simple check is added to + ensure we don't double-serialize, which is possible when a trigger goes + through multiple serialization layers. + """ + if isinstance(d, dict) and Encoding.TYPE in d: + return d + return BaseSerialization.serialize(d) + if isinstance(trigger, dict): - return trigger - classpath, kwargs = trigger.serialize() + classpath = trigger["classpath"] + kwargs = trigger["kwargs"] + else: + classpath, kwargs = trigger.serialize() return { "classpath": classpath, - "kwargs": kwargs, + "kwargs": {k: _ensure_serialized(v) for k, v in kwargs.items()}, } @@ -303,6 +317,18 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: def decode_asset(var: dict[str, Any]): + def _smart_decode_trigger_kwargs(d): + """ + Slightly clean up kwargs for display. + + This detects one level of BaseSerialization and tries to deserialize the + content, removing some __type __var ugliness when the value is displayed + in UI to the user. + """ + if not isinstance(d, dict) or Encoding.TYPE not in d: + return d + return BaseSerialization.deserialize(d) + watchers = var.get("watchers", []) return Asset( name=var["name"], @@ -310,7 +336,14 @@ def decode_asset(var: dict[str, Any]): group=var["group"], extra=var["extra"], watchers=[ - SerializedAssetWatcher(name=watcher["name"], trigger=watcher["trigger"]) for watcher in watchers + SerializedAssetWatcher( + name=watcher["name"], + trigger={ + "classpath": watcher["trigger"]["classpath"], + "kwargs": _smart_decode_trigger_kwargs(watcher["trigger"]["kwargs"]), + }, + ) + for watcher in watchers ], ) diff --git a/airflow-core/tests/unit/api_fastapi/conftest.py b/airflow-core/tests/unit/api_fastapi/conftest.py index eb98893d45b..b39f4c4f743 100644 --- a/airflow-core/tests/unit/api_fastapi/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/conftest.py @@ -27,7 +27,6 @@ from fastapi.testclient import TestClient from airflow.api_fastapi.app import create_app from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser from airflow.models import Connection -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.empty import EmptyOperator from tests_common.test_utils.config import conf_vars @@ -141,7 +140,7 @@ def configure_git_connection_for_dag_bundle(session): @pytest.fixture -def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_bundle): +def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_bundle, session): """ Create DAG with multiple versions @@ -151,17 +150,19 @@ def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_ """ dag_id = "dag_with_multiple_versions" for version_number in range(1, 4): - with dag_maker(dag_id) as dag: + with dag_maker( + dag_id, + session=session, + bundle_version=f"some_commit_hash{version_number}", + ): for task_number in range(version_number): EmptyOperator(task_id=f"task{task_number + 1}") - SerializedDagModel.write_dag( - dag, bundle_name="dag_maker", bundle_version=f"some_commit_hash{version_number}" - ) dag_maker.create_dagrun( run_id=f"run{version_number}", logical_date=datetime.datetime(2020, 1, version_number, tzinfo=datetime.timezone.utc), + session=session, ) - dag.sync_to_db() + session.commit() @pytest.fixture(scope="module") diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py index ced2530e2f3..f4b34f8c548 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_versions.py @@ -20,7 +20,6 @@ from unittest import mock import pytest -from airflow.models.serialized_dag import SerializedDagModel from airflow.providers.standard.operators.empty import EmptyOperator from tests_common.test_utils.db import clear_db_dags, clear_db_serialized_dags @@ -35,16 +34,11 @@ class TestDagVersionEndpoint: clear_db_serialized_dags() with dag_maker( - "ANOTHER_DAG_ID", - ) as dag: + dag_id="ANOTHER_DAG_ID", bundle_version="some_commit_hash", bundle_name="another_bundle_name" + ): EmptyOperator(task_id="task_1") EmptyOperator(task_id="task_2") - dag.sync_to_db() - SerializedDagModel.write_dag( - dag, bundle_name="another_bundle_name", bundle_version="some_commit_hash" - ) - class TestGetDagVersion(TestDagVersionEndpoint): @pytest.mark.parametrize( diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 95631484087..cc2b1baa64a 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -20,6 +20,7 @@ from __future__ import annotations import operator from datetime import datetime from unittest import mock +from uuid import uuid4 import pytest import uuid6 @@ -37,7 +38,13 @@ from airflow.sdk import TaskGroup from airflow.utils import timezone from airflow.utils.state import State, TaskInstanceState, TerminalTIState -from tests_common.test_utils.db import clear_db_assets, clear_db_runs, clear_rendered_ti_fields +from tests_common.test_utils.db import ( + clear_db_assets, + clear_db_dags, + clear_db_runs, + clear_db_serialized_dags, + clear_rendered_ti_fields, +) pytestmark = pytest.mark.db_test @@ -114,9 +121,13 @@ def test_id_matches_sub_claim(client, session, create_task_instance): class TestTIRunState: def setup_method(self): clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() def teardown_method(self): clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() @pytest.mark.parametrize( "max_tries, should_retry", @@ -147,6 +158,7 @@ class TestTIRunState: state=State.QUEUED, session=session, start_date=instant, + dag_id=str(uuid4()), ) ti.max_tries = max_tries session.commit() @@ -165,7 +177,7 @@ class TestTIRunState: assert response.status_code == 200 assert response.json() == { "dag_run": { - "dag_id": "dag", + "dag_id": ti.dag_id, "run_id": "test", "clear_number": 0, "logical_date": instant_str, @@ -179,7 +191,7 @@ class TestTIRunState: "consumed_asset_events": [], }, "task_reschedule_count": 0, - "upstream_map_indexes": None, + "upstream_map_indexes": {}, "max_tries": max_tries, "should_retry": should_retry, "variables": [], @@ -235,6 +247,7 @@ class TestTIRunState: state=State.QUEUED, session=session, start_date=instant, + dag_id=str(uuid4()), ) ti.next_method = "execute_complete" @@ -258,7 +271,7 @@ class TestTIRunState: assert response.json() == { "dag_run": mock.ANY, "task_reschedule_count": 0, - "upstream_map_indexes": None, + "upstream_map_indexes": {}, "max_tries": 0, "should_retry": False, "variables": [], @@ -282,6 +295,7 @@ class TestTIRunState: state=State.QUEUED, session=session, start_date=orig_task_start_time, + dag_id=str(uuid4()), ) ti.start_date = orig_task_start_time @@ -320,7 +334,7 @@ class TestTIRunState: assert response.json() == { "dag_run": mock.ANY, "task_reschedule_count": 0, - "upstream_map_indexes": None, + "upstream_map_indexes": {}, "max_tries": 0, "should_retry": False, "variables": [], @@ -385,6 +399,7 @@ class TestTIRunState: state=State.RUNNING, session=session, start_date=instant, + dag_id=str(uuid4()), ) session.commit() diff --git a/airflow-core/tests/unit/dag_processing/bundles/test_dag_bundle_manager.py b/airflow-core/tests/unit/dag_processing/bundles/test_dag_bundle_manager.py index f61f6ff4896..9f20ed49f8d 100644 --- a/airflow-core/tests/unit/dag_processing/bundles/test_dag_bundle_manager.py +++ b/airflow-core/tests/unit/dag_processing/bundles/test_dag_bundle_manager.py @@ -163,8 +163,6 @@ def test_sync_bundles_to_db(clear_db, dag_maker): # Create DAG version with 'my-test-bundle' with dag_maker(dag_id="test_dag", schedule=None): EmptyOperator(task_id="mytask") - with create_session() as session: - session.add(DagVersion(dag_id="test_dag", version_number=1, bundle_name="my-test-bundle")) # simulate bundle config change (now 'dags-folder' is active, 'my-test-bundle' becomes inactive) manager = DagBundlesManager() diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 2f0d71ac89f..c1f926cfdbe 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -4115,7 +4115,7 @@ class TestSchedulerJob: # Test that custom_task has no Operator Links (after de-serialization) in the Scheduling Loop assert not custom_task.operator_extra_links - def test_scheduler_create_dag_runs_does_not_raise_error(self, caplog, dag_maker): + def test_scheduler_create_dag_runs_does_not_raise_error_when_no_serdag(self, caplog, dag_maker): """ Test that scheduler._create_dag_runs does not raise an error when the DAG does not exist in serialized_dag table @@ -4137,11 +4137,19 @@ class TestSchedulerJob: logger="airflow.jobs.scheduler_job_runner", ), ): + self._clear_serdags(dag_id=dag_maker.dag.dag_id, session=session) self.job_runner._create_dag_runs([dag_maker.dag_model], session) assert caplog.messages == [ "DAG 'test_scheduler_create_dag_runs_does_not_raise_error' not found in serialized_dag table", ] + def _clear_serdags(self, dag_id, session): + SDM = SerializedDagModel + sdms = session.scalars(select(SDM).where(SDM.dag_id == dag_id)) + for sdm in sdms: + session.delete(sdm) + session.commit() + def test_bulk_write_to_db_external_trigger_dont_skip_scheduled_run(self, dag_maker, testing_dag_bundle): """ Test that externally triggered Dag Runs should not affect (by skipping) next diff --git a/airflow-core/tests/unit/models/test_cleartasks.py b/airflow-core/tests/unit/models/test_cleartasks.py index d4a622d8c4e..9ece8824619 100644 --- a/airflow-core/tests/unit/models/test_cleartasks.py +++ b/airflow-core/tests/unit/models/test_cleartasks.py @@ -700,14 +700,18 @@ class TestClearTasks: ti1, ti2 = sorted(dr.get_task_instances(session=session), key=lambda ti: ti.task_id) ti1.task = op1 ti2.task = op2 - - session.get(TaskInstance, ti2.id).try_number += 1 + ti2.refresh_from_db(session=session) + ti2.try_number += 1 session.commit() - ti2.run(session=session) + # Dependency not met assert ti2.try_number == 1 assert ti2.max_tries == 1 + ti1.refresh_from_db(session=session) + assert ti1.max_tries == 0 + assert ti1.try_number == 0 + op2.clear(upstream=True, session=session) ti1.refresh_from_db(session) ti2.refresh_from_db(session) @@ -716,14 +720,9 @@ class TestClearTasks: # max tries will be set to retries + curr try number == 1 + 1 == 2 assert ti2.max_tries == 2 - ti1.try_number += 1 - session.merge(ti1) - session.commit() - - ti1.run(session=session) ti1.refresh_from_db(session) ti2.refresh_from_db(session) - assert ti1.try_number == 1 + assert ti1.try_number == 0 ti2 = _get_ti(ti2) ti2.try_number += 1 diff --git a/airflow-core/tests/unit/models/test_dagbag.py b/airflow-core/tests/unit/models/test_dagbag.py index a9ba6669af9..d72fe10f7d6 100644 --- a/airflow-core/tests/unit/models/test_dagbag.py +++ b/airflow-core/tests/unit/models/test_dagbag.py @@ -32,6 +32,7 @@ from unittest.mock import patch import pytest import time_machine +from sqlalchemy import select import airflow.example_dags from airflow import settings @@ -480,7 +481,7 @@ class TestDagBag: assert dag_id == dag.dag_id assert dagbag.process_file_calls == 2 - def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker, tmp_path): + def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker, tmp_path, session): """ Test that if a DAG does not exist in serialized_dag table (as the DAG file was removed), remove dags from the DagBag @@ -493,14 +494,31 @@ class TestDagBag: start_date=tz.datetime(2021, 10, 12), ) as dag: EmptyOperator(task_id="task_1") - dag_maker.create_dagrun() + dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False, read_dags_from_db=True) dagbag.dags = {dag.dag_id: SerializedDAG.from_dict(SerializedDAG.to_dict(dag))} dagbag.dags_last_fetched = {dag.dag_id: (tz.utcnow() - timedelta(minutes=2))} dagbag.dags_hash = {dag.dag_id: mock.ANY} + # observe we have serdag and dag is in dagbag + assert SerializedDagModel.has_dag(dag.dag_id) is True + assert dagbag.get_dag(dag.dag_id) is not None + + # now delete serdags for this dag + SDM = SerializedDagModel + sdms = session.scalars(select(SDM).where(SDM.dag_id == dag.dag_id)) + for sdm in sdms: + session.delete(sdm) + session.commit() + + # first, confirm that serdags are gone for this dag assert SerializedDagModel.has_dag(dag.dag_id) is False + # now see the dag is still in dagbag + assert dagbag.get_dag(dag.dag_id) is not None + + # but, let's recreate the dagbag and see if the dag will be there + dagbag = DagBag(dag_folder=os.fspath(tmp_path), include_examples=False, read_dags_from_db=True) assert dagbag.get_dag(dag.dag_id) is None assert dag.dag_id not in dagbag.dags assert dag.dag_id not in dagbag.dags_last_fetched diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py b/airflow-core/tests/unit/models/test_serialized_dag.py index 0667701eaa0..a962311f0ec 100644 --- a/airflow-core/tests/unit/models/test_serialized_dag.py +++ b/airflow-core/tests/unit/models/test_serialized_dag.py @@ -484,22 +484,26 @@ class TestSerializedDagModel: db.clear_db_assets() - @pytest.mark.parametrize("min_update_interval", [0, 10]) - @mock.patch.object(DagVersion, "get_latest_version") - def test_min_update_interval_is_respected( - self, mock_dv_get_latest_version, min_update_interval, dag_maker - ): - mock_dv_get_latest_version.return_value = None + @pytest.mark.parametrize( + "provide_interval, new_task, should_write", + [ + (True, True, False), + (True, False, False), + (False, True, True), + (False, False, False), + ], + ) + def test_min_update_interval_is_respected(self, provide_interval, new_task, should_write, dag_maker): + min_update_interval = 10 if provide_interval else 0 with dag_maker("dag1") as dag: PythonOperator(task_id="task1", python_callable=lambda: None) - dag.sync_to_db() - SDM.write_dag(dag, bundle_name="testing") - # new task - PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag) - SDM.write_dag(dag, bundle_name="testing", min_update_interval=min_update_interval) - if min_update_interval: - # Because min_update_interval is 10, DagVersion.get_latest_version would - # be called only once: - mock_dv_get_latest_version.assert_called_once() - else: - assert mock_dv_get_latest_version.call_count == 2 + + if new_task: + PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag) + + did_write = SDM.write_dag( + dag, + bundle_name="testing", + min_update_interval=min_update_interval, + ) + assert did_write is should_write diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 8f9350d0530..54290a00f5d 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -4719,7 +4719,7 @@ class TestMappedTaskInstanceReceiveValue: known_versions = [] - with dag_maker(dag_id="test", session=session) as dag: + with dag_maker(dag_id="test_89eug7u6f7y", session=session) as dag: @dag.task def show(value, *, ti): @@ -4727,9 +4727,8 @@ class TestMappedTaskInstanceReceiveValue: known_versions.append(ti.dag_version_id) show.expand(value=[1, 2, 3]) - # ensure that there is a dag_version record in the db - dag_version = session.merge(DagVersion(dag_id="test", bundle_name="test")) - session.commit() + # get the dag version for the dag + dag_version = session.scalar(select(DagVersion).where(DagVersion.dag_id == dag.dag_id)) dag_maker.create_dagrun(session=session) task = dag.get_task("show") for ti in session.scalars(select(TI)): diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index f3052e5f935..145e2879308 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -894,7 +894,6 @@ def dag_maker(request) -> Generator[DagMaker, None, None]: def __exit__(self, type, value, traceback): from airflow.configuration import conf from airflow.models import DagModel - from airflow.models.serialized_dag import SerializedDagModel dag = self.dag dag.__exit__(type, value, traceback) @@ -903,7 +902,7 @@ def dag_maker(request) -> Generator[DagMaker, None, None]: dag.clear(session=self.session) if AIRFLOW_V_3_0_PLUS: - dag.bulk_write_to_db(self.bundle_name, None, [dag], session=self.session) + dag.bulk_write_to_db(self.bundle_name, self.bundle_version, [dag], session=self.session) else: dag.sync_to_db(session=self.session) @@ -916,46 +915,45 @@ def dag_maker(request) -> Generator[DagMaker, None, None]: security_manager.sync_perm_for_dag(dag.dag_id, dag.access_control) self.dag_model = self.session.get(DagModel, dag.dag_id) - if self.want_serialized: - self.serialized_model = SerializedDagModel(dag) - sdm = self.session.scalar( - select(SerializedDagModel).where( - SerializedDagModel.dag_id == dag.dag_id, - SerializedDagModel.dag_hash == self.serialized_model.dag_hash, - ) - ) - - if AIRFLOW_V_3_0_PLUS and self.serialized_model != sdm: - from airflow.models.dag_version import DagVersion - from airflow.models.dagcode import DagCode + self._make_serdag(dag) + self._bag_dag_compat(self.dag) - dagv = DagVersion.write_dag( - dag_id=dag.dag_id, - bundle_name=self.dag_model.bundle_name, - bundle_version=self.dag_model.bundle_version, - session=self.session, - ) - self.session.add(dagv) - self.session.flush() - dag_code = DagCode(dagv, dag.fileloc, "Source") - self.session.merge(dag_code) - self.serialized_model.dag_version = dagv - if self.want_activate_assets: - self._activate_assets() - if sdm: - sdm._SerializedDagModel__data_cache = ( - self.serialized_model._SerializedDagModel__data_cache - ) - sdm._data = self.serialized_model._data - self.serialized_model = sdm - else: - self.session.merge(self.serialized_model) - serialized_dag = self._serialized_dag() - self._bag_dag_compat(serialized_dag) + def _make_serdag(self, dag): + from airflow.models.serialized_dag import SerializedDagModel + self.serialized_model = SerializedDagModel(dag) + sdm = self.session.scalar( + select(SerializedDagModel).where( + SerializedDagModel.dag_id == dag.dag_id, + SerializedDagModel.dag_hash == self.serialized_model.dag_hash, + ) + ) + if AIRFLOW_V_3_0_PLUS and self.serialized_model != sdm: + from airflow.models.dag_version import DagVersion + from airflow.models.dagcode import DagCode + + dagv = DagVersion.write_dag( + dag_id=dag.dag_id, + bundle_name=self.dag_model.bundle_name, + bundle_version=self.dag_model.bundle_version, + session=self.session, + ) + self.session.add(dagv) self.session.flush() + dag_code = DagCode(dagv, dag.fileloc, "Source") + self.session.merge(dag_code) + self.serialized_model.dag_version = dagv + if self.want_activate_assets: + self._activate_assets() + if sdm: + sdm._SerializedDagModel__data_cache = self.serialized_model._SerializedDagModel__data_cache + sdm._data = self.serialized_model._data + self.serialized_model = sdm else: - self._bag_dag_compat(self.dag) + self.session.merge(self.serialized_model) + serialized_dag = self._serialized_dag() + self._bag_dag_compat(serialized_dag) + self.session.flush() def create_dagrun(self, *, logical_date=NOTSET, **kwargs): from airflow.utils import timezone @@ -1074,6 +1072,7 @@ def dag_maker(request) -> Generator[DagMaker, None, None]: fileloc=None, relative_fileloc=None, bundle_name=None, + bundle_version=None, session=None, **kwargs, ): @@ -1108,6 +1107,7 @@ def dag_maker(request) -> Generator[DagMaker, None, None]: self.want_serialized = serialized self.want_activate_assets = activate_assets self.bundle_name = bundle_name or "dag_maker" + self.bundle_version = bundle_version if AIRFLOW_V_3_0_PLUS: from airflow.models.dagbundle import DagBundleModel diff --git a/providers/docker/tests/unit/docker/decorators/test_docker.py b/providers/docker/tests/unit/docker/decorators/test_docker.py index 03478dfb728..dba160afab9 100644 --- a/providers/docker/tests/unit/docker/decorators/test_docker.py +++ b/providers/docker/tests/unit/docker/decorators/test_docker.py @@ -47,35 +47,37 @@ CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`clou class TestDockerDecorator: - def test_basic_docker_operator(self, dag_maker): + def test_basic_docker_operator(self, dag_maker, session): @task.docker(image="python:3.9-slim", auto_remove="force") def f(): import random return [random.random() for _ in range(100)] - with dag_maker(): + with dag_maker(session=session): ret = f() - - dr = dag_maker.create_dagrun() - ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) - ti = dr.get_task_instances()[0] + session.commit() + dr = dag_maker.create_dagrun(session=session) + session.expunge_all() + ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, session=session) + ti = dr.get_task_instances(session=session)[0] assert len(ti.xcom_pull()) == 100 - def test_basic_docker_operator_with_param(self, dag_maker): + def test_basic_docker_operator_with_param(self, dag_maker, session): @task.docker(image="python:3.9-slim", auto_remove="force") def f(num_results): import random return [random.random() for _ in range(num_results)] - with dag_maker(): + with dag_maker(session=session): ret = f(50) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) + session.expunge_all() ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) - ti = dr.get_task_instances()[0] - result = ti.xcom_pull() + ti = dr.get_task_instances(session=session)[0] + result = ti.xcom_pull(session=session) assert isinstance(result, list) assert len(result) == 50 @@ -92,35 +94,39 @@ class TestDockerDecorator: rendered = ti.render_templates() assert rendered.container_name == f"python_{dr.dag_id}" - def test_basic_docker_operator_multiple_output(self, dag_maker): + def test_basic_docker_operator_multiple_output(self, dag_maker, session): @task.docker(image="python:3.9-slim", multiple_outputs=True, auto_remove="force") def return_dict(number: int): return {"number": number + 1, "43": 43} test_number = 10 - with dag_maker(): + with dag_maker(session=session): ret = return_dict(test_number) - dr = dag_maker.create_dagrun() - ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) + dr = dag_maker.create_dagrun(session=session) + session.expunge_all() + + ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, session=session) - ti = dr.get_task_instances()[0] - assert ti.xcom_pull(key="number") == test_number + 1 - assert ti.xcom_pull(key="43") == 43 - assert ti.xcom_pull() == {"number": test_number + 1, "43": 43} + ti = dr.get_task_instances(session=session)[0] + assert ti.xcom_pull(key="number", session=session) == test_number + 1 + assert ti.xcom_pull(key="43", session=session) == 43 + assert ti.xcom_pull(session=session) == {"number": test_number + 1, "43": 43} - def test_no_return(self, dag_maker): + def test_no_return(self, dag_maker, session): @task.docker(image="python:3.9-slim", auto_remove="force") def f(): pass - with dag_maker(): + with dag_maker(session=session): ret = f() - dr = dag_maker.create_dagrun() - ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) - ti = dr.get_task_instances()[0] - assert ti.xcom_pull() is None + dr = dag_maker.create_dagrun(session=session) + session.expunge_all() + + ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, session=session) + ti = dr.get_task_instances(session=session)[0] + assert ti.xcom_pull(session=session) is None def test_call_decorated_multiple_times(self): """Test calling decorated function 21 times in a DAG""" @@ -159,21 +165,22 @@ class TestDockerDecorator: ({"skip_on_exit_code": (100,)}, 101, TaskInstanceState.FAILED), ], ) - def test_skip_docker_operator(self, kwargs, actual_exit_code, expected_state, dag_maker): + def test_skip_docker_operator(self, kwargs, actual_exit_code, expected_state, dag_maker, session): @task.docker(image="python:3.9-slim", auto_remove="force", **kwargs) def f(exit_code): raise SystemExit(exit_code) - with dag_maker(): + with dag_maker(session=session): ret = f(actual_exit_code) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) + session.expunge_all() if expected_state == TaskInstanceState.FAILED: with pytest.raises(AirflowException): - ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) + ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, session=session) else: - ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) - ti = dr.get_task_instances()[0] + ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, session=session) + ti = dr.get_task_instances(session=session)[0] assert ti.state == expected_state def test_setup_decorator_with_decorated_docker_task(self, dag_maker): @@ -308,7 +315,7 @@ class TestDockerDecorator: assert ret.operator.docker_url == "unix://var/run/docker.sock" - def test_failing_task(self, dag_maker): + def test_failing_task(self, dag_maker, session): """Test regression #39319 Check the log content of the DockerOperator when the task fails. @@ -324,13 +331,15 @@ class TestDockerDecorator: log_capture_string = StringBuffer() ch = logging.StreamHandler(log_capture_string) docker_operator_logger.addHandler(ch) - with dag_maker(): + with dag_maker(session=session): ret = f() - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) + session.expunge_all() + with pytest.raises(AirflowException): - ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date) - ti = dr.get_task_instances()[0] + ret.operator.run(start_date=dr.logical_date, end_date=dr.logical_date, session=session) + ti = dr.get_task_instances(session=session)[0] assert ti.state == TaskInstanceState.FAILED log_content = str(log_capture_string.getvalue())