This is an automated email from the ASF dual-hosted git repository. dstandish pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new f971232ab4 Add conditional logic for dataset triggering (#37016) f971232ab4 is described below commit f971232ab4a636a1f54a349041a7e22476b8b2dc Author: Daniel Standish <15932138+dstand...@users.noreply.github.com> AuthorDate: Wed Feb 21 11:24:21 2024 -0800 Add conditional logic for dataset triggering (#37016) Add conditional logic for dataset-triggered dags so that we can schedule based on dataset1 OR dataset1. This PR only implements the underlying classes, DatasetAny and DatasetAll. In a followup PR we will add more convenient syntax for this, specifically the | and & symbols, e.g. (dataset1 | dataset2) & dataset3. --------- Co-authored-by: Ankit Chaurasia <8670962+sunank...@users.noreply.github.com> Co-authored-by: Jed Cunningham <66968678+jedcunning...@users.noreply.github.com> Co-authored-by: Wei Lee <weilee...@gmail.com> --- airflow/models/dag.py | 102 +++++++++----- airflow/models/dataset.py | 49 ++++++- airflow/serialization/enums.py | 2 + airflow/serialization/schema.json | 36 ++++- airflow/serialization/serialized_objects.py | 29 +++- airflow/timetables/datasets.py | 32 +++-- tests/cli/commands/test_dag_command.py | 12 +- tests/datasets/test_dataset.py | 196 ++++++++++++++++++++++++++ tests/serialization/test_dag_serialization.py | 14 +- tests/timetables/test_datasets_timetable.py | 1 - 10 files changed, 409 insertions(+), 64 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index dd43568657..237759010a 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -18,7 +18,6 @@ from __future__ import annotations import asyncio -import collections import copy import functools import itertools @@ -31,7 +30,7 @@ import time import traceback import warnings import weakref -from collections import deque +from collections import abc, defaultdict, deque from contextlib import ExitStack from datetime import datetime, timedelta from inspect import signature @@ -99,6 +98,13 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dagcode import DagCode from airflow.models.dagpickle import DagPickle from airflow.models.dagrun import RUN_ID_REGEX, DagRun +from airflow.models.dataset import ( + DatasetAll, + DatasetAny, + DatasetBooleanCondition, + DatasetDagRunQueue, + DatasetModel, +) from airflow.models.param import DagParam, ParamsDict from airflow.models.taskinstance import ( Context, @@ -462,7 +468,7 @@ class DAG(LoggingMixin): on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, doc_md: str | None = None, - params: collections.abc.MutableMapping | None = None, + params: abc.MutableMapping | None = None, access_control: dict | None = None, is_paused_upon_creation: bool | None = None, jinja_environment_kwargs: dict | None = None, @@ -580,25 +586,28 @@ class DAG(LoggingMixin): self.timetable: Timetable self.schedule_interval: ScheduleInterval - self.dataset_triggers: Collection[Dataset] = [] - + self.dataset_triggers: DatasetBooleanCondition | None = None + if isinstance(schedule, (DatasetAll, DatasetAny)): + self.dataset_triggers = schedule if isinstance(schedule, Collection) and not isinstance(schedule, str): from airflow.datasets import Dataset if not all(isinstance(x, Dataset) for x in schedule): raise ValueError("All elements in 'schedule' should be datasets") - self.dataset_triggers = list(schedule) + self.dataset_triggers = DatasetAll(*schedule) elif isinstance(schedule, Timetable): timetable = schedule elif schedule is not NOTSET: schedule_interval = schedule - if self.dataset_triggers: + if isinstance(schedule, DatasetOrTimeSchedule): + self.timetable = schedule + self.dataset_triggers = self.timetable.datasets + self.schedule_interval = self.timetable.summary + elif self.dataset_triggers: self.timetable = DatasetTriggeredTimetable() self.schedule_interval = self.timetable.summary elif timetable: - if isinstance(timetable, DatasetOrTimeSchedule): - self.dataset_triggers = timetable.datasets self.timetable = timetable self.schedule_interval = self.timetable.summary else: @@ -3156,8 +3165,8 @@ class DAG(LoggingMixin): TaskOutletDatasetReference, ) - dag_references = collections.defaultdict(set) - outlet_references = collections.defaultdict(set) + dag_references = defaultdict(set) + outlet_references = defaultdict(set) # We can't use a set here as we want to preserve order outlet_datasets: dict[DatasetModel, None] = {} input_datasets: dict[DatasetModel, None] = {} @@ -3168,12 +3177,13 @@ class DAG(LoggingMixin): # later we'll persist them to the database. for dag in dags: curr_orm_dag = existing_dags.get(dag.dag_id) - if not dag.dataset_triggers: + if dag.dataset_triggers is None: if curr_orm_dag and curr_orm_dag.schedule_dataset_references: curr_orm_dag.schedule_dataset_references = [] - for dataset in dag.dataset_triggers: - dag_references[dag.dag_id].add(dataset.uri) - input_datasets[DatasetModel.from_public(dataset)] = None + else: + for dataset in dag.dataset_triggers.all_datasets().values(): + dag_references[dag.dag_id].add(dataset.uri) + input_datasets[DatasetModel.from_public(dataset)] = None curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references for task in dag.tasks: dataset_outlets = [x for x in task.outlets or [] if isinstance(x, Dataset)] @@ -3229,7 +3239,7 @@ class DAG(LoggingMixin): for obj in dag_refs_stored - dag_refs_needed: session.delete(obj) - existing_task_outlet_refs_dict = collections.defaultdict(set) + existing_task_outlet_refs_dict = defaultdict(set) for dag_id, orm_dag in existing_dags.items(): for todr in orm_dag.task_outlet_dataset_references: existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr) @@ -3512,7 +3522,7 @@ class DagOwnerAttributes(Base): @classmethod def get_all(cls, session) -> dict[str, dict[str, str]]: - dag_links: dict = collections.defaultdict(dict) + dag_links: dict = defaultdict(dict) for obj in session.scalars(select(cls)): dag_links[obj.dag_id].update({obj.owner: obj.link}) return dag_links @@ -3781,23 +3791,43 @@ class DagModel(Base): you should ensure that any scheduling decisions are made in a single transaction -- as soon as the transaction is committed it will be unlocked. """ - from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ - - # these dag ids are triggered by datasets, and they are ready to go. - dataset_triggered_dag_info = { - x.dag_id: (x.first_queued_time, x.last_queued_time) - for x in session.execute( - select( - DagScheduleDatasetReference.dag_id, - func.max(DDRQ.created_at).label("last_queued_time"), - func.min(DDRQ.created_at).label("first_queued_time"), - ) - .join(DagScheduleDatasetReference.queue_records, isouter=True) - .group_by(DagScheduleDatasetReference.dag_id) - .having(func.count() == func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0))) - ) - } - dataset_triggered_dag_ids = set(dataset_triggered_dag_info) + from airflow.models.serialized_dag import SerializedDagModel + + def dag_ready(dag_id: str, cond: DatasetBooleanCondition, statuses: dict) -> bool | None: + # if dag was serialized before 2.9 and we *just* upgraded, + # we may be dealing with old version. In that case, + # just wait for the dag to be reserialized. + try: + return cond.evaluate(statuses) + except AttributeError: + log.warning("dag '%s' has old serialization; skipping dag run creation.", dag_id) + return None + + # this loads all the DDRQ records.... may need to limit num dags + all_records = session.scalars(select(DatasetDagRunQueue)).all() + by_dag = defaultdict(list) + for r in all_records: + by_dag[r.target_dag_id].append(r) + del all_records + dag_statuses = {} + for dag_id, records in by_dag.items(): + dag_statuses[dag_id] = {x.dataset.uri: True for x in records} + ser_dags = session.scalars( + select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) + ).all() + for ser_dag in ser_dags: + dag_id = ser_dag.dag_id + statuses = dag_statuses[dag_id] + if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, statuses=statuses): + del by_dag[dag_id] + del dag_statuses[dag_id] + del dag_statuses + dataset_triggered_dag_info = {} + for dag_id, records in by_dag.items(): + times = sorted(x.created_at for x in records) + dataset_triggered_dag_info[dag_id] = (times[0], times[-1]) + del by_dag + dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys()) if dataset_triggered_dag_ids: exclusion_list = set( session.scalars( @@ -3908,7 +3938,7 @@ def dag( on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, doc_md: str | None = None, - params: collections.abc.MutableMapping | None = None, + params: abc.MutableMapping | None = None, access_control: dict | None = None, is_paused_upon_creation: bool | None = None, jinja_environment_kwargs: dict | None = None, @@ -4030,7 +4060,7 @@ class DagContext: """ - _context_managed_dags: collections.deque[DAG] = deque() + _context_managed_dags: deque[DAG] = deque() autoregistered_dags: set[tuple[DAG, ModuleType]] = set() current_autoregister_module_name: str | None = None diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py index d9dd8e4bb5..bf28777358 100644 --- a/airflow/models/dataset.py +++ b/airflow/models/dataset.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from typing import Callable, Iterable from urllib.parse import urlsplit import sqlalchemy_jsonfield @@ -208,7 +209,7 @@ class DatasetDagRunQueue(Base): dataset_id = Column(Integer, primary_key=True, nullable=False) target_dag_id = Column(StringID(), primary_key=True, nullable=False) created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) - + dataset = relationship("DatasetModel", viewonly=True) __tablename__ = "dataset_dag_run_queue" __table_args__ = ( PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey"), @@ -336,3 +337,49 @@ class DatasetEvent(Base): ]: args.append(f"{attr}={getattr(self, attr)!r}") return f"{self.__class__.__name__}({', '.join(args)})" + + +class DatasetBooleanCondition: + """ + Base class for boolean logic for dataset triggers. + + :meta private: + """ + + agg_func: Callable[[Iterable], bool] + + def __init__(self, *objects) -> None: + self.objects = objects + + def evaluate(self, statuses: dict[str, bool]) -> bool: + return self.agg_func(self.eval_one(x, statuses) for x in self.objects) + + def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses) -> bool: + if isinstance(obj, Dataset): + return statuses.get(obj.uri, False) + return obj.evaluate(statuses=statuses) + + def all_datasets(self) -> dict[str, Dataset]: + uris = {} + for x in self.objects: + if isinstance(x, Dataset): + if x.uri not in uris: + uris[x.uri] = x + else: + # keep the first instance + for k, v in x.all_datasets().items(): + if k not in uris: + uris[k] = v + return uris + + +class DatasetAny(DatasetBooleanCondition): + """Use to combine datasets schedule references in an "and" relationship.""" + + agg_func = any + + +class DatasetAll(DatasetBooleanCondition): + """Use to combine datasets schedule references in an "or" relationship.""" + + agg_func = all diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index 4f95c849c8..963dec580e 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -50,6 +50,8 @@ class DagAttributeTypes(str, Enum): PARAM = "param" XCOM_REF = "xcomref" DATASET = "dataset" + DATASET_ANY = "dataset_any" + DATASET_ALL = "dataset_all" SIMPLE_TASK_INSTANCE = "simple_task_instance" BASE_JOB = "Job" TASK_INSTANCE = "task_instance" diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index ae7121fd14..71ee0c8006 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -81,6 +81,36 @@ ], "additionalProperties": false }, + "typed_dataset_cond": { + "type": "object", + "properties": { + "__type": { + "anyOf": [{ + "type": "string", + "constant": "dataset_or" + }, + { + "type": "string", + "constant": "dataset_and" + } + ] + }, + "__var": { + "type": "array", + "items": { + "anyOf": [ + {"$ref": "#/definitions/typed_dataset"}, + { "$ref": "#/definitions/typed_dataset_cond"} + ] + } + } + }, + "required": [ + "__type", + "__var" + ], + "additionalProperties": false + }, "dict": { "description": "A python dictionary containing values of any type", "type": "object" @@ -119,9 +149,9 @@ ] }, "dataset_triggers": { - "type": "array", - "items": { "$ref": "#/definitions/typed_dataset" } - }, + "$ref": "#/definitions/typed_dataset_cond" + +}, "owner_links": { "type": "object" }, "timetable": { "type": "object", diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 7adddbab10..5e6073233e 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -42,6 +42,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection from airflow.models.dag import DAG, DagModel, create_timetable from airflow.models.dagrun import DagRun +from airflow.models.dataset import DatasetAll, DatasetAny from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key from airflow.models.mappedoperator import MappedOperator from airflow.models.param import Param, ParamsDict @@ -404,6 +405,8 @@ class BaseSerialization: serialized_object[key] = cls.serialize(value) elif key == "timetable" and value is not None: serialized_object[key] = encode_timetable(value) + elif key == "dataset_triggers": + serialized_object[key] = cls.serialize(value) else: value = cls.serialize(value) if isinstance(value, dict) and Encoding.TYPE in value: @@ -497,6 +500,22 @@ class BaseSerialization: return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF) elif isinstance(var, Dataset): return cls._encode({"uri": var.uri, "extra": var.extra}, type_=DAT.DATASET) + elif isinstance(var, DatasetAll): + return cls._encode( + [ + cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models) + for x in var.objects + ], + type_=DAT.DATASET_ALL, + ) + elif isinstance(var, DatasetAny): + return cls._encode( + [ + cls.serialize(x, strict=strict, use_pydantic_models=use_pydantic_models) + for x in var.objects + ], + type_=DAT.DATASET_ANY, + ) elif isinstance(var, SimpleTaskInstance): return cls._encode( cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models), @@ -587,6 +606,10 @@ class BaseSerialization: return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG. elif type_ == DAT.DATASET: return Dataset(**var) + elif type_ == DAT.DATASET_ANY: + return DatasetAny(*(cls.deserialize(x) for x in var)) + elif type_ == DAT.DATASET_ALL: + return DatasetAll(*(cls.deserialize(x) for x in var)) elif type_ == DAT.SIMPLE_TASK_INSTANCE: return SimpleTaskInstance(**cls.deserialize(var)) elif type_ == DAT.CONNECTION: @@ -763,12 +786,14 @@ class DependencyDetector: """Detect dependencies set directly on the DAG object.""" if not dag: return - for x in dag.dataset_triggers: + if not dag.dataset_triggers: + return + for uri in dag.dataset_triggers.all_datasets().keys(): yield DagDependency( source="dataset", target=dag.dag_id, dependency_type="dataset", - dependency_id=x.uri, + dependency_id=uri, ) diff --git a/airflow/timetables/datasets.py b/airflow/timetables/datasets.py index 4904c64e9c..c755df964e 100644 --- a/airflow/timetables/datasets.py +++ b/airflow/timetables/datasets.py @@ -17,28 +17,31 @@ from __future__ import annotations -import collections.abc import typing -import attrs - -from airflow.datasets import Dataset from airflow.exceptions import AirflowTimetableInvalid +from airflow.models.dataset import DatasetAll, DatasetBooleanCondition from airflow.timetables.simple import DatasetTriggeredTimetable as DatasetTriggeredSchedule from airflow.utils.types import DagRunType if typing.TYPE_CHECKING: + from collections.abc import Collection + import pendulum + from airflow.datasets import Dataset from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable class DatasetOrTimeSchedule(DatasetTriggeredSchedule): """Combine time-based scheduling with event-based scheduling.""" - def __init__(self, timetable: Timetable, datasets: collections.abc.Collection[Dataset]) -> None: + def __init__(self, timetable: Timetable, datasets: Collection[Dataset] | DatasetBooleanCondition) -> None: self.timetable = timetable - self.datasets = datasets + if isinstance(datasets, DatasetBooleanCondition): + self.datasets = datasets + else: + self.datasets = DatasetAll(*datasets) self.description = f"Triggered by datasets or {timetable.description}" self.periodic = timetable.periodic @@ -52,24 +55,23 @@ class DatasetOrTimeSchedule(DatasetTriggeredSchedule): from airflow.serialization.serialized_objects import decode_timetable return cls( - timetable=decode_timetable(data["timetable"]), datasets=[Dataset(**d) for d in data["datasets"]] + timetable=decode_timetable(data["timetable"]), + # don't need the datasets after deserialization + # they are already stored on dataset_triggers attr on DAG + # and this is what scheduler looks at + datasets=[], ) def serialize(self) -> dict[str, typing.Any]: from airflow.serialization.serialized_objects import encode_timetable - return { - "timetable": encode_timetable(self.timetable), - "datasets": [attrs.asdict(e) for e in self.datasets], - } + return {"timetable": encode_timetable(self.timetable)} def validate(self) -> None: if isinstance(self.timetable, DatasetTriggeredSchedule): raise AirflowTimetableInvalid("cannot nest dataset timetables") - if not isinstance(self.datasets, collections.abc.Collection) or not all( - isinstance(d, Dataset) for d in self.datasets - ): - raise AirflowTimetableInvalid("all elements in 'event' must be datasets") + if not isinstance(self.datasets, DatasetBooleanCondition): + raise AirflowTimetableInvalid("all elements in 'datasets' must be datasets") @property def summary(self) -> str: diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index 0df2c36f7d..ca47309721 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -392,16 +392,24 @@ class TestCliDags: disable_retry=False, ) - @mock.patch("workday.AfterWorkdayTimetable") + @mock.patch("workday.AfterWorkdayTimetable.get_next_workday") @mock.patch("airflow.models.taskinstance.TaskInstance.dry_run") @mock.patch("airflow.cli.commands.dag_command.DagRun") - def test_backfill_with_custom_timetable(self, mock_dagrun, mock_dry_run, mock_AfterWorkdayTimetable): + def test_backfill_with_custom_timetable(self, mock_dagrun, mock_dry_run, mock_get_next_workday): """ when calling `dags backfill` on dag with custom timetable, the DagRun object should be created with data_intervals. """ + start_date = DEFAULT_DATE + timedelta(days=1) end_date = start_date + timedelta(days=1) + workdays = [ + start_date, + start_date + timedelta(days=1), + start_date + timedelta(days=2), + ] + mock_get_next_workday.side_effect = workdays + cli_args = self.parser.parse_args( [ "dags", diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index 9e9ca99513..dfc8b82ba1 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -18,11 +18,25 @@ from __future__ import annotations import os +from collections import defaultdict import pytest +from sqlalchemy.sql import select from airflow.datasets import Dataset +from airflow.models.dataset import DatasetAll, DatasetAny, DatasetDagRunQueue, DatasetModel +from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator +from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG + + +@pytest.fixture +def clear_datasets(): + from tests.test_utils.db import clear_db_datasets + + clear_db_datasets() + yield + clear_db_datasets() @pytest.mark.parametrize( @@ -54,3 +68,185 @@ def test_fspath(): uri = "s3://example_dataset" dataset = Dataset(uri=uri) assert os.fspath(dataset) == uri + + +@pytest.mark.parametrize( + "inputs, scenario, expected", + [ + # Scenarios for DatasetAny + ((True, True, True), "any", True), + ((True, True, False), "any", True), + ((True, False, True), "any", True), + ((True, False, False), "any", True), + ((False, False, True), "any", True), + ((False, True, False), "any", True), + ((False, True, True), "any", True), + ((False, False, False), "any", False), + # Scenarios for DatasetAll + ((True, True, True), "all", True), + ((True, True, False), "all", False), + ((True, False, True), "all", False), + ((True, False, False), "all", False), + ((False, False, True), "all", False), + ((False, True, False), "all", False), + ((False, True, True), "all", False), + ((False, False, False), "all", False), + ], +) +def test_dataset_logical_conditions_evaluation_and_serialization(inputs, scenario, expected): + class_ = DatasetAny if scenario == "any" else DatasetAll + datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)] + condition = class_(*datasets) + + statuses = {dataset.uri: status for dataset, status in zip(datasets, inputs)} + assert ( + condition.evaluate(statuses) == expected + ), f"Condition evaluation failed for inputs {inputs} and scenario '{scenario}'" + + # Serialize and deserialize the condition to test persistence + serialized = BaseSerialization.serialize(condition) + deserialized = BaseSerialization.deserialize(serialized) + assert deserialized.evaluate(statuses) == expected, "Serialization round-trip failed" + + +@pytest.mark.parametrize( + "status_values, expected_evaluation", + [ + ((False, True, True), False), # DatasetAll requires all conditions to be True, but d1 is False + ((True, True, True), True), # All conditions are True + ((True, False, True), True), # d1 is True, and DatasetAny condition (d2 or d3 being True) is met + ((True, False, False), False), # d1 is True, but neither d2 nor d3 meet the DatasetAny condition + ], +) +def test_nested_dataset_conditions_with_serialization(status_values, expected_evaluation): + # Define datasets + d1 = Dataset(uri="s3://abc/123") + d2 = Dataset(uri="s3://abc/124") + d3 = Dataset(uri="s3://abc/125") + + # Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and d3 + nested_condition = DatasetAll(d1, DatasetAny(d2, d3)) + + statuses = { + d1.uri: status_values[0], + d2.uri: status_values[1], + d3.uri: status_values[2], + } + + assert nested_condition.evaluate(statuses) == expected_evaluation, "Initial evaluation mismatch" + + serialized_condition = BaseSerialization.serialize(nested_condition) + deserialized_condition = BaseSerialization.deserialize(serialized_condition) + + assert ( + deserialized_condition.evaluate(statuses) == expected_evaluation + ), "Post-serialization evaluation mismatch" + + +@pytest.fixture +def create_test_datasets(session): + """Fixture to create test datasets and corresponding models.""" + datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)] + for dataset in datasets: + session.add(DatasetModel(uri=dataset.uri)) + session.commit() + return datasets + + +@pytest.mark.db_test +@pytest.mark.usefixtures("clear_datasets") +def test_dataset_trigger_setup_and_serialization(session, dag_maker, create_test_datasets): + datasets = create_test_datasets + + # Create DAG with dataset triggers + with dag_maker(schedule=DatasetAny(*datasets)) as dag: + EmptyOperator(task_id="hello") + + # Verify dataset triggers are set up correctly + assert isinstance( + dag.dataset_triggers, DatasetAny + ), "DAG dataset triggers should be an instance of DatasetAny" + + # Serialize and deserialize DAG dataset triggers + serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers) + deserialized_trigger = SerializedDAG.deserialize(serialized_trigger) + + # Verify serialization and deserialization integrity + assert isinstance( + deserialized_trigger, DatasetAny + ), "Deserialized trigger should maintain type DatasetAny" + assert ( + deserialized_trigger.objects == dag.dataset_triggers.objects + ), "Deserialized trigger objects should match original" + + +@pytest.mark.db_test +@pytest.mark.usefixtures("clear_datasets") +def test_dataset_dag_run_queue_processing(session, clear_datasets, dag_maker, create_test_datasets): + datasets = create_test_datasets + dataset_models = session.query(DatasetModel).all() + + with dag_maker(schedule=DatasetAny(*datasets)) as dag: + EmptyOperator(task_id="hello") + + # Add DatasetDagRunQueue entries to simulate dataset event processing + for dm in dataset_models: + session.add(DatasetDagRunQueue(dataset_id=dm.id, target_dag_id=dag.dag_id)) + session.commit() + + # Fetch and evaluate dataset triggers for all DAGs affected by dataset events + records = session.scalars(select(DatasetDagRunQueue)).all() + dag_statuses = defaultdict(lambda: defaultdict(bool)) + for record in records: + dag_statuses[record.target_dag_id][record.dataset.uri] = True + + serialized_dags = session.execute( + select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys())) + ).fetchall() + + for (serialized_dag,) in serialized_dags: + dag = SerializedDAG.deserialize(serialized_dag.data) + for dataset_uri, status in dag_statuses[dag.dag_id].items(): + assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG trigger evaluation failed" + + +@pytest.mark.db_test +@pytest.mark.usefixtures("clear_datasets") +def test_dag_with_complex_dataset_triggers(session, dag_maker): + # Create Dataset instances + d1 = Dataset(uri="hello1") + d2 = Dataset(uri="hello2") + + # Create and add DatasetModel instances to the session + dm1 = DatasetModel(uri=d1.uri) + dm2 = DatasetModel(uri=d2.uri) + session.add_all([dm1, dm2]) + session.commit() + + # Setup a DAG with complex dataset triggers (DatasetAny with DatasetAll) + with dag_maker(schedule=DatasetAny(d1, DatasetAll(d2, d1))) as dag: + EmptyOperator(task_id="hello") + + assert isinstance( + dag.dataset_triggers, DatasetAny + ), "DAG's dataset trigger should be an instance of DatasetAny" + assert any( + isinstance(trigger, DatasetAll) for trigger in dag.dataset_triggers.objects + ), "DAG's dataset trigger should include DatasetAll" + + serialized_triggers = SerializedDAG.serialize(dag.dataset_triggers) + + deserialized_triggers = SerializedDAG.deserialize(serialized_triggers) + + assert isinstance( + deserialized_triggers, DatasetAny + ), "Deserialized triggers should be an instance of DatasetAny" + assert any( + isinstance(trigger, DatasetAll) for trigger in deserialized_triggers.objects + ), "Deserialized triggers should include DatasetAll" + + serialized_dag_dict = SerializedDAG.to_dict(dag)["dag"] + assert "dataset_triggers" in serialized_dag_dict, "Serialized DAG should contain 'dataset_triggers'" + assert isinstance( + serialized_dag_dict["dataset_triggers"], dict + ), "Serialized 'dataset_triggers' should be a dict" diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 8a122592fd..2adc956b6f 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -60,6 +60,7 @@ from airflow.sensors.bash import BashSensor from airflow.serialization.enums import Encoding from airflow.serialization.json_schema import load_dag_schema_dict from airflow.serialization.serialized_objects import ( + BaseSerialization, DagDependency, DependencyDetector, SerializedBaseOperator, @@ -212,7 +213,6 @@ serialized_simple_dag_ground_truth = { }, ], "schedule_interval": {"__type": "timedelta", "__var": 86400.0}, - "dataset_triggers": [], "timezone": "UTC", "_access_control": { "__type": "dict", @@ -551,11 +551,17 @@ class TestStringifiedDAGs: "params", "_processor_dags_folder", } + compare_serialization_list = { + "dataset_triggers", + } fields_to_check = dag.get_serialized_fields() - exclusion_list for field in fields_to_check: - assert getattr(serialized_dag, field) == getattr( - dag, field - ), f"{dag.dag_id}.{field} does not match" + actual = getattr(serialized_dag, field) + expected = getattr(dag, field) + if field in compare_serialization_list: + actual = BaseSerialization.serialize(actual) + expected = BaseSerialization.serialize(expected) + assert actual == expected, f"{dag.dag_id}.{field} does not match" # _processor_dags_folder is only populated at serialization time # it's only used when relying on serialized dag to determine a dag's relative path assert dag._processor_dags_folder is None diff --git a/tests/timetables/test_datasets_timetable.py b/tests/timetables/test_datasets_timetable.py index 8e293888ca..ce58c42a6b 100644 --- a/tests/timetables/test_datasets_timetable.py +++ b/tests/timetables/test_datasets_timetable.py @@ -127,7 +127,6 @@ def test_serialization(dataset_timetable: DatasetOrTimeSchedule, monkeypatch: An serialized = dataset_timetable.serialize() assert serialized == { "timetable": "mock_serialized_timetable", - "datasets": [{"uri": "test_dataset", "extra": None}], }