This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b50ee5fce2 [TIR][Schedule] Fix type checker to support subscripted
generics in Python 3.14+ (#18639)
b50ee5fce2 is described below
commit b50ee5fce20aa98e1f1bdecd97b14490c24e6809
Author: Haejoon Kim <[email protected]>
AuthorDate: Tue Jan 6 18:42:54 2026 +0900
[TIR][Schedule] Fix type checker to support subscripted generics in Python
3.14+ (#18639)
This PR fixes the type annotation checker in
`tvm.tir.schedule._type_checker` to correctly handle subscripted
generics (e.g., `Union[str, int]`, `List[str]`, `Tuple[str, int]`) in
Python 3.14+.
## Background
In Python 3.14, the internal representation of generic types has
changed:
- `Union[str, int]` is now of type `typing.Union` instead of
`typing._GenericAlias` or `typing._SpecialGenericAlias`
- These types now have `__origin__` attribute directly on the type
object
- The existing type checker failed to recognize these new
representations, causing the dispatcher to fall through to "atomic"
instead of correctly identifying them as "union", "list", etc.
## Changes
Added a check for `__origin__` attribute at the beginning of the method
to handle Python 3.14's new generic type representations. This is fully
backward compatible since the new `__origin__` check is only applied
when the attribute exists.
## Tests
Added parametrized tests to verify the dispatcher correctly handles
subscripted generics:
- `Union[str, int]` → identified as "union"
- `List[str]` → identified as "list"
- `Dict[str, int]` → identified as "dict"
- `Tuple[str, int]` → identified as "tuple"
- `Union[List[str], Dict[str, int]]` → identified as "union" with nested
generics
---
python/tvm/tir/schedule/_type_checker.py | 4 +++
.../python/testing/test_type_annotation_checker.py | 34 ++++++++++++++++++++++
2 files changed, 38 insertions(+)
diff --git a/python/tvm/tir/schedule/_type_checker.py
b/python/tvm/tir/schedule/_type_checker.py
index 5c51b1b09f..148016fb2d 100644
--- a/python/tvm/tir/schedule/_type_checker.py
+++ b/python/tvm/tir/schedule/_type_checker.py
@@ -47,6 +47,10 @@ if hasattr(typing, "_GenericAlias"):
class _Subtype:
@staticmethod
def _origin(type_: Any) -> Any:
+ # In Python 3.14+, check if the type has __origin__ attribute
directly
+ if hasattr(type_, "__origin__"):
+ return type_.__origin__
+
if hasattr(typing, "_SpecialGenericAlias"):
if isinstance(type_, typing._SpecialGenericAlias): # type:
ignore # pylint: disable=protected-access
return type_.__origin__
diff --git a/tests/python/testing/test_type_annotation_checker.py
b/tests/python/testing/test_type_annotation_checker.py
index 42ce1e1039..71bc9ba98b 100644
--- a/tests/python/testing/test_type_annotation_checker.py
+++ b/tests/python/testing/test_type_annotation_checker.py
@@ -187,5 +187,39 @@ def test_not_matches(type_annotation, case):
func(case)
[email protected](
+ ["type_annotation", "expected_key", "expected_subtypes"],
+ [
+ pytest.param(Union[str, int], "union", [str, int], id="Union[str,
int]"),
+ pytest.param(List[str], "list", [str], id="List[str]"),
+ pytest.param(Dict[str, int], "dict", [str, int], id="Dict[str, int]"),
+ pytest.param(Tuple[str, int], "tuple", (str, int), id="Tuple[str,
int]"),
+ pytest.param(
+ Union[List[str], Dict[str, int]],
+ "union",
+ [List[str], Dict[str, int]],
+ id="Union[List[str], Dict[str, int]]",
+ ),
+ ],
+)
+def test_subscripted_generics(type_annotation, expected_key,
expected_subtypes):
+ """Test that _dispatcher correctly handles subscripted generics in Python
3.14+.
+
+ In Python 3.14, Union and other generic types have a different internal
representation.
+ This test ensures that the dispatcher correctly identifies these types.
+ """
+ from tvm.tir.schedule._type_checker import _dispatcher
+
+ key, subtypes = _dispatcher(type_annotation)
+ assert key == expected_key, f"Expected '{expected_key}' but got '{key}'"
+
+ if isinstance(expected_subtypes, tuple):
+ assert (
+ tuple(subtypes) == expected_subtypes
+ ), f"Expected {expected_subtypes} but got {subtypes}"
+ else:
+ assert subtypes == expected_subtypes, f"Expected {expected_subtypes}
but got {subtypes}"
+
+
if __name__ == "__main__":
tvm.testing.main()