This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fb3c8117fb3 Include simple context in triggerer async callback
(#55241)
fb3c8117fb3 is described below
commit fb3c8117fb3e5482712d448d702c032b19b29039
Author: Ramit Kataria <[email protected]>
AuthorDate: Mon Sep 8 12:24:08 2025 -0700
Include simple context in triggerer async callback (#55241)
- Added a simple context dict in the kwargs
- Set `context` as a reserved field for kwargs in callback definition
Eventually, we should probably use the TaskSDK API in the triggerer to
fetch the full context but this solution covers most use cases for now.
---
airflow-core/src/airflow/models/deadline.py | 12 ++++++-
airflow-core/src/airflow/triggers/deadline.py | 11 +++---
airflow-core/tests/unit/models/test_deadline.py | 42 +++++++++++++---------
airflow-core/tests/unit/triggers/test_deadline.py | 11 +++---
task-sdk/src/airflow/sdk/definitions/deadline.py | 4 ++-
.../tests/task_sdk/definitions/test_deadline.py | 11 ++++++
6 files changed, 62 insertions(+), 29 deletions(-)
diff --git a/airflow-core/src/airflow/models/deadline.py
b/airflow-core/src/airflow/models/deadline.py
index f41fe648418..cfb99160a11 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -202,10 +202,20 @@ class Deadline(Base):
"""Handle a missed deadline by running the callback in the appropriate
host and updating the `callback_state`."""
from airflow.sdk.definitions.deadline import AsyncCallback,
SyncCallback
+ def get_simple_context():
+ from airflow.api_fastapi.core_api.datamodels.dag_run import
DAGRunResponse
+
+ # TODO: Use the TaskAPI from within Triggerer to fetch full
context instead of sending this context
+ # from the scheduler
+ return {
+ "dag_run":
DAGRunResponse.model_validate(self.dagrun).model_dump(mode="json"),
+ "deadline": {"id": self.id, "deadline_time":
self.deadline_time},
+ }
+
if isinstance(self.callback, AsyncCallback):
callback_trigger = DeadlineCallbackTrigger(
callback_path=self.callback.path,
- callback_kwargs=self.callback.kwargs,
+ callback_kwargs=(self.callback.kwargs or {}) | {"context":
get_simple_context()},
)
trigger_orm = Trigger.from_object(callback_trigger)
session.add(trigger_orm)
diff --git a/airflow-core/src/airflow/triggers/deadline.py
b/airflow-core/src/airflow/triggers/deadline.py
index bcff27fd1b2..bd8a665b9fc 100644
--- a/airflow-core/src/airflow/triggers/deadline.py
+++ b/airflow-core/src/airflow/triggers/deadline.py
@@ -49,15 +49,13 @@ class DeadlineCallbackTrigger(BaseTrigger):
from airflow.models.deadline import DeadlineCallbackState # to avoid
cyclic imports
try:
- callback = import_string(self.callback_path)
yield TriggerEvent({PAYLOAD_STATUS_KEY:
DeadlineCallbackState.RUNNING})
+ callback = import_string(self.callback_path)
- # TODO: get airflow context
- context: dict = {}
-
- result = await callback(**self.callback_kwargs, context=context)
- log.info("Deadline callback completed with return value: %s",
result)
+ # TODO: get full context and run template rendering. Right now, a
simple context in included in `callback_kwargs`
+ result = await callback(**self.callback_kwargs)
yield TriggerEvent({PAYLOAD_STATUS_KEY:
DeadlineCallbackState.SUCCESS, PAYLOAD_BODY_KEY: result})
+
except Exception as e:
if isinstance(e, ImportError):
message = "Failed to import this deadline callback on the
triggerer"
@@ -65,6 +63,7 @@ class DeadlineCallbackTrigger(BaseTrigger):
message = "Failed to run this deadline callback because it is
not awaitable"
else:
message = "An error occurred during execution of this deadline
callback"
+
log.exception("%s: %s; kwargs: %s\n%s", message,
self.callback_path, self.callback_kwargs, e)
yield TriggerEvent(
{
diff --git a/airflow-core/tests/unit/models/test_deadline.py
b/airflow-core/tests/unit/models/test_deadline.py
index 1412b16bdbb..5e152935707 100644
--- a/airflow-core/tests/unit/models/test_deadline.py
+++ b/airflow-core/tests/unit/models/test_deadline.py
@@ -23,6 +23,7 @@ import pytest
import time_machine
from sqlalchemy.exc import SQLAlchemyError
+from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
from airflow.models import DagRun, Trigger
from airflow.models.deadline import Deadline, DeadlineCallbackState,
ReferenceModels, _fetch_from_db
from airflow.providers.standard.operators.empty import EmptyOperator
@@ -160,16 +161,37 @@ class TestDeadline:
)
@pytest.mark.db_test
- def test_handle_miss_async_callback(self, dagrun, deadline_orm, session):
+ @pytest.mark.parametrize(
+ "kwargs",
+ [
+ pytest.param(TEST_CALLBACK_KWARGS, id="non-empty kwargs"),
+ pytest.param(None, id="null kwargs"),
+ ],
+ )
+ def test_handle_miss_async_callback(self, dagrun, session, kwargs):
+ deadline_orm = Deadline(
+ deadline_time=DEFAULT_DATE,
+ callback=AsyncCallback(TEST_CALLBACK_PATH, kwargs),
+ dagrun_id=dagrun.id,
+ )
+ session.add(deadline_orm)
+ session.flush()
deadline_orm.handle_miss(session=session)
session.flush()
assert deadline_orm.trigger_id is not None
-
trigger = session.query(Trigger).filter(Trigger.id ==
deadline_orm.trigger_id).one()
assert trigger is not None
+
assert trigger.kwargs["callback_path"] == TEST_CALLBACK_PATH
- assert trigger.kwargs["callback_kwargs"] == TEST_CALLBACK_KWARGS
+
+ trigger_kwargs = trigger.kwargs["callback_kwargs"]
+ context = trigger_kwargs.pop("context")
+ assert trigger_kwargs == (kwargs or {})
+
+ assert context["deadline"]["id"] == str(deadline_orm.id)
+ assert context["deadline"]["deadline_time"].timestamp() ==
deadline_orm.deadline_time.timestamp()
+ assert context["dag_run"] ==
DAGRunResponse.model_validate(dagrun).model_dump(mode="json")
@pytest.mark.db_test
def test_handle_miss_sync_callback(self, dagrun, session):
@@ -232,20 +254,6 @@ class TestDeadline:
else:
assert deadline_orm.callback_state == DeadlineCallbackState.QUEUED
- def test_handle_miss_creates_trigger(self, dagrun, deadline_orm, session):
- """Test that handle_miss creates a trigger with correct parameters."""
- deadline_orm.handle_miss(session)
- session.flush()
-
- # Check trigger was created
- trigger = session.query(Trigger).first()
- assert trigger is not None
- assert deadline_orm.trigger_id == trigger.id
-
- # Check trigger has correct kwargs
- assert trigger.kwargs["callback_path"] == TEST_CALLBACK_PATH
- assert trigger.kwargs["callback_kwargs"] == TEST_CALLBACK_KWARGS
-
def test_handle_miss_sets_callback_state(self, dagrun, deadline_orm,
session):
"""Test that handle_miss sets the callback state to QUEUED."""
deadline_orm.handle_miss(session)
diff --git a/airflow-core/tests/unit/triggers/test_deadline.py
b/airflow-core/tests/unit/triggers/test_deadline.py
index 72bea33f188..955b6cb49c0 100644
--- a/airflow-core/tests/unit/triggers/test_deadline.py
+++ b/airflow-core/tests/unit/triggers/test_deadline.py
@@ -27,7 +27,7 @@ from airflow.triggers.deadline import PAYLOAD_BODY_KEY,
PAYLOAD_STATUS_KEY, Dead
TEST_MESSAGE = "test_message"
TEST_CALLBACK_PATH = "classpath.test_callback_for_deadline"
-TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE}
+TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE, "context": {"dag_run":
"test"}}
TEST_TRIGGER = DeadlineCallbackTrigger(callback_path=TEST_CALLBACK_PATH,
callback_kwargs=TEST_CALLBACK_KWARGS)
@@ -85,7 +85,7 @@ class TestDeadlineCallbackTrigger:
success_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
- mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS,
context=mock.ANY)
+ mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
assert success_event.payload[PAYLOAD_STATUS_KEY] ==
DeadlineCallbackState.SUCCESS
assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value
@@ -102,7 +102,10 @@ class TestDeadlineCallbackTrigger:
success_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
assert success_event.payload[PAYLOAD_STATUS_KEY] ==
DeadlineCallbackState.SUCCESS
- assert success_event.payload[PAYLOAD_BODY_KEY] == f"Async
notification: {TEST_MESSAGE}, context: {{}}"
+ assert (
+ success_event.payload[PAYLOAD_BODY_KEY]
+ == f"Async notification: {TEST_MESSAGE}, context: {{'dag_run':
'test'}}"
+ )
@pytest.mark.asyncio
async def test_run_failure(self, mock_import_string):
@@ -117,6 +120,6 @@ class TestDeadlineCallbackTrigger:
failure_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
- mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS,
context=mock.ANY)
+ mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
assert failure_event.payload[PAYLOAD_STATUS_KEY] ==
DeadlineCallbackState.FAILED
assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in
["raise", "RuntimeError", exc_msg])
diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py
b/task-sdk/src/airflow/sdk/definitions/deadline.py
index 966e2b926a6..46e5eeb7be2 100644
--- a/task-sdk/src/airflow/sdk/definitions/deadline.py
+++ b/task-sdk/src/airflow/sdk/definitions/deadline.py
@@ -120,8 +120,10 @@ class Callback(ABC):
path: str
kwargs: dict | None
- def __init__(self, callback_callable: Callable | str, kwargs: dict | None
= None):
+ def __init__(self, callback_callable: Callable | str, kwargs: dict[str,
Any] | None = None):
self.path = self.get_callback_path(callback_callable)
+ if kwargs and "context" in kwargs:
+ raise ValueError("context is a reserved kwarg for this class")
self.kwargs = kwargs
@classmethod
diff --git a/task-sdk/tests/task_sdk/definitions/test_deadline.py
b/task-sdk/tests/task_sdk/definitions/test_deadline.py
index 8bb70a7fad2..654cc41e2b8 100644
--- a/task-sdk/tests/task_sdk/definitions/test_deadline.py
+++ b/task-sdk/tests/task_sdk/definitions/test_deadline.py
@@ -160,6 +160,17 @@ class TestDeadlineAlert:
class TestCallback:
+ @pytest.mark.parametrize(
+ "subclass, callable",
+ [
+ pytest.param(AsyncCallback,
empty_async_callback_for_deadline_tests, id="async"),
+ pytest.param(SyncCallback, empty_sync_callback_for_deadline_tests,
id="sync"),
+ ],
+ )
+ def test_init_error_reserved_kwarg(self, subclass, callable):
+ with pytest.raises(ValueError, match="context is a reserved kwarg for
this class"):
+ subclass(callable, {"context": None})
+
@pytest.mark.parametrize(
"callback_callable, expected_path",
[