This is an automated email from the ASF dual-hosted git repository.

o-nikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2970152b190 Fix CallbackKey type for more accurate type checking 
(#66973)
2970152b190 is described below

commit 2970152b190664d4ccc825f5fe008396ca7b0343
Author: D. Ferruzzi <[email protected]>
AuthorDate: Wed May 20 09:59:34 2026 -0700

    Fix CallbackKey type for more accurate type checking (#66973)
    
    * Fix CallbackType for more accurate type checking
    
    * additional cleanup
---
 .../src/airflow/executors/base_executor.py         | 12 ++---
 .../src/airflow/executors/workloads/callback.py    |  6 ++-
 .../src/airflow/executors/workloads/types.py       |  8 ++--
 .../src/airflow/jobs/scheduler_job_runner.py       | 11 ++---
 airflow-core/src/airflow/models/callback.py        | 12 ++++-
 .../tests/unit/executors/test_base_executor.py     | 25 +++++++++--
 .../tests/unit/executors/test_local_executor.py    |  2 +-
 .../tests/unit/executors/test_workloads.py         | 52 ++++++++++++++++++++++
 8 files changed, 106 insertions(+), 22 deletions(-)

diff --git a/airflow-core/src/airflow/executors/base_executor.py 
b/airflow-core/src/airflow/executors/base_executor.py
index 2f1f5e8adb9..eff6ff07714 100644
--- a/airflow-core/src/airflow/executors/base_executor.py
+++ b/airflow-core/src/airflow/executors/base_executor.py
@@ -38,6 +38,7 @@ from airflow.executors.workloads.task import ExecuteTask
 from airflow.executors.workloads.types import state_class_for_key
 from airflow.models import Log
 from airflow.models.callback import CallbackKey
+from airflow.models.taskinstancekey import TaskInstanceKey
 from airflow.observability.metrics import stats_utils
 from airflow.utils.log.logging_mixin import LoggingMixin
 
@@ -78,7 +79,6 @@ if TYPE_CHECKING:
     from airflow.executors.workloads import ExecutorWorkload
     from airflow.executors.workloads.types import WorkloadKey, WorkloadState
     from airflow.models.taskinstance import TaskInstance
-    from airflow.models.taskinstancekey import TaskInstanceKey
 
     # Event_buffer dict value type
     # Tuple of: state, info
@@ -217,7 +217,7 @@ class BaseExecutor(LoggingMixin):
         self.parallelism: int = parallelism
         self.team_name: str | None = team_name
         self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {}
-        self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {}
+        self.queued_callbacks: dict[CallbackKey, workloads.ExecuteCallback] = 
{}
         self.running: set[WorkloadKey] = set()
         self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {}
         self._task_event_logs: deque[Log] = deque()
@@ -265,7 +265,7 @@ class BaseExecutor(LoggingMixin):
                     f"Set supports_callbacks = True and implement callback 
handling in _process_workloads(). "
                     f"See LocalExecutor or CeleryExecutor for reference 
implementation."
                 )
-            self.queued_callbacks[workload.callback.id] = workload
+            self.queued_callbacks[workload.key] = workload
         else:
             raise ValueError(
                 f"Un-handled workload type {type(workload).__name__!r} in 
{type(self).__name__}. "
@@ -497,7 +497,7 @@ class BaseExecutor(LoggingMixin):
 
         In case dag_ids is specified it will only return and flush events
         for the given dag_ids. Otherwise, it returns and flushes all events.
-        Note: Callback events (with string keys) are always included 
regardless of dag_ids filter.
+        Note: Callback events (with CallbackKey keys) are always included 
regardless of dag_ids filter.
 
         :param dag_ids: the dag_ids to return events for; returns all if given 
``None``.
         :return: a dict of events
@@ -508,7 +508,9 @@ class BaseExecutor(LoggingMixin):
             self.event_buffer = {}
         else:
             for key in list(self.event_buffer.keys()):
-                if isinstance(key, CallbackKey) or key.dag_id in dag_ids:
+                if isinstance(key, CallbackKey) or (
+                    isinstance(key, TaskInstanceKey) and key.dag_id in dag_ids
+                ):
                     cleared_events[key] = self.event_buffer.pop(key)
 
         return cleared_events
diff --git a/airflow-core/src/airflow/executors/workloads/callback.py 
b/airflow-core/src/airflow/executors/workloads/callback.py
index a78dbab43a5..04f26b8e787 100644
--- a/airflow-core/src/airflow/executors/workloads/callback.py
+++ b/airflow-core/src/airflow/executors/workloads/callback.py
@@ -64,8 +64,10 @@ class CallbackDTO(BaseModel):
 
     @property
     def key(self) -> CallbackKey:
-        """Return callback ID as key (CallbackKey = str)."""
-        return self.id
+        """Return callback ID as a CallbackKey instance."""
+        from airflow.models.callback import CallbackKey  # circular import
+
+        return CallbackKey(id=self.id)
 
 
 class ExecuteCallback(BaseDagBundleWorkload):
diff --git a/airflow-core/src/airflow/executors/workloads/types.py 
b/airflow-core/src/airflow/executors/workloads/types.py
index 61f7bf037d2..09cd2c3b359 100644
--- a/airflow-core/src/airflow/executors/workloads/types.py
+++ b/airflow-core/src/airflow/executors/workloads/types.py
@@ -20,14 +20,12 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, TypeAlias
 
-from airflow.models.callback import ExecutorCallback
+from airflow.models.callback import CallbackKey, ExecutorCallback
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskinstancekey import TaskInstanceKey
 from airflow.utils.state import CallbackState, TaskInstanceState
 
 if TYPE_CHECKING:
-    from airflow.models.callback import CallbackKey
-
     # Type aliases for workload keys and states (used by executor layer)
     WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey
     WorkloadState: TypeAlias = TaskInstanceState | CallbackState
@@ -43,4 +41,6 @@ SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback
 def state_class_for_key(key: WorkloadKey) -> type[TaskInstanceState] | 
type[CallbackState]:
     if isinstance(key, TaskInstanceKey):
         return TaskInstanceState
-    return CallbackState
+    if isinstance(key, CallbackKey):
+        return CallbackState
+    raise TypeError(f"Unknown workload key type: {type(key)!r}")
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 5dcbdc21521..d04ee85c202 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -85,7 +85,7 @@ from airflow.models.asset import (
 )
 from airflow.models.asset_state import AssetStateModel
 from airflow.models.backfill import Backfill, BackfillDagRun
-from airflow.models.callback import Callback, CallbackType, ExecutorCallback
+from airflow.models.callback import Callback, CallbackKey, CallbackType, 
ExecutorCallback
 from airflow.models.dag import DagModel
 from airflow.models.dag_version import DagVersion
 from airflow.models.dagbag import DBDagBag
@@ -1233,7 +1233,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         ti_primary_key_to_try_number_map: dict[tuple[str, str, str, int], int] 
= {}
         event_buffer = executor.get_event_buffer()
         tis_with_right_state: list[TaskInstanceKey] = []
-        callback_keys_with_events: list[str] = []
+        callback_keys_with_events: list[CallbackKey] = []
 
         # Report execution - handle both task and callback events
         for key, (state, _) in event_buffer.items():
@@ -1258,16 +1258,17 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                     TaskInstanceState.RESTARTING,
                 ):
                     tis_with_right_state.append(key)
-            else:
-                # Callback event (key is string UUID)
+            elif isinstance(key, CallbackKey):
                 cls.logger().info("Received executor event with state %s for 
callback %s", state, key)
                 if state in (CallbackState.RUNNING, CallbackState.FAILED, 
CallbackState.SUCCESS):
                     callback_keys_with_events.append(key)
+            else:
+                cls.logger().error("Unknown workload key type in event buffer: 
%r", key)
 
         # Handle callback state events
         for callback_id in callback_keys_with_events:
             state, info = event_buffer.pop(callback_id)
-            callback = session.get(Callback, callback_id)
+            callback = session.get(Callback, str(callback_id))
             if not callback:
                 # This should not normally happen - we just received an event 
for this callback.
                 # Only possible if callback was deleted mid-execution (e.g., 
cascade delete from DagRun deletion).
diff --git a/airflow-core/src/airflow/models/callback.py 
b/airflow-core/src/airflow/models/callback.py
index b42db309710..15f9662cdc8 100644
--- a/airflow-core/src/airflow/models/callback.py
+++ b/airflow-core/src/airflow/models/callback.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+from dataclasses import dataclass
 from datetime import datetime
 from enum import Enum
 from importlib import import_module
@@ -38,7 +39,16 @@ from airflow.models.base import StringID
 from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
 from airflow.utils.state import CallbackState
 
-CallbackKey = str  # Callback keys are str(UUID)
+
+@dataclass(frozen=True, slots=True)
+class CallbackKey:
+    """Distinct key type for callbacks, preventing any bare string from 
passing isinstance checks."""
+
+    id: str
+
+    def __str__(self) -> str:
+        return self.id
+
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
diff --git a/airflow-core/tests/unit/executors/test_base_executor.py 
b/airflow-core/tests/unit/executors/test_base_executor.py
index 6b2b47d0102..b3894bdef29 100644
--- a/airflow-core/tests/unit/executors/test_base_executor.py
+++ b/airflow-core/tests/unit/executors/test_base_executor.py
@@ -36,7 +36,7 @@ from airflow.executors.base_executor import BaseExecutor, 
RunningRetryAttemptTyp
 from airflow.executors.local_executor import LocalExecutor
 from airflow.executors.workloads.base import BundleInfo
 from airflow.executors.workloads.callback import CallbackDTO
-from airflow.models.callback import CallbackFetchMethod
+from airflow.models.callback import CallbackFetchMethod, CallbackKey
 from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
 from airflow.sdk import BaseOperator
 from airflow.sdk.execution_time.callback_supervisor import execute_callback
@@ -100,6 +100,23 @@ def test_get_event_buffer():
     assert len(executor.event_buffer) == 0
 
 
+def test_get_event_buffer_always_includes_callback_keys():
+    """CallbackKey events are always returned regardless of the dag_ids 
filter."""
+    executor = BaseExecutor()
+
+    date = timezone.utcnow()
+    ti_key = TaskInstanceKey("my_dag1", "my_task1", date, 1)
+    callback_key = CallbackKey(id="00000000-0000-0000-0000-000000000042")
+
+    executor.event_buffer[ti_key] = State.SUCCESS, None
+    executor.event_buffer[callback_key] = CallbackState.SUCCESS, None
+
+    # Filter for a dag that doesn't match the TI key. Callback should still be 
included
+    result = executor.get_event_buffer(("other_dag",))
+    assert callback_key in result
+    assert ti_key not in result
+
+
 def test_log_task_event_branches_on_key_type():
     executor = BaseExecutor()
     ti_key = TaskInstanceKey("my_dag", "my_task", timezone.utcnow(), 1)
@@ -107,7 +124,7 @@ def test_log_task_event_branches_on_key_type():
     executor.log_task_event(event="task_event", extra="extra", ti_key=ti_key)
     assert len(executor._task_event_logs) == 1
 
-    callback_key = str(UUID("00000000-0000-0000-0000-000000000001"))
+    callback_key = 
CallbackKey(id=str(UUID("00000000-0000-0000-0000-000000000001")))
     executor.log_task_event(event="callback_event", extra="extra", 
ti_key=callback_key)
     assert len(executor._task_event_logs) == 1
 
@@ -123,7 +140,7 @@ def test_log_task_event_branches_on_key_type():
 )
 def test_state_methods_pick_callback_state_for_callback_key(method_name, 
expected_state):
     executor = BaseExecutor()
-    callback_key = str(UUID("00000000-0000-0000-0000-000000000002"))
+    callback_key = 
CallbackKey(id=str(UUID("00000000-0000-0000-0000-000000000002")))
 
     getattr(executor, method_name)(callback_key)
 
@@ -627,7 +644,7 @@ class TestCallbackSupport:
         executor.queue_workload(callback_workload, session)
 
         assert len(executor.queued_callbacks) == 1
-        assert callback_data.id in executor.queued_callbacks
+        assert callback_workload.key in executor.queued_callbacks
 
     @pytest.mark.db_test
     def test_get_workloads_prioritizes_callbacks(self, dag_maker, session):
diff --git a/airflow-core/tests/unit/executors/test_local_executor.py 
b/airflow-core/tests/unit/executors/test_local_executor.py
index af135fc5f99..2c9e42d23aa 100644
--- a/airflow-core/tests/unit/executors/test_local_executor.py
+++ b/airflow-core/tests/unit/executors/test_local_executor.py
@@ -427,7 +427,7 @@ class TestLocalExecutorCallbackSupport:
         executor.start()
 
         try:
-            executor.queued_callbacks[callback_data.id] = callback_workload
+            executor.queued_callbacks[callback_workload.key] = 
callback_workload
             executor._process_workloads([callback_workload])
             assert len(executor.queued_callbacks) == 0
             # We can't easily verify worker execution without running the 
worker,
diff --git a/airflow-core/tests/unit/executors/test_workloads.py 
b/airflow-core/tests/unit/executors/test_workloads.py
index 2c3ffbf53ea..6f027a4d3be 100644
--- a/airflow-core/tests/unit/executors/test_workloads.py
+++ b/airflow-core/tests/unit/executors/test_workloads.py
@@ -17,16 +17,21 @@
 # under the License.
 from __future__ import annotations
 
+import dataclasses
 from pathlib import PurePosixPath
 from uuid import uuid4
 
 import jwt
+import pytest
 
 from airflow.api_fastapi.auth.tokens import JWTGenerator
 from airflow.executors import workloads
 from airflow.executors.workloads import TaskInstance, TaskInstanceDTO, base as 
workloads_base
 from airflow.executors.workloads.base import BaseWorkloadSchema, BundleInfo
+from airflow.executors.workloads.callback import CallbackDTO, 
CallbackFetchMethod
 from airflow.executors.workloads.task import ExecuteTask
+from airflow.executors.workloads.types import state_class_for_key
+from airflow.models.callback import CallbackKey
 
 
 def test_task_instance_alias_keeps_backwards_compat():
@@ -82,3 +87,50 @@ def test_generate_token_produces_workload_scope(monkeypatch):
 def test_generate_token_without_generator():
     """generate_token should return empty string when no generator is 
provided."""
     assert BaseWorkloadSchema.generate_token("ti-123", None) == ""
+
+
+def test_callback_key_is_frozen_and_hashable():
+    """CallbackKey must be usable as a dict key (hashable) and immutable 
(frozen)."""
+    cid = "some-uuid-value"
+
+    key = CallbackKey(id=cid)
+    assert hash(key) == hash(CallbackKey(id=cid))
+    assert key == CallbackKey(id=cid)
+    assert key != CallbackKey(id="other")
+
+    # Frozen: assignment raises
+    with pytest.raises(dataclasses.FrozenInstanceError):
+        key.id = "mutated"  # type: ignore[misc]
+
+
+def test_callback_key_str_returns_id():
+    """str(CallbackKey) should return the raw id string."""
+    cid = "some-uuid-value"
+
+    key = CallbackKey(id=cid)
+    assert str(key) == cid
+
+
+def test_callback_key_is_not_a_string():
+    """CallbackKey must NOT pass isinstance(x, str)."""
+
+    key = CallbackKey(id="some-uuid-value")
+    assert not isinstance(key, str)
+
+
+def test_state_class_for_key_raises_on_unknown_type():
+    """state_class_for_key should raise TypeError for unrecognized key 
types."""
+
+    with pytest.raises(TypeError, match="Unknown workload key type"):
+        state_class_for_key("bare-string-is-not-a-key")  # type: 
ignore[arg-type]
+
+
+def test_callback_dto_key_returns_callback_key_instance():
+    """CallbackDTO.key should return a CallbackKey, not a bare string."""
+    cid = "some-uuid-value"
+
+    callback = CallbackDTO(id=cid, 
fetch_method=CallbackFetchMethod.IMPORT_PATH, data={})
+    key = callback.key
+    assert isinstance(key, CallbackKey)
+    assert key.id == cid
+    assert str(key) == cid

Reply via email to