This is an automated email from the ASF dual-hosted git repository.
jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 28e82d24c97 Enforce supervisor schema class name matches its `type`
literal (#66899)
28e82d24c97 is described below
commit 28e82d24c979ec5a075e255674f5b3aa42121d77
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Fri May 15 11:09:40 2026 +0800
Enforce supervisor schema class name matches its `type` literal (#66899)
* Add prek hook to keep supervisor schema class name in sync with its
`type` literal
* Refactor prek hooks as unit tests
---
task-sdk/src/airflow/sdk/execution_time/comms.py | 4 +-
.../task_sdk/execution_time/test_supervisor.py | 2 +-
.../test_supervisor_schemas_name_type_sync.py | 78 ++++++++++++++++++++++
3 files changed, 81 insertions(+), 3 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index a30872a6a54..c56f5b23ab3 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -509,7 +509,7 @@ class XComResult(XComResponse):
class XComCountResponse(BaseModel):
len: int
- type: Literal["XComLengthResponse"] = "XComLengthResponse"
+ type: Literal["XComCountResponse"] = "XComCountResponse"
class XComSequenceIndexResult(BaseModel):
@@ -869,7 +869,7 @@ class GetXComCount(BaseModel):
dag_id: str
run_id: str
task_id: str
- type: Literal["GetNumberXComs"] = "GetNumberXComs"
+ type: Literal["GetXComCount"] = "GetXComCount"
class GetXComSequenceItem(BaseModel):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index b54477b7769..51131f3c48d 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -2627,7 +2627,7 @@ REQUEST_TEST_CASES = [
),
RequestTestCase(
message=GetXComCount(key="test_key", dag_id="test_dag",
run_id="test_run", task_id="test_task"),
- expected_body={"len": 5, "type": "XComLengthResponse"},
+ expected_body={"len": 5, "type": "XComCountResponse"},
client_mock=ClientMock(
method_path="xcoms.head",
args=("test_dag", "test_run", "test_task", "test_key"),
diff --git
a/task-sdk/tests/task_sdk/execution_time/test_supervisor_schemas_name_type_sync.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor_schemas_name_type_sync.py
new file mode 100644
index 00000000000..e1c3ef13118
--- /dev/null
+++
b/task-sdk/tests/task_sdk/execution_time/test_supervisor_schemas_name_type_sync.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Guard the invariant that every supervisor-schema body's class
+``__name__`` equals the value of its ``type`` ``Literal`` discriminator.
+
+``CommsDecoder`` routes incoming wire frames against the ``type`` literal
+on each member of a discriminated union, but downstream consumers
+(registry lookups, snapshot codegen, debug output) identify the head
+class by ``__name__``. If the two strings drift, a frame decodes against
+one class then surfaces under a different name -- a silent contract
+break.
+"""
+
+from __future__ import annotations
+
+from typing import Annotated, get_args, get_origin
+
+import pytest
+from pydantic import BaseModel
+
+from airflow.dag_processing.processor import ToDagProcessor, ToManager
+from airflow.jobs.triggerer_job_runner import ToTriggerRunner,
ToTriggerSupervisor
+from airflow.sdk.execution_time.comms import ToSupervisor, ToTask
+
+
+def _members_of_union(union: object) -> tuple[type[BaseModel], ...]:
+ """Return the BaseModel classes composing an ``Annotated[A | B | ...,
Field(...)]``."""
+ if get_origin(union) is Annotated:
+ union = get_args(union)[0]
+ return tuple(m for m in get_args(union) if isinstance(m, type) and
issubclass(m, BaseModel))
+
+
+# All six supervisor discriminated unions. Triggerer's two unions are
+# not part of the lang-SDK-facing registry, but the same name/type
+# invariant is required for ``CommsDecoder`` to round-trip them.
+SUPERVISOR_UNIONS = [
+ pytest.param(ToTask, id="ToTask"),
+ pytest.param(ToSupervisor, id="ToSupervisor"),
+ pytest.param(ToManager, id="ToManager"),
+ pytest.param(ToDagProcessor, id="ToDagProcessor"),
+ pytest.param(ToTriggerRunner, id="ToTriggerRunner"),
+ pytest.param(ToTriggerSupervisor, id="ToTriggerSupervisor"),
+]
+
+
[email protected]("union", SUPERVISOR_UNIONS)
+def test_class_name_matches_type_literal(union):
+ """For every member, ``cls.__name__`` must equal its ``type`` Literal
value."""
+ mismatches: list[str] = []
+ for member in _members_of_union(union):
+ field = member.model_fields.get("type")
+ if field is None:
+ mismatches.append(f"{member.__name__}: missing `type` field")
+ continue
+ args = get_args(field.annotation)
+ if len(args) != 1:
+ mismatches.append(f"{member.__name__}: `type` must be a
single-value Literal, got {args!r}")
+ continue
+ literal = args[0]
+ if literal != member.__name__:
+ mismatches.append(f"{member.__name__}: type literal = {literal!r},
expected {member.__name__!r}")
+
+ assert not mismatches, "Class __name__ must equal its `type` Literal
value:\n " + "\n ".join(mismatches)