This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 61f80814e6 [TVMScript] Fix PEP 563 closure variable resolution (#18856)
61f80814e6 is described below
commit 61f80814e655c97be48e47cd19b09ea5a8636f4b
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Feb 28 22:38:50 2026 -0500
[TVMScript] Fix PEP 563 closure variable resolution (#18856)
With `from __future__ import annotations`, Python stores annotations as
strings
and does not capture annotation-only variables in `__closure__`. This
broke
TVMScript when buffer shapes/dtypes referenced closure variables.
Fix: wrap `extra_vars` in a `collections.ChainMap` with snapshots of all
live
caller-frame locals (from `inspect.stack()`) as fallback layers in both
`tir/entry.py` (`prim_func`) and `ir/entry.py` (`ir_module`). The
`ir_module`
function now also captures `outer_stack = inspect.stack()` at its entry
point,
mirroring the existing pattern in `prim_func`. Lookup falls back to
frame locals
only on cache miss, preserving existing behavior for non-PEP-563 code.
Add `tests/python/tvmscript/test_tvmscript_pep563_closure.py` (requires
`from __future__ import annotations` at the top) covering closure
variables in
buffer shapes, dtypes, nested scopes, ir_module, and mixed
annotation+body use.
---
python/tvm/script/parser/core/utils.py | 82 +++++++++++
python/tvm/script/parser/ir/entry.py | 8 +-
python/tvm/script/parser/tir/entry.py | 4 +-
.../tvmscript/test_tvmscript_pep563_closure.py | 158 +++++++++++++++++++++
4 files changed, 250 insertions(+), 2 deletions(-)
diff --git a/python/tvm/script/parser/core/utils.py
b/python/tvm/script/parser/core/utils.py
index 85190b96d9..fc8a928e05 100644
--- a/python/tvm/script/parser/core/utils.py
+++ b/python/tvm/script/parser/core/utils.py
@@ -89,6 +89,88 @@ def inspect_class_capture(cls: type) -> dict[str, Any]:
return result
+def _collect_annotation_names(source_obj: type | Callable) -> set[str]:
+ """Parse source AST to find names used in function annotations.
+
+ Returns the set of ``ast.Name`` identifiers found inside argument
+ annotations and return annotations of any function definitions in
+ *source_obj*.
+ """
+ import ast
+ import textwrap
+
+ try:
+ source = textwrap.dedent(inspect.getsource(source_obj))
+ tree = ast.parse(source)
+ except (OSError, TypeError):
+ return set()
+
+ names: set[str] = set()
+ for node in ast.walk(tree):
+ if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef):
+ for arg in node.args.args + node.args.posonlyargs +
node.args.kwonlyargs:
+ if arg.annotation:
+ for n in ast.walk(arg.annotation):
+ if isinstance(n, ast.Name):
+ names.add(n.id)
+ if node.returns:
+ for n in ast.walk(node.returns):
+ if isinstance(n, ast.Name):
+ names.add(n.id)
+ return names
+
+
+def _has_string_annotations(source_obj: type | Callable) -> bool:
+ """Check if *source_obj* has stringified annotations (PEP 563)."""
+ if inspect.isclass(source_obj):
+ return any(
+ isinstance(a, str)
+ for v in source_obj.__dict__.values()
+ if inspect.isfunction(v)
+ for a in v.__annotations__.values()
+ )
+ return any(isinstance(a, str) for a in getattr(source_obj,
"__annotations__", {}).values())
+
+
+def _get_enclosing_scope_names(qualname: str) -> set[str]:
+ """Extract lexically enclosing scope names from ``__qualname__``.
+
+ For ``outer.<locals>.inner.<locals>.func`` this returns ``{"outer",
"inner"}``.
+ """
+ parts = qualname.split(".")
+ return {p for p in parts[:-1] if p != "<locals>"}
+
+
+def resolve_closure_vars(
+ source_obj: type | Callable, extra_vars: dict[str, Any], outer_stack: list
+) -> None:
+ """Resolve closure variables hidden by PEP 563.
+
+ With ``from __future__ import annotations``, variables used only in
+ annotations are not captured in ``__closure__``. This function parses
+ the source AST to find names used in function annotations, then looks
+ them up in lexically enclosing scope frames identified via
+ ``__qualname__``.
+
+ Only triggered when annotations are actually strings (PEP 563 active).
+ Only annotation-referenced names are added, and only from enclosing
+ scopes — not from arbitrary caller frames.
+
+ Works for both classes (``@I.ir_module``) and functions (``@T.prim_func``).
+ """
+ if not _has_string_annotations(source_obj):
+ return
+ ann_names = _collect_annotation_names(source_obj)
+ enclosing = _get_enclosing_scope_names(source_obj.__qualname__)
+ for name in ann_names:
+ if name not in extra_vars:
+ for frame_info in outer_stack[1:]:
+ if frame_info.frame.f_code.co_name in enclosing:
+ if name in frame_info.frame.f_locals:
+ extra_vars[name] = frame_info.frame.f_locals[name]
+ break
+
+
def is_defined_in_class(frames: list[FrameType], obj: Any) -> bool:
"""Check whether a object is defined in a class scope.
diff --git a/python/tvm/script/parser/ir/entry.py
b/python/tvm/script/parser/ir/entry.py
index 8f7a5be663..b0685e3db0 100644
--- a/python/tvm/script/parser/ir/entry.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -46,6 +46,9 @@ def ir_module(mod: type | None = None, check_well_formed:
bool = True) -> IRModu
The parsed ir module.
"""
+ # Capture stack outside wrapper (wrapper adds to the stack)
+ outer_stack = inspect.stack()
+
def decorator_wrapper(mod):
if not inspect.isclass(mod):
raise TypeError(f"Expect a class, but got: {mod}")
@@ -53,7 +56,10 @@ def ir_module(mod: type | None = None, check_well_formed:
bool = True) -> IRModu
# Check BasePyModule inheritance
base_py_module_inherited = any(base.__name__ == "BasePyModule" for
base in mod.__bases__)
- m = parse(mod, utils.inspect_class_capture(mod),
check_well_formed=check_well_formed)
+ extra_vars = utils.inspect_class_capture(mod)
+ # Resolve closure variables hidden by PEP 563 (annotation-only names)
+ utils.resolve_closure_vars(mod, extra_vars, outer_stack)
+ m = parse(mod, extra_vars, check_well_formed=check_well_formed)
if base_py_module_inherited:
# Lazy import: tvm.relax cannot be imported at module level in
tvm.script.parser
diff --git a/python/tvm/script/parser/tir/entry.py
b/python/tvm/script/parser/tir/entry.py
index da09851e67..d0486b0d9f 100644
--- a/python/tvm/script/parser/tir/entry.py
+++ b/python/tvm/script/parser/tir/entry.py
@@ -63,7 +63,9 @@ def prim_func(
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(outer_stack, func):
return func
- f = parse(func, utils.inspect_function_capture(func),
check_well_formed=check_well_formed)
+ extra_vars = utils.inspect_function_capture(func)
+ utils.resolve_closure_vars(func, extra_vars, outer_stack)
+ f = parse(func, extra_vars, check_well_formed=check_well_formed)
setattr(f, "__name__", func.__name__)
return f
diff --git a/tests/python/tvmscript/test_tvmscript_pep563_closure.py
b/tests/python/tvmscript/test_tvmscript_pep563_closure.py
new file mode 100644
index 0000000000..a5d26d7f16
--- /dev/null
+++ b/tests/python/tvmscript/test_tvmscript_pep563_closure.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test TVMScript with PEP 563 (from __future__ import annotations).
+
+IMPORTANT: The `from __future__ import annotations` import below is the
+test condition itself, because we need to test compatibility with it.
+"""
+
+from __future__ import annotations
+
+import tvm
+import tvm.testing
+from tvm.script import ir as I
+from tvm.script import tir as T
+
+
+def _normalize(func):
+ """Strip the global_symbol so function names do not affect structural
equality."""
+ return func.with_attr("global_symbol", "")
+
+
+def test_prim_func_closure_shape():
+ """Closure variable used in Buffer shape annotation."""
+
+ def f(M=16):
+ @T.prim_func
+ def func(A: T.Buffer((M,), "float32")):
+ T.evaluate(0)
+
+ return func
+
+ @T.prim_func
+ def expected_16(A: T.Buffer((16,), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def expected_32(A: T.Buffer((32,), "float32")):
+ T.evaluate(0)
+
+ tvm.ir.assert_structural_equal(_normalize(f(16)), _normalize(expected_16))
+ tvm.ir.assert_structural_equal(_normalize(f(32)), _normalize(expected_32))
+
+
+def test_prim_func_closure_dtype():
+ """Closure variable used as Buffer dtype."""
+
+ def f(dtype="float32"):
+ @T.prim_func
+ def func(A: T.Buffer((16,), dtype)):
+ T.evaluate(0)
+
+ return func
+
+ @T.prim_func
+ def expected_f32(A: T.Buffer((16,), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def expected_f16(A: T.Buffer((16,), "float16")):
+ T.evaluate(0)
+
+ tvm.ir.assert_structural_equal(_normalize(f("float32")),
_normalize(expected_f32))
+ tvm.ir.assert_structural_equal(_normalize(f("float16")),
_normalize(expected_f16))
+
+
+def test_prim_func_nested_closure():
+ """Variables from enclosing scope active on the call stack (grandparent
frame fallback).
+
+ With PEP 563, closure-only variables are missing from __closure__ unless
they
+ appear in the function body. The ChainMap fallback walks the live call
stack,
+ so this works when the enclosing frames are still active (outer calls
middle
+ which applies the decorator, keeping outer's frame alive on the stack).
+ """
+
+ def outer(M=16):
+ def middle(N=8):
+ @T.prim_func
+ def func(A: T.Buffer((M, N), "float32")):
+ T.evaluate(0)
+
+ return func
+
+ return middle()
+
+ @T.prim_func
+ def expected_16_8(A: T.Buffer((16, 8), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def expected_32_8(A: T.Buffer((32, 8), "float32")):
+ T.evaluate(0)
+
+ tvm.ir.assert_structural_equal(_normalize(outer(16)),
_normalize(expected_16_8))
+ tvm.ir.assert_structural_equal(_normalize(outer(32)),
_normalize(expected_32_8))
+
+
+def test_ir_module_closure():
+ """Closure variable in @I.ir_module class method."""
+
+ def f(M=16):
+ @I.ir_module
+ class Mod:
+ @T.prim_func
+ def main(A: T.Buffer((M,), "float32")):
+ T.evaluate(0)
+
+ return Mod
+
+ @T.prim_func
+ def expected_16(A: T.Buffer((16,), "float32")):
+ T.evaluate(0)
+
+ @T.prim_func
+ def expected_32(A: T.Buffer((32,), "float32")):
+ T.evaluate(0)
+
+ tvm.ir.assert_structural_equal(_normalize(f(16)["main"]),
_normalize(expected_16))
+ tvm.ir.assert_structural_equal(_normalize(f(32)["main"]),
_normalize(expected_32))
+
+
+def test_mixed_closure_usage():
+ """Closure var used in both annotation AND body -- regression check."""
+
+ def f(M=16):
+ @T.prim_func
+ def func(A: T.Buffer((M,), "float32")):
+ T.evaluate(M)
+
+ return func
+
+ @T.prim_func
+ def expected_16(A: T.Buffer((16,), "float32")):
+ T.evaluate(16)
+
+ @T.prim_func
+ def expected_32(A: T.Buffer((32,), "float32")):
+ T.evaluate(32)
+
+ tvm.ir.assert_structural_equal(_normalize(f(16)), _normalize(expected_16))
+ tvm.ir.assert_structural_equal(_normalize(f(32)), _normalize(expected_32))
+
+
+if __name__ == "__main__":
+ tvm.testing.main()