This is an automated email from the ASF dual-hosted git repository.

jason810496 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 28e82d24c97 Enforce supervisor schema class name matches its `type` 
literal (#66899)
28e82d24c97 is described below

commit 28e82d24c979ec5a075e255674f5b3aa42121d77
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Fri May 15 11:09:40 2026 +0800

    Enforce supervisor schema class name matches its `type` literal (#66899)
    
    * Add prek hook to keep supervisor schema class name in sync with its 
`type` literal
    
    * Refactor prek hooks as unit tests
---
 task-sdk/src/airflow/sdk/execution_time/comms.py   |  4 +-
 .../task_sdk/execution_time/test_supervisor.py     |  2 +-
 .../test_supervisor_schemas_name_type_sync.py      | 78 ++++++++++++++++++++++
 3 files changed, 81 insertions(+), 3 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index a30872a6a54..c56f5b23ab3 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -509,7 +509,7 @@ class XComResult(XComResponse):
 
 class XComCountResponse(BaseModel):
     len: int
-    type: Literal["XComLengthResponse"] = "XComLengthResponse"
+    type: Literal["XComCountResponse"] = "XComCountResponse"
 
 
 class XComSequenceIndexResult(BaseModel):
@@ -869,7 +869,7 @@ class GetXComCount(BaseModel):
     dag_id: str
     run_id: str
     task_id: str
-    type: Literal["GetNumberXComs"] = "GetNumberXComs"
+    type: Literal["GetXComCount"] = "GetXComCount"
 
 
 class GetXComSequenceItem(BaseModel):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index b54477b7769..51131f3c48d 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -2627,7 +2627,7 @@ REQUEST_TEST_CASES = [
     ),
     RequestTestCase(
         message=GetXComCount(key="test_key", dag_id="test_dag", 
run_id="test_run", task_id="test_task"),
-        expected_body={"len": 5, "type": "XComLengthResponse"},
+        expected_body={"len": 5, "type": "XComCountResponse"},
         client_mock=ClientMock(
             method_path="xcoms.head",
             args=("test_dag", "test_run", "test_task", "test_key"),
diff --git 
a/task-sdk/tests/task_sdk/execution_time/test_supervisor_schemas_name_type_sync.py
 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor_schemas_name_type_sync.py
new file mode 100644
index 00000000000..e1c3ef13118
--- /dev/null
+++ 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor_schemas_name_type_sync.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Guard the invariant that every supervisor-schema body's class
+``__name__`` equals the value of its ``type`` ``Literal`` discriminator.
+
+``CommsDecoder`` routes incoming wire frames against the ``type`` literal
+on each member of a discriminated union, but downstream consumers
+(registry lookups, snapshot codegen, debug output) identify the head
+class by ``__name__``. If the two strings drift, a frame decodes against
+one class then surfaces under a different name -- a silent contract
+break.
+"""
+
+from __future__ import annotations
+
+from typing import Annotated, get_args, get_origin
+
+import pytest
+from pydantic import BaseModel
+
+from airflow.dag_processing.processor import ToDagProcessor, ToManager
+from airflow.jobs.triggerer_job_runner import ToTriggerRunner, 
ToTriggerSupervisor
+from airflow.sdk.execution_time.comms import ToSupervisor, ToTask
+
+
+def _members_of_union(union: object) -> tuple[type[BaseModel], ...]:
+    """Return the BaseModel classes composing an ``Annotated[A | B | ..., 
Field(...)]``."""
+    if get_origin(union) is Annotated:
+        union = get_args(union)[0]
+    return tuple(m for m in get_args(union) if isinstance(m, type) and 
issubclass(m, BaseModel))
+
+
+# All six supervisor discriminated unions. Triggerer's two unions are
+# not part of the lang-SDK-facing registry, but the same name/type
+# invariant is required for ``CommsDecoder`` to round-trip them.
+SUPERVISOR_UNIONS = [
+    pytest.param(ToTask, id="ToTask"),
+    pytest.param(ToSupervisor, id="ToSupervisor"),
+    pytest.param(ToManager, id="ToManager"),
+    pytest.param(ToDagProcessor, id="ToDagProcessor"),
+    pytest.param(ToTriggerRunner, id="ToTriggerRunner"),
+    pytest.param(ToTriggerSupervisor, id="ToTriggerSupervisor"),
+]
+
+
[email protected]("union", SUPERVISOR_UNIONS)
+def test_class_name_matches_type_literal(union):
+    """For every member, ``cls.__name__`` must equal its ``type`` Literal 
value."""
+    mismatches: list[str] = []
+    for member in _members_of_union(union):
+        field = member.model_fields.get("type")
+        if field is None:
+            mismatches.append(f"{member.__name__}: missing `type` field")
+            continue
+        args = get_args(field.annotation)
+        if len(args) != 1:
+            mismatches.append(f"{member.__name__}: `type` must be a 
single-value Literal, got {args!r}")
+            continue
+        literal = args[0]
+        if literal != member.__name__:
+            mismatches.append(f"{member.__name__}: type literal = {literal!r}, 
expected {member.__name__!r}")
+
+    assert not mismatches, "Class __name__ must equal its `type` Literal 
value:\n  " + "\n  ".join(mismatches)

Reply via email to