This is an automated email from the ASF dual-hosted git repository.
o-nikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2970152b190 Fix CallbackKey type for more accurate type checking
(#66973)
2970152b190 is described below
commit 2970152b190664d4ccc825f5fe008396ca7b0343
Author: D. Ferruzzi <[email protected]>
AuthorDate: Wed May 20 09:59:34 2026 -0700
Fix CallbackKey type for more accurate type checking (#66973)
* Fix CallbackType for more accurate type checking
* additional cleanup
---
.../src/airflow/executors/base_executor.py | 12 ++---
.../src/airflow/executors/workloads/callback.py | 6 ++-
.../src/airflow/executors/workloads/types.py | 8 ++--
.../src/airflow/jobs/scheduler_job_runner.py | 11 ++---
airflow-core/src/airflow/models/callback.py | 12 ++++-
.../tests/unit/executors/test_base_executor.py | 25 +++++++++--
.../tests/unit/executors/test_local_executor.py | 2 +-
.../tests/unit/executors/test_workloads.py | 52 ++++++++++++++++++++++
8 files changed, 106 insertions(+), 22 deletions(-)
diff --git a/airflow-core/src/airflow/executors/base_executor.py
b/airflow-core/src/airflow/executors/base_executor.py
index 2f1f5e8adb9..eff6ff07714 100644
--- a/airflow-core/src/airflow/executors/base_executor.py
+++ b/airflow-core/src/airflow/executors/base_executor.py
@@ -38,6 +38,7 @@ from airflow.executors.workloads.task import ExecuteTask
from airflow.executors.workloads.types import state_class_for_key
from airflow.models import Log
from airflow.models.callback import CallbackKey
+from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.observability.metrics import stats_utils
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -78,7 +79,6 @@ if TYPE_CHECKING:
from airflow.executors.workloads import ExecutorWorkload
from airflow.executors.workloads.types import WorkloadKey, WorkloadState
from airflow.models.taskinstance import TaskInstance
- from airflow.models.taskinstancekey import TaskInstanceKey
# Event_buffer dict value type
# Tuple of: state, info
@@ -217,7 +217,7 @@ class BaseExecutor(LoggingMixin):
self.parallelism: int = parallelism
self.team_name: str | None = team_name
self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {}
- self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {}
+ self.queued_callbacks: dict[CallbackKey, workloads.ExecuteCallback] =
{}
self.running: set[WorkloadKey] = set()
self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {}
self._task_event_logs: deque[Log] = deque()
@@ -265,7 +265,7 @@ class BaseExecutor(LoggingMixin):
f"Set supports_callbacks = True and implement callback
handling in _process_workloads(). "
f"See LocalExecutor or CeleryExecutor for reference
implementation."
)
- self.queued_callbacks[workload.callback.id] = workload
+ self.queued_callbacks[workload.key] = workload
else:
raise ValueError(
f"Un-handled workload type {type(workload).__name__!r} in
{type(self).__name__}. "
@@ -497,7 +497,7 @@ class BaseExecutor(LoggingMixin):
In case dag_ids is specified it will only return and flush events
for the given dag_ids. Otherwise, it returns and flushes all events.
- Note: Callback events (with string keys) are always included
regardless of dag_ids filter.
+ Note: Callback events (with CallbackKey keys) are always included
regardless of dag_ids filter.
:param dag_ids: the dag_ids to return events for; returns all if given
``None``.
:return: a dict of events
@@ -508,7 +508,9 @@ class BaseExecutor(LoggingMixin):
self.event_buffer = {}
else:
for key in list(self.event_buffer.keys()):
- if isinstance(key, CallbackKey) or key.dag_id in dag_ids:
+ if isinstance(key, CallbackKey) or (
+ isinstance(key, TaskInstanceKey) and key.dag_id in dag_ids
+ ):
cleared_events[key] = self.event_buffer.pop(key)
return cleared_events
diff --git a/airflow-core/src/airflow/executors/workloads/callback.py
b/airflow-core/src/airflow/executors/workloads/callback.py
index a78dbab43a5..04f26b8e787 100644
--- a/airflow-core/src/airflow/executors/workloads/callback.py
+++ b/airflow-core/src/airflow/executors/workloads/callback.py
@@ -64,8 +64,10 @@ class CallbackDTO(BaseModel):
@property
def key(self) -> CallbackKey:
- """Return callback ID as key (CallbackKey = str)."""
- return self.id
+ """Return callback ID as a CallbackKey instance."""
+ from airflow.models.callback import CallbackKey # circular import
+
+ return CallbackKey(id=self.id)
class ExecuteCallback(BaseDagBundleWorkload):
diff --git a/airflow-core/src/airflow/executors/workloads/types.py
b/airflow-core/src/airflow/executors/workloads/types.py
index 61f7bf037d2..09cd2c3b359 100644
--- a/airflow-core/src/airflow/executors/workloads/types.py
+++ b/airflow-core/src/airflow/executors/workloads/types.py
@@ -20,14 +20,12 @@ from __future__ import annotations
from typing import TYPE_CHECKING, TypeAlias
-from airflow.models.callback import ExecutorCallback
+from airflow.models.callback import CallbackKey, ExecutorCallback
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.state import CallbackState, TaskInstanceState
if TYPE_CHECKING:
- from airflow.models.callback import CallbackKey
-
# Type aliases for workload keys and states (used by executor layer)
WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey
WorkloadState: TypeAlias = TaskInstanceState | CallbackState
@@ -43,4 +41,6 @@ SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback
def state_class_for_key(key: WorkloadKey) -> type[TaskInstanceState] |
type[CallbackState]:
if isinstance(key, TaskInstanceKey):
return TaskInstanceState
- return CallbackState
+ if isinstance(key, CallbackKey):
+ return CallbackState
+ raise TypeError(f"Unknown workload key type: {type(key)!r}")
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 5dcbdc21521..d04ee85c202 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -85,7 +85,7 @@ from airflow.models.asset import (
)
from airflow.models.asset_state import AssetStateModel
from airflow.models.backfill import Backfill, BackfillDagRun
-from airflow.models.callback import Callback, CallbackType, ExecutorCallback
+from airflow.models.callback import Callback, CallbackKey, CallbackType,
ExecutorCallback
from airflow.models.dag import DagModel
from airflow.models.dag_version import DagVersion
from airflow.models.dagbag import DBDagBag
@@ -1233,7 +1233,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
ti_primary_key_to_try_number_map: dict[tuple[str, str, str, int], int]
= {}
event_buffer = executor.get_event_buffer()
tis_with_right_state: list[TaskInstanceKey] = []
- callback_keys_with_events: list[str] = []
+ callback_keys_with_events: list[CallbackKey] = []
# Report execution - handle both task and callback events
for key, (state, _) in event_buffer.items():
@@ -1258,16 +1258,17 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
TaskInstanceState.RESTARTING,
):
tis_with_right_state.append(key)
- else:
- # Callback event (key is string UUID)
+ elif isinstance(key, CallbackKey):
cls.logger().info("Received executor event with state %s for
callback %s", state, key)
if state in (CallbackState.RUNNING, CallbackState.FAILED,
CallbackState.SUCCESS):
callback_keys_with_events.append(key)
+ else:
+ cls.logger().error("Unknown workload key type in event buffer:
%r", key)
# Handle callback state events
for callback_id in callback_keys_with_events:
state, info = event_buffer.pop(callback_id)
- callback = session.get(Callback, callback_id)
+ callback = session.get(Callback, str(callback_id))
if not callback:
# This should not normally happen - we just received an event
for this callback.
# Only possible if callback was deleted mid-execution (e.g.,
cascade delete from DagRun deletion).
diff --git a/airflow-core/src/airflow/models/callback.py
b/airflow-core/src/airflow/models/callback.py
index b42db309710..15f9662cdc8 100644
--- a/airflow-core/src/airflow/models/callback.py
+++ b/airflow-core/src/airflow/models/callback.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from importlib import import_module
@@ -38,7 +39,16 @@ from airflow.models.base import StringID
from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
from airflow.utils.state import CallbackState
-CallbackKey = str # Callback keys are str(UUID)
+
+@dataclass(frozen=True, slots=True)
+class CallbackKey:
+ """Distinct key type for callbacks, preventing any bare string from
passing isinstance checks."""
+
+ id: str
+
+ def __str__(self) -> str:
+ return self.id
+
if TYPE_CHECKING:
from sqlalchemy.orm import Session
diff --git a/airflow-core/tests/unit/executors/test_base_executor.py
b/airflow-core/tests/unit/executors/test_base_executor.py
index 6b2b47d0102..b3894bdef29 100644
--- a/airflow-core/tests/unit/executors/test_base_executor.py
+++ b/airflow-core/tests/unit/executors/test_base_executor.py
@@ -36,7 +36,7 @@ from airflow.executors.base_executor import BaseExecutor,
RunningRetryAttemptTyp
from airflow.executors.local_executor import LocalExecutor
from airflow.executors.workloads.base import BundleInfo
from airflow.executors.workloads.callback import CallbackDTO
-from airflow.models.callback import CallbackFetchMethod
+from airflow.models.callback import CallbackFetchMethod, CallbackKey
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.sdk import BaseOperator
from airflow.sdk.execution_time.callback_supervisor import execute_callback
@@ -100,6 +100,23 @@ def test_get_event_buffer():
assert len(executor.event_buffer) == 0
+def test_get_event_buffer_always_includes_callback_keys():
+ """CallbackKey events are always returned regardless of the dag_ids
filter."""
+ executor = BaseExecutor()
+
+ date = timezone.utcnow()
+ ti_key = TaskInstanceKey("my_dag1", "my_task1", date, 1)
+ callback_key = CallbackKey(id="00000000-0000-0000-0000-000000000042")
+
+ executor.event_buffer[ti_key] = State.SUCCESS, None
+ executor.event_buffer[callback_key] = CallbackState.SUCCESS, None
+
+ # Filter for a dag that doesn't match the TI key. Callback should still be
included
+ result = executor.get_event_buffer(("other_dag",))
+ assert callback_key in result
+ assert ti_key not in result
+
+
def test_log_task_event_branches_on_key_type():
executor = BaseExecutor()
ti_key = TaskInstanceKey("my_dag", "my_task", timezone.utcnow(), 1)
@@ -107,7 +124,7 @@ def test_log_task_event_branches_on_key_type():
executor.log_task_event(event="task_event", extra="extra", ti_key=ti_key)
assert len(executor._task_event_logs) == 1
- callback_key = str(UUID("00000000-0000-0000-0000-000000000001"))
+ callback_key =
CallbackKey(id=str(UUID("00000000-0000-0000-0000-000000000001")))
executor.log_task_event(event="callback_event", extra="extra",
ti_key=callback_key)
assert len(executor._task_event_logs) == 1
@@ -123,7 +140,7 @@ def test_log_task_event_branches_on_key_type():
)
def test_state_methods_pick_callback_state_for_callback_key(method_name,
expected_state):
executor = BaseExecutor()
- callback_key = str(UUID("00000000-0000-0000-0000-000000000002"))
+ callback_key =
CallbackKey(id=str(UUID("00000000-0000-0000-0000-000000000002")))
getattr(executor, method_name)(callback_key)
@@ -627,7 +644,7 @@ class TestCallbackSupport:
executor.queue_workload(callback_workload, session)
assert len(executor.queued_callbacks) == 1
- assert callback_data.id in executor.queued_callbacks
+ assert callback_workload.key in executor.queued_callbacks
@pytest.mark.db_test
def test_get_workloads_prioritizes_callbacks(self, dag_maker, session):
diff --git a/airflow-core/tests/unit/executors/test_local_executor.py
b/airflow-core/tests/unit/executors/test_local_executor.py
index af135fc5f99..2c9e42d23aa 100644
--- a/airflow-core/tests/unit/executors/test_local_executor.py
+++ b/airflow-core/tests/unit/executors/test_local_executor.py
@@ -427,7 +427,7 @@ class TestLocalExecutorCallbackSupport:
executor.start()
try:
- executor.queued_callbacks[callback_data.id] = callback_workload
+ executor.queued_callbacks[callback_workload.key] =
callback_workload
executor._process_workloads([callback_workload])
assert len(executor.queued_callbacks) == 0
# We can't easily verify worker execution without running the
worker,
diff --git a/airflow-core/tests/unit/executors/test_workloads.py
b/airflow-core/tests/unit/executors/test_workloads.py
index 2c3ffbf53ea..6f027a4d3be 100644
--- a/airflow-core/tests/unit/executors/test_workloads.py
+++ b/airflow-core/tests/unit/executors/test_workloads.py
@@ -17,16 +17,21 @@
# under the License.
from __future__ import annotations
+import dataclasses
from pathlib import PurePosixPath
from uuid import uuid4
import jwt
+import pytest
from airflow.api_fastapi.auth.tokens import JWTGenerator
from airflow.executors import workloads
from airflow.executors.workloads import TaskInstance, TaskInstanceDTO, base as
workloads_base
from airflow.executors.workloads.base import BaseWorkloadSchema, BundleInfo
+from airflow.executors.workloads.callback import CallbackDTO,
CallbackFetchMethod
from airflow.executors.workloads.task import ExecuteTask
+from airflow.executors.workloads.types import state_class_for_key
+from airflow.models.callback import CallbackKey
def test_task_instance_alias_keeps_backwards_compat():
@@ -82,3 +87,50 @@ def test_generate_token_produces_workload_scope(monkeypatch):
def test_generate_token_without_generator():
"""generate_token should return empty string when no generator is
provided."""
assert BaseWorkloadSchema.generate_token("ti-123", None) == ""
+
+
+def test_callback_key_is_frozen_and_hashable():
+ """CallbackKey must be usable as a dict key (hashable) and immutable
(frozen)."""
+ cid = "some-uuid-value"
+
+ key = CallbackKey(id=cid)
+ assert hash(key) == hash(CallbackKey(id=cid))
+ assert key == CallbackKey(id=cid)
+ assert key != CallbackKey(id="other")
+
+ # Frozen: assignment raises
+ with pytest.raises(dataclasses.FrozenInstanceError):
+ key.id = "mutated" # type: ignore[misc]
+
+
+def test_callback_key_str_returns_id():
+ """str(CallbackKey) should return the raw id string."""
+ cid = "some-uuid-value"
+
+ key = CallbackKey(id=cid)
+ assert str(key) == cid
+
+
+def test_callback_key_is_not_a_string():
+ """CallbackKey must NOT pass isinstance(x, str)."""
+
+ key = CallbackKey(id="some-uuid-value")
+ assert not isinstance(key, str)
+
+
+def test_state_class_for_key_raises_on_unknown_type():
+ """state_class_for_key should raise TypeError for unrecognized key
types."""
+
+ with pytest.raises(TypeError, match="Unknown workload key type"):
+ state_class_for_key("bare-string-is-not-a-key") # type:
ignore[arg-type]
+
+
+def test_callback_dto_key_returns_callback_key_instance():
+ """CallbackDTO.key should return a CallbackKey, not a bare string."""
+ cid = "some-uuid-value"
+
+ callback = CallbackDTO(id=cid,
fetch_method=CallbackFetchMethod.IMPORT_PATH, data={})
+ key = callback.key
+ assert isinstance(key, CallbackKey)
+ assert key.id == cid
+ assert str(key) == cid