This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f971232ab4 Add conditional logic for dataset triggering (#37016)
f971232ab4 is described below

commit f971232ab4a636a1f54a349041a7e22476b8b2dc
Author: Daniel Standish <15932138+dstand...@users.noreply.github.com>
AuthorDate: Wed Feb 21 11:24:21 2024 -0800

    Add conditional logic for dataset triggering (#37016)
    
    Add conditional logic for dataset-triggered dags so that we can schedule 
based on dataset1 OR dataset1.
    
    This PR only implements the underlying classes, DatasetAny and DatasetAll. 
In a followup PR we will add more convenient syntax for this, specifically the 
| and & symbols, e.g. (dataset1 | dataset2) & dataset3.
    
    ---------
    
    Co-authored-by: Ankit Chaurasia <8670962+sunank...@users.noreply.github.com>
    Co-authored-by: Jed Cunningham 
<66968678+jedcunning...@users.noreply.github.com>
    Co-authored-by: Wei Lee <weilee...@gmail.com>
---
 airflow/models/dag.py                         | 102 +++++++++-----
 airflow/models/dataset.py                     |  49 ++++++-
 airflow/serialization/enums.py                |   2 +
 airflow/serialization/schema.json             |  36 ++++-
 airflow/serialization/serialized_objects.py   |  29 +++-
 airflow/timetables/datasets.py                |  32 +++--
 tests/cli/commands/test_dag_command.py        |  12 +-
 tests/datasets/test_dataset.py                | 196 ++++++++++++++++++++++++++
 tests/serialization/test_dag_serialization.py |  14 +-
 tests/timetables/test_datasets_timetable.py   |   1 -
 10 files changed, 409 insertions(+), 64 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index dd43568657..237759010a 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -18,7 +18,6 @@
 from __future__ import annotations
 
 import asyncio
-import collections
 import copy
 import functools
 import itertools
@@ -31,7 +30,7 @@ import time
 import traceback
 import warnings
 import weakref
-from collections import deque
+from collections import abc, defaultdict, deque
 from contextlib import ExitStack
 from datetime import datetime, timedelta
 from inspect import signature
@@ -99,6 +98,13 @@ from airflow.models.baseoperator import BaseOperator
 from airflow.models.dagcode import DagCode
 from airflow.models.dagpickle import DagPickle
 from airflow.models.dagrun import RUN_ID_REGEX, DagRun
+from airflow.models.dataset import (
+    DatasetAll,
+    DatasetAny,
+    DatasetBooleanCondition,
+    DatasetDagRunQueue,
+    DatasetModel,
+)
 from airflow.models.param import DagParam, ParamsDict
 from airflow.models.taskinstance import (
     Context,
@@ -462,7 +468,7 @@ class DAG(LoggingMixin):
         on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
         on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
         doc_md: str | None = None,
-        params: collections.abc.MutableMapping | None = None,
+        params: abc.MutableMapping | None = None,
         access_control: dict | None = None,
         is_paused_upon_creation: bool | None = None,
         jinja_environment_kwargs: dict | None = None,
@@ -580,25 +586,28 @@ class DAG(LoggingMixin):
 
         self.timetable: Timetable
         self.schedule_interval: ScheduleInterval
-        self.dataset_triggers: Collection[Dataset] = []
-
+        self.dataset_triggers: DatasetBooleanCondition | None = None
+        if isinstance(schedule, (DatasetAll, DatasetAny)):
+            self.dataset_triggers = schedule
         if isinstance(schedule, Collection) and not isinstance(schedule, str):
             from airflow.datasets import Dataset
 
             if not all(isinstance(x, Dataset) for x in schedule):
                 raise ValueError("All elements in 'schedule' should be 
datasets")
-            self.dataset_triggers = list(schedule)
+            self.dataset_triggers = DatasetAll(*schedule)
         elif isinstance(schedule, Timetable):
             timetable = schedule
         elif schedule is not NOTSET:
             schedule_interval = schedule
 
-        if self.dataset_triggers:
+        if isinstance(schedule, DatasetOrTimeSchedule):
+            self.timetable = schedule
+            self.dataset_triggers = self.timetable.datasets
+            self.schedule_interval = self.timetable.summary
+        elif self.dataset_triggers:
             self.timetable = DatasetTriggeredTimetable()
             self.schedule_interval = self.timetable.summary
         elif timetable:
-            if isinstance(timetable, DatasetOrTimeSchedule):
-                self.dataset_triggers = timetable.datasets
             self.timetable = timetable
             self.schedule_interval = self.timetable.summary
         else:
@@ -3156,8 +3165,8 @@ class DAG(LoggingMixin):
             TaskOutletDatasetReference,
         )
 
-        dag_references = collections.defaultdict(set)
-        outlet_references = collections.defaultdict(set)
+        dag_references = defaultdict(set)
+        outlet_references = defaultdict(set)
         # We can't use a set here as we want to preserve order
         outlet_datasets: dict[DatasetModel, None] = {}
         input_datasets: dict[DatasetModel, None] = {}
@@ -3168,12 +3177,13 @@ class DAG(LoggingMixin):
         # later we'll persist them to the database.
         for dag in dags:
             curr_orm_dag = existing_dags.get(dag.dag_id)
-            if not dag.dataset_triggers:
+            if dag.dataset_triggers is None:
                 if curr_orm_dag and curr_orm_dag.schedule_dataset_references:
                     curr_orm_dag.schedule_dataset_references = []
-            for dataset in dag.dataset_triggers:
-                dag_references[dag.dag_id].add(dataset.uri)
-                input_datasets[DatasetModel.from_public(dataset)] = None
+            else:
+                for dataset in dag.dataset_triggers.all_datasets().values():
+                    dag_references[dag.dag_id].add(dataset.uri)
+                    input_datasets[DatasetModel.from_public(dataset)] = None
             curr_outlet_references = curr_orm_dag and 
curr_orm_dag.task_outlet_dataset_references
             for task in dag.tasks:
                 dataset_outlets = [x for x in task.outlets or [] if 
isinstance(x, Dataset)]
@@ -3229,7 +3239,7 @@ class DAG(LoggingMixin):
             for obj in dag_refs_stored - dag_refs_needed:
                 session.delete(obj)
 
-        existing_task_outlet_refs_dict = collections.defaultdict(set)
+        existing_task_outlet_refs_dict = defaultdict(set)
         for dag_id, orm_dag in existing_dags.items():
             for todr in orm_dag.task_outlet_dataset_references:
                 existing_task_outlet_refs_dict[(dag_id, 
todr.task_id)].add(todr)
@@ -3512,7 +3522,7 @@ class DagOwnerAttributes(Base):
 
     @classmethod
     def get_all(cls, session) -> dict[str, dict[str, str]]:
-        dag_links: dict = collections.defaultdict(dict)
+        dag_links: dict = defaultdict(dict)
         for obj in session.scalars(select(cls)):
             dag_links[obj.dag_id].update({obj.owner: obj.link})
         return dag_links
@@ -3781,23 +3791,43 @@ class DagModel(Base):
         you should ensure that any scheduling decisions are made in a single 
transaction -- as soon as the
         transaction is committed it will be unlocked.
         """
-        from airflow.models.dataset import DagScheduleDatasetReference, 
DatasetDagRunQueue as DDRQ
-
-        # these dag ids are triggered by datasets, and they are ready to go.
-        dataset_triggered_dag_info = {
-            x.dag_id: (x.first_queued_time, x.last_queued_time)
-            for x in session.execute(
-                select(
-                    DagScheduleDatasetReference.dag_id,
-                    func.max(DDRQ.created_at).label("last_queued_time"),
-                    func.min(DDRQ.created_at).label("first_queued_time"),
-                )
-                .join(DagScheduleDatasetReference.queue_records, isouter=True)
-                .group_by(DagScheduleDatasetReference.dag_id)
-                .having(func.count() == 
func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
-            )
-        }
-        dataset_triggered_dag_ids = set(dataset_triggered_dag_info)
+        from airflow.models.serialized_dag import SerializedDagModel
+
+        def dag_ready(dag_id: str, cond: DatasetBooleanCondition, statuses: 
dict) -> bool | None:
+            # if dag was serialized before 2.9 and we *just* upgraded,
+            # we may be dealing with old version.  In that case,
+            # just wait for the dag to be reserialized.
+            try:
+                return cond.evaluate(statuses)
+            except AttributeError:
+                log.warning("dag '%s' has old serialization; skipping dag run 
creation.", dag_id)
+                return None
+
+        # this loads all the DDRQ records.... may need to limit num dags
+        all_records = session.scalars(select(DatasetDagRunQueue)).all()
+        by_dag = defaultdict(list)
+        for r in all_records:
+            by_dag[r.target_dag_id].append(r)
+        del all_records
+        dag_statuses = {}
+        for dag_id, records in by_dag.items():
+            dag_statuses[dag_id] = {x.dataset.uri: True for x in records}
+        ser_dags = session.scalars(
+            
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
+        ).all()
+        for ser_dag in ser_dags:
+            dag_id = ser_dag.dag_id
+            statuses = dag_statuses[dag_id]
+            if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, 
statuses=statuses):
+                del by_dag[dag_id]
+                del dag_statuses[dag_id]
+        del dag_statuses
+        dataset_triggered_dag_info = {}
+        for dag_id, records in by_dag.items():
+            times = sorted(x.created_at for x in records)
+            dataset_triggered_dag_info[dag_id] = (times[0], times[-1])
+        del by_dag
+        dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
         if dataset_triggered_dag_ids:
             exclusion_list = set(
                 session.scalars(
@@ -3908,7 +3938,7 @@ def dag(
     on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
     on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None,
     doc_md: str | None = None,
-    params: collections.abc.MutableMapping | None = None,
+    params: abc.MutableMapping | None = None,
     access_control: dict | None = None,
     is_paused_upon_creation: bool | None = None,
     jinja_environment_kwargs: dict | None = None,
@@ -4030,7 +4060,7 @@ class DagContext:
 
     """
 
-    _context_managed_dags: collections.deque[DAG] = deque()
+    _context_managed_dags: deque[DAG] = deque()
     autoregistered_dags: set[tuple[DAG, ModuleType]] = set()
     current_autoregister_module_name: str | None = None
 
diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py
index d9dd8e4bb5..bf28777358 100644
--- a/airflow/models/dataset.py
+++ b/airflow/models/dataset.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from typing import Callable, Iterable
 from urllib.parse import urlsplit
 
 import sqlalchemy_jsonfield
@@ -208,7 +209,7 @@ class DatasetDagRunQueue(Base):
     dataset_id = Column(Integer, primary_key=True, nullable=False)
     target_dag_id = Column(StringID(), primary_key=True, nullable=False)
     created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
-
+    dataset = relationship("DatasetModel", viewonly=True)
     __tablename__ = "dataset_dag_run_queue"
     __table_args__ = (
         PrimaryKeyConstraint(dataset_id, target_dag_id, 
name="datasetdagrunqueue_pkey"),
@@ -336,3 +337,49 @@ class DatasetEvent(Base):
         ]:
             args.append(f"{attr}={getattr(self, attr)!r}")
         return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetBooleanCondition:
+    """
+    Base class for boolean logic for dataset triggers.
+
+    :meta private:
+    """
+
+    agg_func: Callable[[Iterable], bool]
+
+    def __init__(self, *objects) -> None:
+        self.objects = objects
+
+    def evaluate(self, statuses: dict[str, bool]) -> bool:
+        return self.agg_func(self.eval_one(x, statuses) for x in self.objects)
+
+    def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses) -> 
bool:
+        if isinstance(obj, Dataset):
+            return statuses.get(obj.uri, False)
+        return obj.evaluate(statuses=statuses)
+
+    def all_datasets(self) -> dict[str, Dataset]:
+        uris = {}
+        for x in self.objects:
+            if isinstance(x, Dataset):
+                if x.uri not in uris:
+                    uris[x.uri] = x
+            else:
+                # keep the first instance
+                for k, v in x.all_datasets().items():
+                    if k not in uris:
+                        uris[k] = v
+        return uris
+
+
+class DatasetAny(DatasetBooleanCondition):
+    """Use to combine datasets schedule references in an "and" relationship."""
+
+    agg_func = any
+
+
+class DatasetAll(DatasetBooleanCondition):
+    """Use to combine datasets schedule references in an "or" relationship."""
+
+    agg_func = all
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index 4f95c849c8..963dec580e 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -50,6 +50,8 @@ class DagAttributeTypes(str, Enum):
     PARAM = "param"
     XCOM_REF = "xcomref"
     DATASET = "dataset"
+    DATASET_ANY = "dataset_any"
+    DATASET_ALL = "dataset_all"
     SIMPLE_TASK_INSTANCE = "simple_task_instance"
     BASE_JOB = "Job"
     TASK_INSTANCE = "task_instance"
diff --git a/airflow/serialization/schema.json 
b/airflow/serialization/schema.json
index ae7121fd14..71ee0c8006 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -81,6 +81,36 @@
       ],
       "additionalProperties": false
     },
+    "typed_dataset_cond": {
+      "type": "object",
+      "properties": {
+        "__type": {
+            "anyOf": [{
+                "type": "string",
+                "constant": "dataset_or"
+            },
+            {
+                "type": "string",
+                "constant": "dataset_and"
+            }
+            ]
+        },
+        "__var": {
+            "type": "array",
+            "items": {
+                "anyOf": [
+                    {"$ref": "#/definitions/typed_dataset"},
+                    {    "$ref": "#/definitions/typed_dataset_cond"}
+                    ]
+            }
+        }
+      },
+    "required": [
+        "__type",
+        "__var"
+        ],
+        "additionalProperties": false
+    },
     "dict": {
       "description": "A python dictionary containing values of any type",
       "type": "object"
@@ -119,9 +149,9 @@
           ]
         },
         "dataset_triggers": {
-            "type": "array",
-            "items": { "$ref": "#/definitions/typed_dataset" }
-        },
+        "$ref": "#/definitions/typed_dataset_cond"
+
+},
         "owner_links": { "type": "object" },
         "timetable": {
           "type": "object",
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 7adddbab10..5e6073233e 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -42,6 +42,7 @@ from airflow.models.baseoperator import BaseOperator
 from airflow.models.connection import Connection
 from airflow.models.dag import DAG, DagModel, create_timetable
 from airflow.models.dagrun import DagRun
+from airflow.models.dataset import DatasetAll, DatasetAny
 from airflow.models.expandinput import EXPAND_INPUT_EMPTY, 
create_expand_input, get_map_type_key
 from airflow.models.mappedoperator import MappedOperator
 from airflow.models.param import Param, ParamsDict
@@ -404,6 +405,8 @@ class BaseSerialization:
                 serialized_object[key] = cls.serialize(value)
             elif key == "timetable" and value is not None:
                 serialized_object[key] = encode_timetable(value)
+            elif key == "dataset_triggers":
+                serialized_object[key] = cls.serialize(value)
             else:
                 value = cls.serialize(value)
                 if isinstance(value, dict) and Encoding.TYPE in value:
@@ -497,6 +500,22 @@ class BaseSerialization:
             return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
         elif isinstance(var, Dataset):
             return cls._encode({"uri": var.uri, "extra": var.extra}, 
type_=DAT.DATASET)
+        elif isinstance(var, DatasetAll):
+            return cls._encode(
+                [
+                    cls.serialize(x, strict=strict, 
use_pydantic_models=use_pydantic_models)
+                    for x in var.objects
+                ],
+                type_=DAT.DATASET_ALL,
+            )
+        elif isinstance(var, DatasetAny):
+            return cls._encode(
+                [
+                    cls.serialize(x, strict=strict, 
use_pydantic_models=use_pydantic_models)
+                    for x in var.objects
+                ],
+                type_=DAT.DATASET_ANY,
+            )
         elif isinstance(var, SimpleTaskInstance):
             return cls._encode(
                 cls.serialize(var.__dict__, strict=strict, 
use_pydantic_models=use_pydantic_models),
@@ -587,6 +606,10 @@ class BaseSerialization:
             return _XComRef(var)  # Delay deserializing XComArg objects until 
we have the entire DAG.
         elif type_ == DAT.DATASET:
             return Dataset(**var)
+        elif type_ == DAT.DATASET_ANY:
+            return DatasetAny(*(cls.deserialize(x) for x in var))
+        elif type_ == DAT.DATASET_ALL:
+            return DatasetAll(*(cls.deserialize(x) for x in var))
         elif type_ == DAT.SIMPLE_TASK_INSTANCE:
             return SimpleTaskInstance(**cls.deserialize(var))
         elif type_ == DAT.CONNECTION:
@@ -763,12 +786,14 @@ class DependencyDetector:
         """Detect dependencies set directly on the DAG object."""
         if not dag:
             return
-        for x in dag.dataset_triggers:
+        if not dag.dataset_triggers:
+            return
+        for uri in dag.dataset_triggers.all_datasets().keys():
             yield DagDependency(
                 source="dataset",
                 target=dag.dag_id,
                 dependency_type="dataset",
-                dependency_id=x.uri,
+                dependency_id=uri,
             )
 
 
diff --git a/airflow/timetables/datasets.py b/airflow/timetables/datasets.py
index 4904c64e9c..c755df964e 100644
--- a/airflow/timetables/datasets.py
+++ b/airflow/timetables/datasets.py
@@ -17,28 +17,31 @@
 
 from __future__ import annotations
 
-import collections.abc
 import typing
 
-import attrs
-
-from airflow.datasets import Dataset
 from airflow.exceptions import AirflowTimetableInvalid
+from airflow.models.dataset import DatasetAll, DatasetBooleanCondition
 from airflow.timetables.simple import DatasetTriggeredTimetable as 
DatasetTriggeredSchedule
 from airflow.utils.types import DagRunType
 
 if typing.TYPE_CHECKING:
+    from collections.abc import Collection
+
     import pendulum
 
+    from airflow.datasets import Dataset
     from airflow.timetables.base import DagRunInfo, DataInterval, 
TimeRestriction, Timetable
 
 
 class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
     """Combine time-based scheduling with event-based scheduling."""
 
-    def __init__(self, timetable: Timetable, datasets: 
collections.abc.Collection[Dataset]) -> None:
+    def __init__(self, timetable: Timetable, datasets: Collection[Dataset] | 
DatasetBooleanCondition) -> None:
         self.timetable = timetable
-        self.datasets = datasets
+        if isinstance(datasets, DatasetBooleanCondition):
+            self.datasets = datasets
+        else:
+            self.datasets = DatasetAll(*datasets)
 
         self.description = f"Triggered by datasets or {timetable.description}"
         self.periodic = timetable.periodic
@@ -52,24 +55,23 @@ class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
         from airflow.serialization.serialized_objects import decode_timetable
 
         return cls(
-            timetable=decode_timetable(data["timetable"]), 
datasets=[Dataset(**d) for d in data["datasets"]]
+            timetable=decode_timetable(data["timetable"]),
+            # don't need the datasets after deserialization
+            # they are already stored on dataset_triggers attr on DAG
+            # and this is what scheduler looks at
+            datasets=[],
         )
 
     def serialize(self) -> dict[str, typing.Any]:
         from airflow.serialization.serialized_objects import encode_timetable
 
-        return {
-            "timetable": encode_timetable(self.timetable),
-            "datasets": [attrs.asdict(e) for e in self.datasets],
-        }
+        return {"timetable": encode_timetable(self.timetable)}
 
     def validate(self) -> None:
         if isinstance(self.timetable, DatasetTriggeredSchedule):
             raise AirflowTimetableInvalid("cannot nest dataset timetables")
-        if not isinstance(self.datasets, collections.abc.Collection) or not 
all(
-            isinstance(d, Dataset) for d in self.datasets
-        ):
-            raise AirflowTimetableInvalid("all elements in 'event' must be 
datasets")
+        if not isinstance(self.datasets, DatasetBooleanCondition):
+            raise AirflowTimetableInvalid("all elements in 'datasets' must be 
datasets")
 
     @property
     def summary(self) -> str:
diff --git a/tests/cli/commands/test_dag_command.py 
b/tests/cli/commands/test_dag_command.py
index 0df2c36f7d..ca47309721 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -392,16 +392,24 @@ class TestCliDags:
             disable_retry=False,
         )
 
-    @mock.patch("workday.AfterWorkdayTimetable")
+    @mock.patch("workday.AfterWorkdayTimetable.get_next_workday")
     @mock.patch("airflow.models.taskinstance.TaskInstance.dry_run")
     @mock.patch("airflow.cli.commands.dag_command.DagRun")
-    def test_backfill_with_custom_timetable(self, mock_dagrun, mock_dry_run, 
mock_AfterWorkdayTimetable):
+    def test_backfill_with_custom_timetable(self, mock_dagrun, mock_dry_run, 
mock_get_next_workday):
         """
         when calling `dags backfill` on dag with custom timetable, the DagRun 
object should be created with
          data_intervals.
         """
+
         start_date = DEFAULT_DATE + timedelta(days=1)
         end_date = start_date + timedelta(days=1)
+        workdays = [
+            start_date,
+            start_date + timedelta(days=1),
+            start_date + timedelta(days=2),
+        ]
+        mock_get_next_workday.side_effect = workdays
+
         cli_args = self.parser.parse_args(
             [
                 "dags",
diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py
index 9e9ca99513..dfc8b82ba1 100644
--- a/tests/datasets/test_dataset.py
+++ b/tests/datasets/test_dataset.py
@@ -18,11 +18,25 @@
 from __future__ import annotations
 
 import os
+from collections import defaultdict
 
 import pytest
+from sqlalchemy.sql import select
 
 from airflow.datasets import Dataset
+from airflow.models.dataset import DatasetAll, DatasetAny, DatasetDagRunQueue, 
DatasetModel
+from airflow.models.serialized_dag import SerializedDagModel
 from airflow.operators.empty import EmptyOperator
+from airflow.serialization.serialized_objects import BaseSerialization, 
SerializedDAG
+
+
+@pytest.fixture
+def clear_datasets():
+    from tests.test_utils.db import clear_db_datasets
+
+    clear_db_datasets()
+    yield
+    clear_db_datasets()
 
 
 @pytest.mark.parametrize(
@@ -54,3 +68,185 @@ def test_fspath():
     uri = "s3://example_dataset"
     dataset = Dataset(uri=uri)
     assert os.fspath(dataset) == uri
+
+
+@pytest.mark.parametrize(
+    "inputs, scenario, expected",
+    [
+        # Scenarios for DatasetAny
+        ((True, True, True), "any", True),
+        ((True, True, False), "any", True),
+        ((True, False, True), "any", True),
+        ((True, False, False), "any", True),
+        ((False, False, True), "any", True),
+        ((False, True, False), "any", True),
+        ((False, True, True), "any", True),
+        ((False, False, False), "any", False),
+        # Scenarios for DatasetAll
+        ((True, True, True), "all", True),
+        ((True, True, False), "all", False),
+        ((True, False, True), "all", False),
+        ((True, False, False), "all", False),
+        ((False, False, True), "all", False),
+        ((False, True, False), "all", False),
+        ((False, True, True), "all", False),
+        ((False, False, False), "all", False),
+    ],
+)
+def test_dataset_logical_conditions_evaluation_and_serialization(inputs, 
scenario, expected):
+    class_ = DatasetAny if scenario == "any" else DatasetAll
+    datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)]
+    condition = class_(*datasets)
+
+    statuses = {dataset.uri: status for dataset, status in zip(datasets, 
inputs)}
+    assert (
+        condition.evaluate(statuses) == expected
+    ), f"Condition evaluation failed for inputs {inputs} and scenario 
'{scenario}'"
+
+    # Serialize and deserialize the condition to test persistence
+    serialized = BaseSerialization.serialize(condition)
+    deserialized = BaseSerialization.deserialize(serialized)
+    assert deserialized.evaluate(statuses) == expected, "Serialization 
round-trip failed"
+
+
+@pytest.mark.parametrize(
+    "status_values, expected_evaluation",
+    [
+        ((False, True, True), False),  # DatasetAll requires all conditions to 
be True, but d1 is False
+        ((True, True, True), True),  # All conditions are True
+        ((True, False, True), True),  # d1 is True, and DatasetAny condition 
(d2 or d3 being True) is met
+        ((True, False, False), False),  # d1 is True, but neither d2 nor d3 
meet the DatasetAny condition
+    ],
+)
+def test_nested_dataset_conditions_with_serialization(status_values, 
expected_evaluation):
+    # Define datasets
+    d1 = Dataset(uri="s3://abc/123")
+    d2 = Dataset(uri="s3://abc/124")
+    d3 = Dataset(uri="s3://abc/125")
+
+    # Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and 
d3
+    nested_condition = DatasetAll(d1, DatasetAny(d2, d3))
+
+    statuses = {
+        d1.uri: status_values[0],
+        d2.uri: status_values[1],
+        d3.uri: status_values[2],
+    }
+
+    assert nested_condition.evaluate(statuses) == expected_evaluation, 
"Initial evaluation mismatch"
+
+    serialized_condition = BaseSerialization.serialize(nested_condition)
+    deserialized_condition = 
BaseSerialization.deserialize(serialized_condition)
+
+    assert (
+        deserialized_condition.evaluate(statuses) == expected_evaluation
+    ), "Post-serialization evaluation mismatch"
+
+
+@pytest.fixture
+def create_test_datasets(session):
+    """Fixture to create test datasets and corresponding models."""
+    datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)]
+    for dataset in datasets:
+        session.add(DatasetModel(uri=dataset.uri))
+    session.commit()
+    return datasets
+
+
+@pytest.mark.db_test
+@pytest.mark.usefixtures("clear_datasets")
+def test_dataset_trigger_setup_and_serialization(session, dag_maker, 
create_test_datasets):
+    datasets = create_test_datasets
+
+    # Create DAG with dataset triggers
+    with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+        EmptyOperator(task_id="hello")
+
+    # Verify dataset triggers are set up correctly
+    assert isinstance(
+        dag.dataset_triggers, DatasetAny
+    ), "DAG dataset triggers should be an instance of DatasetAny"
+
+    # Serialize and deserialize DAG dataset triggers
+    serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers)
+    deserialized_trigger = SerializedDAG.deserialize(serialized_trigger)
+
+    # Verify serialization and deserialization integrity
+    assert isinstance(
+        deserialized_trigger, DatasetAny
+    ), "Deserialized trigger should maintain type DatasetAny"
+    assert (
+        deserialized_trigger.objects == dag.dataset_triggers.objects
+    ), "Deserialized trigger objects should match original"
+
+
+@pytest.mark.db_test
+@pytest.mark.usefixtures("clear_datasets")
+def test_dataset_dag_run_queue_processing(session, clear_datasets, dag_maker, 
create_test_datasets):
+    datasets = create_test_datasets
+    dataset_models = session.query(DatasetModel).all()
+
+    with dag_maker(schedule=DatasetAny(*datasets)) as dag:
+        EmptyOperator(task_id="hello")
+
+    # Add DatasetDagRunQueue entries to simulate dataset event processing
+    for dm in dataset_models:
+        session.add(DatasetDagRunQueue(dataset_id=dm.id, 
target_dag_id=dag.dag_id))
+    session.commit()
+
+    # Fetch and evaluate dataset triggers for all DAGs affected by dataset 
events
+    records = session.scalars(select(DatasetDagRunQueue)).all()
+    dag_statuses = defaultdict(lambda: defaultdict(bool))
+    for record in records:
+        dag_statuses[record.target_dag_id][record.dataset.uri] = True
+
+    serialized_dags = session.execute(
+        
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
+    ).fetchall()
+
+    for (serialized_dag,) in serialized_dags:
+        dag = SerializedDAG.deserialize(serialized_dag.data)
+        for dataset_uri, status in dag_statuses[dag.dag_id].items():
+            assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG 
trigger evaluation failed"
+
+
+@pytest.mark.db_test
+@pytest.mark.usefixtures("clear_datasets")
+def test_dag_with_complex_dataset_triggers(session, dag_maker):
+    # Create Dataset instances
+    d1 = Dataset(uri="hello1")
+    d2 = Dataset(uri="hello2")
+
+    # Create and add DatasetModel instances to the session
+    dm1 = DatasetModel(uri=d1.uri)
+    dm2 = DatasetModel(uri=d2.uri)
+    session.add_all([dm1, dm2])
+    session.commit()
+
+    # Setup a DAG with complex dataset triggers (DatasetAny with DatasetAll)
+    with dag_maker(schedule=DatasetAny(d1, DatasetAll(d2, d1))) as dag:
+        EmptyOperator(task_id="hello")
+
+    assert isinstance(
+        dag.dataset_triggers, DatasetAny
+    ), "DAG's dataset trigger should be an instance of DatasetAny"
+    assert any(
+        isinstance(trigger, DatasetAll) for trigger in 
dag.dataset_triggers.objects
+    ), "DAG's dataset trigger should include DatasetAll"
+
+    serialized_triggers = SerializedDAG.serialize(dag.dataset_triggers)
+
+    deserialized_triggers = SerializedDAG.deserialize(serialized_triggers)
+
+    assert isinstance(
+        deserialized_triggers, DatasetAny
+    ), "Deserialized triggers should be an instance of DatasetAny"
+    assert any(
+        isinstance(trigger, DatasetAll) for trigger in 
deserialized_triggers.objects
+    ), "Deserialized triggers should include DatasetAll"
+
+    serialized_dag_dict = SerializedDAG.to_dict(dag)["dag"]
+    assert "dataset_triggers" in serialized_dag_dict, "Serialized DAG should 
contain 'dataset_triggers'"
+    assert isinstance(
+        serialized_dag_dict["dataset_triggers"], dict
+    ), "Serialized 'dataset_triggers' should be a dict"
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 8a122592fd..2adc956b6f 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -60,6 +60,7 @@ from airflow.sensors.bash import BashSensor
 from airflow.serialization.enums import Encoding
 from airflow.serialization.json_schema import load_dag_schema_dict
 from airflow.serialization.serialized_objects import (
+    BaseSerialization,
     DagDependency,
     DependencyDetector,
     SerializedBaseOperator,
@@ -212,7 +213,6 @@ serialized_simple_dag_ground_truth = {
             },
         ],
         "schedule_interval": {"__type": "timedelta", "__var": 86400.0},
-        "dataset_triggers": [],
         "timezone": "UTC",
         "_access_control": {
             "__type": "dict",
@@ -551,11 +551,17 @@ class TestStringifiedDAGs:
             "params",
             "_processor_dags_folder",
         }
+        compare_serialization_list = {
+            "dataset_triggers",
+        }
         fields_to_check = dag.get_serialized_fields() - exclusion_list
         for field in fields_to_check:
-            assert getattr(serialized_dag, field) == getattr(
-                dag, field
-            ), f"{dag.dag_id}.{field} does not match"
+            actual = getattr(serialized_dag, field)
+            expected = getattr(dag, field)
+            if field in compare_serialization_list:
+                actual = BaseSerialization.serialize(actual)
+                expected = BaseSerialization.serialize(expected)
+            assert actual == expected, f"{dag.dag_id}.{field} does not match"
         # _processor_dags_folder is only populated at serialization time
         # it's only used when relying on serialized dag to determine a dag's 
relative path
         assert dag._processor_dags_folder is None
diff --git a/tests/timetables/test_datasets_timetable.py 
b/tests/timetables/test_datasets_timetable.py
index 8e293888ca..ce58c42a6b 100644
--- a/tests/timetables/test_datasets_timetable.py
+++ b/tests/timetables/test_datasets_timetable.py
@@ -127,7 +127,6 @@ def test_serialization(dataset_timetable: 
DatasetOrTimeSchedule, monkeypatch: An
     serialized = dataset_timetable.serialize()
     assert serialized == {
         "timetable": "mock_serialized_timetable",
-        "datasets": [{"uri": "test_dataset", "extra": None}],
     }
 
 

Reply via email to