This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 28b3f6eb921 Decouple deadline reference types from core in task SDK 
(#61461)
28b3f6eb921 is described below

commit 28b3f6eb9213b794a55a147c0ed70a1fe8f9ad26
Author: Amogh Desai <[email protected]>
AuthorDate: Wed Mar 11 14:03:57 2026 +0530

    Decouple deadline reference types from core in task SDK (#61461)
    
    Custom deadline references now serialize and deserialize using a wrapper 
pattern.
---
 airflow-core/docs/howto/deadline-alerts.rst        |   7 +-
 airflow-core/src/airflow/models/deadline.py        |   6 +-
 airflow-core/src/airflow/models/deadline_alert.py  |   7 +-
 airflow-core/src/airflow/serialization/decoders.py |  17 ++-
 .../airflow/serialization/definitions/deadline.py  |  91 ++++++++++-
 airflow-core/src/airflow/serialization/encoders.py |  28 +++-
 airflow-core/tests/unit/models/test_deadline.py    | 103 ++++++++-----
 .../tests/unit/models/test_deadline_alert.py       |  58 ++++++-
 task-sdk/src/airflow/sdk/definitions/deadline.py   | 170 +++++++++++++++++----
 9 files changed, 396 insertions(+), 91 deletions(-)

diff --git a/airflow-core/docs/howto/deadline-alerts.rst 
b/airflow-core/docs/howto/deadline-alerts.rst
index 1ed9750bf4e..64f39c02440 100644
--- a/airflow-core/docs/howto/deadline-alerts.rst
+++ b/airflow-core/docs/howto/deadline-alerts.rst
@@ -425,17 +425,16 @@ implement an ``_evaluate_with()`` method.
 
 .. code-block:: python
 
-    from airflow.models.deadline import ReferenceModels
     from sqlalchemy.orm import Session
 
     from airflow.sdk import DeadlineReference
-    from airflow.sdk.definitions.deadline import deadline_reference
+    from airflow.sdk.definitions.deadline import BaseDeadlineReference, 
deadline_reference
     from airflow.sdk.timezone import datetime
 
 
     # By default, the evaluate_with method will be executed when the dagrun is 
created.
     @deadline_reference()
-    class MyCustomDecoratedReference(ReferenceModels.BaseDeadlineReference):
+    class MyCustomDecoratedReference(BaseDeadlineReference):
         """A custom reference evaluated when Dag runs are created."""
 
         def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
@@ -445,7 +444,7 @@ implement an ``_evaluate_with()`` method.
 
     # You can specify when evaluate_with will be called by providing a 
DeadlineReference.TYPES value.
     @deadline_reference(DeadlineReference.TYPES.DAGRUN_QUEUED)
-    class MyQueuedReference(ReferenceModels.BaseDeadlineReference):
+    class MyQueuedReference(BaseDeadlineReference):
         """A custom reference evaluated when Dag runs are queued."""
 
         required_kwargs = {"custom_param"}
diff --git a/airflow-core/src/airflow/models/deadline.py 
b/airflow-core/src/airflow/models/deadline.py
index debfe949b31..ec3ab5ad99c 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -314,7 +314,11 @@ class ReferenceModels:
                 )
 
             if extra_kwargs := kwargs.keys() - filtered_kwargs.keys():
-                self.log.debug("Ignoring unexpected parameters: %s", ", 
".join(extra_kwargs))
+                self.log.debug(
+                    "%s ignoring unexpected parameters: %s",
+                    self.reference_name,
+                    ", ".join(extra_kwargs),
+                )
 
             base_time = self._evaluate_with(session=session, **filtered_kwargs)
             return base_time + interval if base_time is not None else None
diff --git a/airflow-core/src/airflow/models/deadline_alert.py 
b/airflow-core/src/airflow/models/deadline_alert.py
index 8afc35d7560..0b8a8eba9b1 100644
--- a/airflow-core/src/airflow/models/deadline_alert.py
+++ b/airflow-core/src/airflow/models/deadline_alert.py
@@ -86,9 +86,10 @@ class DeadlineAlert(Base):
     @property
     def reference_class(self) -> 
type[SerializedReferenceModels.SerializedBaseDeadlineReference]:
         """Return the deserialized reference class."""
-        return SerializedReferenceModels.get_reference_class(
-            self.reference[SerializedReferenceModels.REFERENCE_TYPE_FIELD]
-        )
+        ref_name = 
self.reference.get(SerializedReferenceModels.REFERENCE_TYPE_FIELD)
+        if ref_name and 
SerializedReferenceModels.is_builtin_reference(ref_name):
+            return SerializedReferenceModels.get_reference_class(ref_name)
+        return SerializedReferenceModels.SerializedCustomReference
 
     @classmethod
     @provide_session
diff --git a/airflow-core/src/airflow/serialization/decoders.py 
b/airflow-core/src/airflow/serialization/decoders.py
index a438010424f..8d19a196e92 100644
--- a/airflow-core/src/airflow/serialization/decoders.py
+++ b/airflow-core/src/airflow/serialization/decoders.py
@@ -136,6 +136,18 @@ def decode_asset_like(var: dict[str, Any]) -> 
SerializedAssetBase:
             raise ValueError(f"deserialization not implemented for DAT 
{data_type!r}")
 
 
+def decode_deadline_reference(reference_data: dict):
+    """Decode a previously serialized deadline reference."""
+    ref_name = 
reference_data.get(SerializedReferenceModels.REFERENCE_TYPE_FIELD)
+
+    if ref_name and SerializedReferenceModels.is_builtin_reference(ref_name):
+        reference_class = 
SerializedReferenceModels.get_reference_class(ref_name)
+    else:
+        reference_class = SerializedReferenceModels.SerializedCustomReference
+
+    return reference_class.deserialize_reference(reference_data)
+
+
 def decode_deadline_alert(encoded_data: dict):
     """
     Decode a previously serialized deadline alert.
@@ -147,10 +159,7 @@ def decode_deadline_alert(encoded_data: dict):
     data = encoded_data.get(Encoding.VAR, encoded_data)
 
     reference_data = data[DeadlineAlertFields.REFERENCE]
-    reference_type = 
reference_data[SerializedReferenceModels.REFERENCE_TYPE_FIELD]
-
-    reference_class = 
SerializedReferenceModels.get_reference_class(reference_type)
-    reference = reference_class.deserialize_reference(reference_data)
+    reference = decode_deadline_reference(reference_data)
 
     return SerializedDeadlineAlert(
         reference=reference,
diff --git a/airflow-core/src/airflow/serialization/definitions/deadline.py 
b/airflow-core/src/airflow/serialization/definitions/deadline.py
index 78adc6b9a76..93af9ef19e7 100644
--- a/airflow-core/src/airflow/serialization/definitions/deadline.py
+++ b/airflow-core/src/airflow/serialization/definitions/deadline.py
@@ -20,6 +20,7 @@ import logging
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from datetime import datetime, timedelta
+from inspect import isclass
 from typing import TYPE_CHECKING, Any
 
 import attrs
@@ -62,6 +63,15 @@ class SerializedReferenceModels:
 
     REFERENCE_TYPE_FIELD = "reference_type"
 
+    @classmethod
+    def is_builtin_reference(cls, ref_name: str) -> bool:
+        """Check if a reference type is a built-in reference."""
+        return any(
+            r.__name__ == ref_name
+            for r in vars(cls).values()
+            if isclass(r) and issubclass(r, 
cls.SerializedBaseDeadlineReference)
+        )
+
     @classmethod
     def get_reference_class(cls, reference_name: str) -> 
type[SerializedBaseDeadlineReference]:
         """
@@ -99,7 +109,11 @@ class SerializedReferenceModels:
                 )
 
             if extra_kwargs := kwargs.keys() - filtered_kwargs.keys():
-                self.log.debug("Ignoring unexpected parameters: %s", ", 
".join(extra_kwargs))
+                self.log.debug(
+                    "%s ignoring unexpected parameters: %s",
+                    self.reference_name,
+                    ", ".join(extra_kwargs),
+                )
 
             base_time = self._evaluate_with(session=session, **filtered_kwargs)
             return base_time + interval if base_time is not None else None
@@ -225,8 +239,19 @@ class SerializedReferenceModels:
                 )
                 return None
 
-            avg_duration_seconds = sum(durations) / len(durations)
-            return timezone.utcnow() + timedelta(seconds=avg_duration_seconds)
+            # Convert to float to handle Decimal types from MySQL while 
preserving precision
+            # Use Decimal arithmetic for higher precision, then convert to 
float
+            from decimal import Decimal
+
+            decimal_durations = [Decimal(str(d)) for d in durations]
+            avg_seconds = float(sum(decimal_durations) / 
len(decimal_durations))
+            logger.info(
+                "Average runtime for dag_id %s (from %d runs): %.2f seconds",
+                dag_id,
+                len(durations),
+                avg_seconds,
+            )
+            return timezone.utcnow() + timedelta(seconds=avg_seconds)
 
         def serialize_reference(self) -> dict:
             return {
@@ -239,6 +264,62 @@ class SerializedReferenceModels:
         def deserialize_reference(cls, reference_data: dict):
             return cls(max_runs=reference_data["max_runs"], 
min_runs=reference_data.get("min_runs"))
 
+    class SerializedCustomReference(SerializedBaseDeadlineReference):
+        """
+        Wrapper for custom deadline references.
+
+        This class dynamically delegates to the wrapped reference for 
required_kwargs and evaluation logic.
+        """
+
+        def __init__(self, inner_ref):
+            self.inner_ref = inner_ref
+
+        @property
+        def reference_name(self) -> str:
+            return self.inner_ref.reference_name
+
+        def evaluate_with(self, *, session: Session, interval: timedelta, 
**kwargs: Any) -> datetime | None:
+            """Validate the provided kwargs and evaluate this deadline with 
the given conditions."""
+            required_kwargs: set[str] = getattr(self.inner_ref, 
"required_kwargs", set())
+            filtered_kwargs = {k: v for k, v in kwargs.items() if k in 
required_kwargs}
+
+            if missing_kwargs := required_kwargs - filtered_kwargs.keys():
+                raise ValueError(
+                    f"{self.inner_ref.__class__.__name__} is missing required 
parameters: {', '.join(missing_kwargs)}"
+                )
+
+            if extra_kwargs := kwargs.keys() - filtered_kwargs.keys():
+                self.log.debug(
+                    "%s ignoring unexpected parameters: %s",
+                    self.reference_name,
+                    ", ".join(extra_kwargs),
+                )
+
+            deadline = self.inner_ref._evaluate_with(session=session, 
**filtered_kwargs)
+            return deadline + interval if deadline is not None else None
+
+        def _evaluate_with(self, *, session: Session, **kwargs: Any) -> 
datetime | None:
+            return self.inner_ref._evaluate_with(session=session, **kwargs)
+
+        def serialize_reference(self) -> dict:
+            return self.inner_ref.serialize_reference()
+
+        @classmethod
+        def deserialize_reference(cls, reference_data: dict):
+            from airflow._shared.module_loading import import_string
+
+            custom_class = import_string(reference_data["__class_path"])
+            inner_ref = custom_class.deserialize_reference(reference_data)
+            return cls(inner_ref)
+
+        def __eq__(self, other) -> bool:
+            if not isinstance(other, 
SerializedReferenceModels.SerializedCustomReference):
+                return False
+            return self.inner_ref == other.inner_ref
+
+        def __hash__(self) -> int:
+            return hash(self.inner_ref)
+
     class TYPES:
         """Collection of SerializedDeadlineReference types for type 
checking."""
 
@@ -259,7 +340,9 @@ SerializedReferenceModels.TYPES.DAGRUN_CREATED = (
 )
 SerializedReferenceModels.TYPES.DAGRUN_QUEUED = 
(SerializedReferenceModels.DagRunQueuedAtDeadline,)
 SerializedReferenceModels.TYPES.DAGRUN = (
-    SerializedReferenceModels.TYPES.DAGRUN_CREATED + 
SerializedReferenceModels.TYPES.DAGRUN_QUEUED
+    *SerializedReferenceModels.TYPES.DAGRUN_CREATED,
+    *SerializedReferenceModels.TYPES.DAGRUN_QUEUED,
+    SerializedReferenceModels.SerializedCustomReference,
 )
 
 
diff --git a/airflow-core/src/airflow/serialization/encoders.py 
b/airflow-core/src/airflow/serialization/encoders.py
index 7db97b844f6..239b2b1b97b 100644
--- a/airflow-core/src/airflow/serialization/encoders.py
+++ b/airflow-core/src/airflow/serialization/encoders.py
@@ -202,19 +202,43 @@ def encode_deadline_alert(d: DeadlineAlert | 
SerializedDeadlineAlert) -> dict[st
     from airflow.sdk.serde import serialize
 
     return {
-        "reference": d.reference.serialize_reference(),
+        "reference": encode_deadline_reference(d.reference),
         "interval": d.interval.total_seconds(),
         "callback": serialize(d.callback),
     }
 
 
+_BUILTIN_DEADLINE_MODULES = (
+    "airflow.sdk.definitions.deadline",
+    "airflow.serialization.definitions.deadline",
+    # Include airflow.models.deadline to treat core's deadline references as 
builtins.
+    # This is to maintain backcompat with 3.1.x custom refs that inherit from
+    # airflow.models.deadline.ReferenceModels.BaseDeadlineReference.
+    "airflow.models.deadline",
+)
+
+
 def encode_deadline_reference(ref) -> dict[str, Any]:
     """
     Encode a deadline reference.
 
+    For custom (non-builtin) deadline references, includes the class path
+    so the decoder can import the user's class at runtime.
+
     :meta private:
     """
-    return ref.serialize_reference()
+    from airflow._shared.module_loading import qualname
+
+    serialized = ref.serialize_reference()
+
+    # Custom types (not built-in) need __class_path so the decoder can import 
them.
+    # Unlike built-in types which are looked up in SerializedReferenceModels,
+    # custom types are discovered via import_string(__class_path) at 
deserialization time.
+    module = type(ref).__module__
+    if module not in _BUILTIN_DEADLINE_MODULES:
+        serialized["__class_path"] = qualname(ref)
+
+    return serialized
 
 
 def _get_serialized_timetable_import_path(var: BaseTimetable | CoreTimetable) 
-> str:
diff --git a/airflow-core/tests/unit/models/test_deadline.py 
b/airflow-core/tests/unit/models/test_deadline.py
index f4f435291ce..94c6977ae0c 100644
--- a/airflow-core/tests/unit/models/test_deadline.py
+++ b/airflow-core/tests/unit/models/test_deadline.py
@@ -28,11 +28,20 @@ from sqlalchemy.exc import SQLAlchemyError
 
 from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
 from airflow.models import DagRun
-from airflow.models.deadline import Deadline, ReferenceModels, _fetch_from_db
+from airflow.models.deadline import Deadline, _fetch_from_db
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.sdk import timezone
 from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
-from airflow.sdk.definitions.deadline import DeadlineReference, 
deadline_reference
+from airflow.sdk.definitions.deadline import (
+    AverageRuntimeDeadline,
+    BaseDeadlineReference,
+    DagRunLogicalDateDeadline,
+    DagRunQueuedAtDeadline,
+    DeadlineReference,
+    FixedDatetimeDeadline,
+    deadline_reference,
+)
+from airflow.serialization.definitions.deadline import 
SerializedReferenceModels
 from airflow.utils.state import DagRunState
 
 from tests_common.test_utils import db
@@ -46,10 +55,12 @@ INVALID_DAG_ID = "invalid_dag_id"
 INVALID_RUN_ID = -1
 
 REFERENCE_TYPES = [
-    pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, id="logical_date"),
-    pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, id="queued_at"),
-    pytest.param(DeadlineReference.FIXED_DATETIME(DEFAULT_DATE), 
id="fixed_deadline"),
-    pytest.param(DeadlineReference.AVERAGE_RUNTIME(), id="average_runtime"),
+    pytest.param(SerializedReferenceModels.DagRunLogicalDateDeadline(), 
id="logical_date"),
+    pytest.param(SerializedReferenceModels.DagRunQueuedAtDeadline(), 
id="queued_at"),
+    
pytest.param(SerializedReferenceModels.FixedDatetimeDeadline(DEFAULT_DATE), 
id="fixed_deadline"),
+    pytest.param(
+        SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10, 
min_runs=10), id="average_runtime"
+    ),
 ]
 
 
@@ -320,10 +331,20 @@ class TestCalculatedDeadlineDatabaseCalls:
     @pytest.mark.parametrize(
         ("reference", "expected_column"),
         [
-            pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, 
DagRun.logical_date, id="logical_date"),
-            pytest.param(DeadlineReference.DAGRUN_QUEUED_AT, DagRun.queued_at, 
id="queued_at"),
-            pytest.param(DeadlineReference.FIXED_DATETIME(DEFAULT_DATE), None, 
id="fixed_deadline"),
-            pytest.param(DeadlineReference.AVERAGE_RUNTIME(), None, 
id="average_runtime"),
+            pytest.param(
+                SerializedReferenceModels.DagRunLogicalDateDeadline(), 
DagRun.logical_date, id="logical_date"
+            ),
+            pytest.param(
+                SerializedReferenceModels.DagRunQueuedAtDeadline(), 
DagRun.queued_at, id="queued_at"
+            ),
+            pytest.param(
+                SerializedReferenceModels.FixedDatetimeDeadline(DEFAULT_DATE), 
None, id="fixed_deadline"
+            ),
+            pytest.param(
+                SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10, 
min_runs=10),
+                None,
+                id="average_runtime",
+            ),
         ],
     )
     def test_deadline_database_integration(self, reference, expected_column, 
session):
@@ -337,13 +358,13 @@ class TestCalculatedDeadlineDatabaseCalls:
         """
         conditions = {"dag_id": DAG_ID, "run_id": "dagrun_1"}
         interval = timedelta(hours=1)
-        with mock.patch("airflow.models.deadline._fetch_from_db") as 
mock_fetch:
+        with 
mock.patch("airflow.serialization.definitions.deadline._fetch_from_db") as 
mock_fetch:
             mock_fetch.return_value = DEFAULT_DATE
 
             if expected_column is not None:
                 result = reference.evaluate_with(session=session, 
interval=interval, **conditions)
                 mock_fetch.assert_called_once_with(expected_column, 
session=session, **conditions)
-            elif reference == DeadlineReference.AVERAGE_RUNTIME():
+            elif isinstance(reference, 
SerializedReferenceModels.AverageRuntimeDeadline):
                 with mock.patch("airflow._shared.timezones.timezone.utcnow") 
as mock_utcnow:
                     mock_utcnow.return_value = DEFAULT_DATE
                     # No DAG runs exist, so it should use 24-hour default
@@ -380,7 +401,7 @@ class TestCalculatedDeadlineDatabaseCalls:
         session.commit()
 
         # Test with default max_runs (10)
-        reference = DeadlineReference.AVERAGE_RUNTIME()
+        reference = 
SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10, min_runs=10)
         interval = timedelta(hours=1)
 
         with mock.patch("airflow._shared.timezones.timezone.utcnow") as 
mock_utcnow:
@@ -417,7 +438,7 @@ class TestCalculatedDeadlineDatabaseCalls:
 
         session.commit()
 
-        reference = DeadlineReference.AVERAGE_RUNTIME()
+        reference = 
SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10, min_runs=10)
         interval = timedelta(hours=1)
 
         with mock.patch("airflow._shared.timezones.timezone.utcnow") as 
mock_utcnow:
@@ -451,7 +472,7 @@ class TestCalculatedDeadlineDatabaseCalls:
         session.commit()
 
         # Test with min_runs=2, should work with 3 runs
-        reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=10, min_runs=2)
+        reference = 
SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10, min_runs=2)
         interval = timedelta(hours=1)
 
         with mock.patch("airflow._shared.timezones.timezone.utcnow") as 
mock_utcnow:
@@ -465,7 +486,7 @@ class TestCalculatedDeadlineDatabaseCalls:
             assert result.replace(second=0, microsecond=0) == 
expected.replace(second=0, microsecond=0)
 
         # Test with min_runs=5, should return None with only 3 runs
-        reference = DeadlineReference.AVERAGE_RUNTIME(max_runs=10, min_runs=5)
+        reference = 
SerializedReferenceModels.AverageRuntimeDeadline(max_runs=10, min_runs=5)
 
         with mock.patch("airflow._shared.timezones.timezone.utcnow") as 
mock_utcnow:
             mock_utcnow.return_value = DEFAULT_DATE
@@ -535,17 +556,17 @@ class TestDeadlineReference:
     def test_deadline_reference_creation(self):
         """Test that DeadlineReference provides consistent interface and 
types."""
         fixed_reference = DeadlineReference.FIXED_DATETIME(DEFAULT_DATE)
-        assert isinstance(fixed_reference, 
ReferenceModels.FixedDatetimeDeadline)
+        assert isinstance(fixed_reference, FixedDatetimeDeadline)
         assert fixed_reference._datetime == DEFAULT_DATE
 
         logical_date_reference = DeadlineReference.DAGRUN_LOGICAL_DATE
-        assert isinstance(logical_date_reference, 
ReferenceModels.DagRunLogicalDateDeadline)
+        assert isinstance(logical_date_reference, DagRunLogicalDateDeadline)
 
         queued_reference = DeadlineReference.DAGRUN_QUEUED_AT
-        assert isinstance(queued_reference, 
ReferenceModels.DagRunQueuedAtDeadline)
+        assert isinstance(queued_reference, DagRunQueuedAtDeadline)
 
         average_runtime_reference = DeadlineReference.AVERAGE_RUNTIME()
-        assert isinstance(average_runtime_reference, 
ReferenceModels.AverageRuntimeDeadline)
+        assert isinstance(average_runtime_reference, AverageRuntimeDeadline)
         assert average_runtime_reference.max_runs == 10
         assert average_runtime_reference.min_runs == 10
 
@@ -556,14 +577,14 @@ class TestDeadlineReference:
 
 
 class TestCustomDeadlineReference:
-    class MyCustomRef(ReferenceModels.BaseDeadlineReference):
+    class MyCustomRef(BaseDeadlineReference):
         def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
             return timezone.datetime(DEFAULT_DATE)
 
     class MyInvalidCustomRef:
         pass
 
-    class MyCustomRefWithKwargs(ReferenceModels.BaseDeadlineReference):
+    class MyCustomRefWithKwargs(BaseDeadlineReference):
         required_kwargs = {"custom_id"}
 
         def _evaluate_with(self, *, session: Session, **kwargs) -> datetime:
@@ -573,7 +594,6 @@ class TestCustomDeadlineReference:
         self.original_dagrun_created = DeadlineReference.TYPES.DAGRUN_CREATED
         self.original_dagrun_queued = DeadlineReference.TYPES.DAGRUN_QUEUED
         self.original_dagrun = DeadlineReference.TYPES.DAGRUN
-        self.original_attrs = set(dir(ReferenceModels))
         self.original_deadline_attrs = set(dir(DeadlineReference))
 
     def teardown_method(self):
@@ -581,10 +601,6 @@ class TestCustomDeadlineReference:
         DeadlineReference.TYPES.DAGRUN_QUEUED = self.original_dagrun_queued
         DeadlineReference.TYPES.DAGRUN = self.original_dagrun
 
-        for attr in set(dir(ReferenceModels)):
-            if attr not in self.original_attrs:
-                delattr(ReferenceModels, attr)
-
         for attr in set(dir(DeadlineReference)):
             if attr not in self.original_deadline_attrs:
                 delattr(DeadlineReference, attr)
@@ -613,7 +629,7 @@ class TestCustomDeadlineReference:
             expected_timing = timing
 
         assert result is reference
-        assert getattr(ReferenceModels, reference.__name__) is reference
+        assert hasattr(DeadlineReference, reference.__name__)
         assert getattr(DeadlineReference, reference.__name__).__class__ is 
reference
 
         assert_correct_timing(reference, expected_timing)
@@ -637,12 +653,15 @@ class TestCustomDeadlineReference:
         ):
             DeadlineReference.register_custom_reference(self.MyCustomRef, 
invalid_timing)
 
-    def test_custom_reference_discoverable_by_get_reference_class(self):
+    def test_custom_reference_discoverable_on_deadline_reference(self):
+        # Custom references are only registered on DeadlineReference, not on 
ReferenceModels.
+        # During deserialization, custom refs are discovered via __class_path 
in the
+        # serialized data (using import_string), not through ReferenceModels 
lookup.
         DeadlineReference.register_custom_reference(self.MyCustomRef)
 
-        found_class = 
ReferenceModels.get_reference_class(self.MyCustomRef.__name__)
-
-        assert found_class is self.MyCustomRef
+        assert hasattr(DeadlineReference, self.MyCustomRef.__name__)
+        found_instance = getattr(DeadlineReference, self.MyCustomRef.__name__)
+        assert isinstance(found_instance, self.MyCustomRef)
 
 
 class TestDeadlineReferenceDecorator:
@@ -650,21 +669,21 @@ class TestDeadlineReferenceDecorator:
         self.original_dagrun_created = DeadlineReference.TYPES.DAGRUN_CREATED
         self.original_dagrun_queued = DeadlineReference.TYPES.DAGRUN_QUEUED
         self.original_dagrun = DeadlineReference.TYPES.DAGRUN
-        self.original_attrs = set(dir(ReferenceModels))
+        self.original_deadline_attrs = set(dir(DeadlineReference))
 
     def teardown_method(self):
         DeadlineReference.TYPES.DAGRUN_CREATED = self.original_dagrun_created
         DeadlineReference.TYPES.DAGRUN_QUEUED = self.original_dagrun_queued
         DeadlineReference.TYPES.DAGRUN = self.original_dagrun
 
-        for attr in set(dir(ReferenceModels)):
-            if attr not in self.original_attrs:
-                delattr(ReferenceModels, attr)
+        for attr in set(dir(DeadlineReference)):
+            if attr not in self.original_deadline_attrs:
+                delattr(DeadlineReference, attr)
 
     @staticmethod
     def create_decorated_custom_ref():
         @deadline_reference()
-        class DecoratedCustomRef(ReferenceModels.BaseDeadlineReference):
+        class DecoratedCustomRef(BaseDeadlineReference):
             def _evaluate_with(self, *, session: Session, **kwargs) -> 
datetime:
                 return timezone.datetime(DEFAULT_DATE)
 
@@ -673,7 +692,7 @@ class TestDeadlineReferenceDecorator:
     @staticmethod
     def create_decorated_custom_ref_with_kwargs():
         @deadline_reference()
-        class 
DecoratedCustomRefWithKwargs(ReferenceModels.BaseDeadlineReference):
+        class DecoratedCustomRefWithKwargs(BaseDeadlineReference):
             required_kwargs = {"custom_id"}
 
             def _evaluate_with(self, *, session: Session, **kwargs) -> 
datetime:
@@ -684,7 +703,7 @@ class TestDeadlineReferenceDecorator:
     @staticmethod
     def create_decorated_custom_ref_queued():
         @deadline_reference(DeadlineReference.TYPES.DAGRUN_QUEUED)
-        class DecoratedCustomRefQueued(ReferenceModels.BaseDeadlineReference):
+        class DecoratedCustomRefQueued(BaseDeadlineReference):
             def _evaluate_with(self, *, session: Session, **kwargs) -> 
datetime:
                 return timezone.datetime(DEFAULT_DATE)
 
@@ -713,7 +732,7 @@ class TestDeadlineReferenceDecorator:
     def test_deadline_reference_decorator(self, reference_factory, 
expected_timing):
         reference = reference_factory()
 
-        assert getattr(ReferenceModels, reference.__name__) is reference
+        assert hasattr(DeadlineReference, reference.__name__)
         assert getattr(DeadlineReference, reference.__name__).__class__ is 
reference
 
         assert_correct_timing(reference, expected_timing)
@@ -741,7 +760,7 @@ class TestDeadlineReferenceDecorator:
         ):
 
             @deadline_reference(invalid_timing)
-            class DecoratedCustomRef(ReferenceModels.BaseDeadlineReference):
+            class DecoratedCustomRef(BaseDeadlineReference):
                 def _evaluate_with(self, *, session: Session, **kwargs) -> 
datetime:
                     return timezone.datetime(DEFAULT_DATE)
 
@@ -750,7 +769,7 @@ class TestDeadlineReferenceDecorator:
         timing = DeadlineReference.TYPES.DAGRUN_QUEUED
 
         @deadline_reference(timing)
-        class DecoratedCustomRef(ReferenceModels.BaseDeadlineReference):
+        class DecoratedCustomRef(BaseDeadlineReference):
             def _evaluate_with(self, *, session: Session, **kwargs) -> 
datetime:
                 return timezone.datetime(DEFAULT_DATE)
 
diff --git a/airflow-core/tests/unit/models/test_deadline_alert.py 
b/airflow-core/tests/unit/models/test_deadline_alert.py
index 879203814b3..9d69577a6d6 100644
--- a/airflow-core/tests/unit/models/test_deadline_alert.py
+++ b/airflow-core/tests/unit/models/test_deadline_alert.py
@@ -16,13 +16,17 @@
 # under the License.
 from __future__ import annotations
 
+from datetime import timedelta
+from unittest.mock import Mock
+
 import pytest
 import time_machine
 from sqlalchemy import select
 
+from airflow._shared.timezones import timezone
 from airflow.models.deadline_alert import DeadlineAlert
 from airflow.models.serialized_dag import SerializedDagModel
-from airflow.sdk.definitions.deadline import DeadlineReference
+from airflow.sdk.definitions.deadline import BaseDeadlineReference, 
DeadlineReference
 from airflow.serialization.definitions.deadline import 
SerializedReferenceModels
 
 from tests_common.test_utils import db
@@ -172,3 +176,55 @@ class TestDeadlineAlert:
         nonexistent_uuid = "00000000-0000-7000-8000-000000000000"
         with pytest.raises(NoResultFound, match="No DeadlineAlert found"):
             DeadlineAlert.get_by_id(nonexistent_uuid, session=session)
+
+    def test_serialized_custom_reference_kwargs_handling(self):
+        """Test that SerializedCustomReference properly filters and validates 
kwargs."""
+
+        class StrictCustomRef(BaseDeadlineReference):
+            reference_name = "StrictCustomRef"
+            required_kwargs = {"dag_id", "run_id"}
+
+            def _evaluate_with(self, *, session, dag_id, run_id):
+                return timezone.utcnow()
+
+        inner_ref = StrictCustomRef()
+        inner_ref._evaluate_with = Mock(return_value=timezone.utcnow())
+
+        wrapper = 
SerializedReferenceModels.SerializedCustomReference(inner_ref)
+
+        wrapper.evaluate_with(
+            session=None,
+            interval=timedelta(hours=1),
+            dag_id="test_dag",
+            run_id="test_run",
+            extra_param="should_be_filtered",
+        )
+
+        inner_ref._evaluate_with.assert_called_once_with(session=None, 
dag_id="test_dag", run_id="test_run")
+
+        # try calling with missing required parameters
+        with pytest.raises(ValueError, match="missing required parameters: 
run_id"):
+            wrapper.evaluate_with(
+                session=None,
+                interval=timedelta(hours=1),
+                dag_id="test_dag",
+            )
+
+    def test_core_deadline_reference_treated_as_builtins(self):
+        """Test that refs from airflow.models.deadline are still treated as 
builtins."""
+        from airflow.models.deadline import ReferenceModels
+        from airflow.serialization.encoders import encode_deadline_reference
+
+        ref = ReferenceModels.DagRunLogicalDateDeadline()
+        serialized = encode_deadline_reference(ref)
+
+        assert "__class_path" not in serialized
+        assert serialized["reference_type"] == "DagRunLogicalDateDeadline"
+
+    def test_is_builtin_reference(self):
+        """Test that is_builtin_reference correctly identifies built-in vs 
custom references."""
+        assert 
SerializedReferenceModels.is_builtin_reference("DagRunLogicalDateDeadline") is 
True
+        assert 
SerializedReferenceModels.is_builtin_reference("DagRunQueuedAtDeadline") is True
+        assert 
SerializedReferenceModels.is_builtin_reference("AverageRuntimeDeadline") is True
+
+        assert SerializedReferenceModels.is_builtin_reference("MyCustomRef") 
is False
diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py 
b/task-sdk/src/airflow/sdk/definitions/deadline.py
index 8c55e10d45c..2fe220e789d 100644
--- a/task-sdk/src/airflow/sdk/definitions/deadline.py
+++ b/task-sdk/src/airflow/sdk/definitions/deadline.py
@@ -17,10 +17,11 @@
 from __future__ import annotations
 
 import logging
+from abc import ABC
+from dataclasses import dataclass
 from datetime import datetime, timedelta
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 
-from airflow.models.deadline import DeadlineReferenceType, ReferenceModels
 from airflow.sdk.definitions.callback import AsyncCallback, Callback, 
SyncCallback
 
 if TYPE_CHECKING:
@@ -29,7 +30,111 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-DeadlineReferenceTypes: TypeAlias = 
tuple[type[ReferenceModels.BaseDeadlineReference], ...]
+# Field name used in serialization - must be in sync with 
SerializedReferenceModels.REFERENCE_TYPE_FIELD
+REFERENCE_TYPE_FIELD = "reference_type"
+
+
+class BaseDeadlineReference(ABC):
+    """
+    Base class for all Deadline Reference implementations.
+
+    This is a lightweight SDK class for DAG authoring. It only handles 
serialization.
+    The actual evaluation logic (_evaluate_with) is in Core's 
SerializedReferenceModels.
+
+    For custom deadline references, users should inherit from this class and 
implement
+    _evaluate_with() with deferred Core imports (imports inside the method 
body).
+    """
+
+    @property
+    def reference_name(self) -> str:
+        """Return the class name as the reference identifier."""
+        return self.__class__.__name__
+
+    def serialize_reference(self) -> dict[str, Any]:
+        """
+        Serialize this reference type into a dictionary representation.
+
+        Override this method in subclasses if additional data is needed for 
serialization.
+        """
+        return {REFERENCE_TYPE_FIELD: self.reference_name}
+
+    @classmethod
+    def deserialize_reference(cls, reference_data: dict[str, Any]) -> 
BaseDeadlineReference:
+        """
+        Deserialize a reference type from its dictionary representation.
+
+        :param reference_data: Dictionary containing serialized reference data.
+        """
+        return cls()
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, BaseDeadlineReference):
+            return NotImplemented
+        return self.serialize_reference() == other.serialize_reference()
+
+    def __hash__(self) -> int:
+        return hash(frozenset(self.serialize_reference().items()))
+
+
+class DagRunLogicalDateDeadline(BaseDeadlineReference):
+    """A deadline that returns a DagRun's logical date."""
+
+
+class DagRunQueuedAtDeadline(BaseDeadlineReference):
+    """A deadline that returns when a DagRun was queued."""
+
+
+@dataclass
+class FixedDatetimeDeadline(BaseDeadlineReference):
+    """A deadline that always returns a fixed datetime."""
+
+    _datetime: datetime
+
+    def serialize_reference(self) -> dict[str, Any]:
+        return {
+            REFERENCE_TYPE_FIELD: self.reference_name,
+            "datetime": self._datetime.timestamp(),
+        }
+
+    @classmethod
+    def deserialize_reference(cls, reference_data: dict[str, Any]) -> 
FixedDatetimeDeadline:
+        from airflow._shared.timezones import timezone
+
+        return 
cls(_datetime=timezone.from_timestamp(reference_data["datetime"]))
+
+
+@dataclass
+class AverageRuntimeDeadline(BaseDeadlineReference):
+    """A deadline that calculates the average runtime from past DAG runs."""
+
+    DEFAULT_LIMIT = 10
+    max_runs: int
+    min_runs: int | None = None
+
+    def __post_init__(self):
+        if self.min_runs is None:
+            self.min_runs = self.max_runs
+        if self.min_runs < 1:
+            raise ValueError("min_runs must be at least 1")
+
+    def serialize_reference(self) -> dict[str, Any]:
+        return {
+            REFERENCE_TYPE_FIELD: self.reference_name,
+            "max_runs": self.max_runs,
+            "min_runs": self.min_runs,
+        }
+
+    @classmethod
+    def deserialize_reference(cls, reference_data: dict[str, Any]) -> 
AverageRuntimeDeadline:
+        max_runs = reference_data.get("max_runs", cls.DEFAULT_LIMIT)
+        min_runs = reference_data.get("min_runs", max_runs)
+        if min_runs < 1:
+            raise ValueError("min_runs must be at least 1")
+        return cls(max_runs=max_runs, min_runs=min_runs)
+
+
+DeadlineReferenceType: TypeAlias = BaseDeadlineReference
+DeadlineReferenceTypes: TypeAlias = tuple[type[BaseDeadlineReference], ...]
 
 
 class DeadlineAlert:
@@ -118,33 +223,31 @@ class DeadlineReference:
 
         # Deadlines that should be created when the DagRun is created.
         DAGRUN_CREATED: DeadlineReferenceTypes = (
-            ReferenceModels.DagRunLogicalDateDeadline,
-            ReferenceModels.FixedDatetimeDeadline,
-            ReferenceModels.AverageRuntimeDeadline,
+            DagRunLogicalDateDeadline,
+            FixedDatetimeDeadline,
+            AverageRuntimeDeadline,
         )
 
         # Deadlines that should be created when the DagRun is queued.
-        DAGRUN_QUEUED: DeadlineReferenceTypes = 
(ReferenceModels.DagRunQueuedAtDeadline,)
+        DAGRUN_QUEUED: DeadlineReferenceTypes = (DagRunQueuedAtDeadline,)
 
         # All DagRun-related deadline types.
         DAGRUN: DeadlineReferenceTypes = DAGRUN_CREATED + DAGRUN_QUEUED
 
-    from airflow.models.deadline import ReferenceModels
-
-    DAGRUN_LOGICAL_DATE: DeadlineReferenceType = 
ReferenceModels.DagRunLogicalDateDeadline()
-    DAGRUN_QUEUED_AT: DeadlineReferenceType = 
ReferenceModels.DagRunQueuedAtDeadline()
+    DAGRUN_LOGICAL_DATE: DeadlineReferenceType = DagRunLogicalDateDeadline()
+    DAGRUN_QUEUED_AT: DeadlineReferenceType = DagRunQueuedAtDeadline()
 
     @classmethod
     def AVERAGE_RUNTIME(cls, max_runs: int = 0, min_runs: int | None = None) 
-> DeadlineReferenceType:
         if max_runs == 0:
-            max_runs = cls.ReferenceModels.AverageRuntimeDeadline.DEFAULT_LIMIT
+            max_runs = AverageRuntimeDeadline.DEFAULT_LIMIT
         if min_runs is None:
             min_runs = max_runs
-        return cls.ReferenceModels.AverageRuntimeDeadline(max_runs, min_runs)
+        return AverageRuntimeDeadline(max_runs, min_runs)
 
     @classmethod
-    def FIXED_DATETIME(cls, datetime: datetime) -> DeadlineReferenceType:
-        return cls.ReferenceModels.FixedDatetimeDeadline(datetime)
+    def FIXED_DATETIME(cls, dt: datetime) -> DeadlineReferenceType:
+        return FixedDatetimeDeadline(dt)
 
     # TODO: Remove this once other deadline types exist.
     #   This is a temporary reference type used only in tests to verify that
@@ -152,16 +255,16 @@ class DeadlineReference:
     #   It should be replaced with a real non-dagrun deadline type when one is 
available.
     _TEMPORARY_TEST_REFERENCE = type(
         "TemporaryTestDeadlineForTypeChecking",
-        (DeadlineReferenceType,),
-        {"_evaluate_with": lambda self, **kwargs: datetime.now()},
+        (BaseDeadlineReference,),
+        {"serialize_reference": lambda self: {REFERENCE_TYPE_FIELD: 
"TemporaryTestDeadlineForTypeChecking"}},
     )()
 
     @classmethod
     def register_custom_reference(
         cls,
-        reference_class: type[ReferenceModels.BaseDeadlineReference],
+        reference_class: type[BaseDeadlineReference],
         deadline_reference_type: DeadlineReferenceTypes | None = None,
-    ) -> type[ReferenceModels.BaseDeadlineReference]:
+    ) -> type[BaseDeadlineReference]:
         """
         Register a custom deadline reference class.
 
@@ -169,18 +272,18 @@ class DeadlineReference:
         :param deadline_reference_type: A DeadlineReference.TYPES for when the 
deadline should be evaluated ("DAGRUN_CREATED",
             "DAGRUN_QUEUED", etc.); defaults to 
DeadlineReference.TYPES.DAGRUN_CREATED
         """
-        from airflow.models.deadline import ReferenceModels
-
         # Default to DAGRUN_CREATED if no deadline_reference_type specified
         if deadline_reference_type is None:
             deadline_reference_type = cls.TYPES.DAGRUN_CREATED
 
         # Validate the reference class inherits from BaseDeadlineReference
-        if not issubclass(reference_class, 
ReferenceModels.BaseDeadlineReference):
+        # Accept both sdk and core base classes for backward compatibility for 
now
+        from airflow.models.deadline import ReferenceModels
+
+        if not issubclass(reference_class, (BaseDeadlineReference, 
ReferenceModels.BaseDeadlineReference)):
             raise ValueError(f"{reference_class.__name__} must inherit from 
BaseDeadlineReference")
 
-        # Register the new reference with ReferenceModels and 
DeadlineReference for discoverability
-        setattr(ReferenceModels, reference_class.__name__, reference_class)
+        # Register the new reference with DeadlineReference for discoverability
         setattr(cls, reference_class.__name__, reference_class())
         logger.info("Registered DeadlineReference %s", 
reference_class.__name__)
 
@@ -203,29 +306,36 @@ class DeadlineReference:
 
 def deadline_reference(
     deadline_reference_type: DeadlineReferenceTypes | None = None,
-) -> Callable[[type[ReferenceModels.BaseDeadlineReference]], 
type[ReferenceModels.BaseDeadlineReference]]:
+) -> Callable[[type[BaseDeadlineReference]], type[BaseDeadlineReference]]:
     """
     Decorate a class to register a custom deadline reference.
 
     Usage:
         @deadline_reference()
-        class MyCustomReference(ReferenceModels.BaseDeadlineReference):
+        class MyCustomReference(BaseDeadlineReference):
             # By default, evaluate_with will be called when a new dagrun is 
created.
             def _evaluate_with(self, *, session: Session, **kwargs) -> 
datetime:
-                # Put your business logic here
+                # Put your business logic here (use deferred imports for Core 
types)
+                from airflow.models import DagRun
                 return some_datetime
 
+            def serialize_reference(self) -> dict:
+                return {"reference_type": self.reference_name}
+
         @deadline_reference(DeadlineReference.TYPES.DAGRUN_QUEUED)
-        class MyQueuedRef(ReferenceModels.BaseDeadlineReference):
+        class MyQueuedRef(BaseDeadlineReference):
             # Optionally, you can specify when you want it calculated by 
providing a DeadlineReference.TYPES
             def _evaluate_with(self, *, session: Session, **kwargs) -> 
datetime:
                  # Put your business logic here
                 return some_datetime
+
+            def serialize_reference(self) -> dict:
+                return {"reference_type": self.reference_name}
     """
 
     def decorator(
-        reference_class: type[ReferenceModels.BaseDeadlineReference],
-    ) -> type[ReferenceModels.BaseDeadlineReference]:
+        reference_class: type[BaseDeadlineReference],
+    ) -> type[BaseDeadlineReference]:
         DeadlineReference.register_custom_reference(reference_class, 
deadline_reference_type)
         return reference_class
 


Reply via email to