This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c265cdae97 [TVMScript] Add `__name__` attr for parsed PrimFunc and
IRModule (#14786)
c265cdae97 is described below
commit c265cdae97a9449f6fae2d26db79088aec11e8cc
Author: Siyuan Feng <[email protected]>
AuthorDate: Sun May 7 06:29:46 2023 +0800
[TVMScript] Add `__name__` attr for parsed PrimFunc and IRModule (#14786)
This PR adds `__name__` attr to indicate the func/mod name for parsed
PrimFunc and IRModule.
---
python/tvm/script/parser/ir/entry.py | 4 +++-
python/tvm/script/parser/tir/entry.py | 4 +++-
tests/python/unittest/test_tvmscript_parser_ir.py | 1 +
tests/python/unittest/test_tvmscript_parser_tir.py | 16 ++++++++++++++--
4 files changed, 21 insertions(+), 4 deletions(-)
diff --git a/python/tvm/script/parser/ir/entry.py
b/python/tvm/script/parser/ir/entry.py
index 94fc3d2e2c..5878a1ce55 100644
--- a/python/tvm/script/parser/ir/entry.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -40,7 +40,9 @@ def ir_module(mod: Type) -> IRModule:
if not inspect.isclass(mod):
raise TypeError(f"Expect a class, but got: {mod}")
- return parse(mod, utils.inspect_class_capture(mod))
+ m = parse(mod, utils.inspect_class_capture(mod))
+ setattr(m, "__name__", mod.__name__)
+ return m
setattr(ir_module, "dispatch_token", "ir")
diff --git a/python/tvm/script/parser/tir/entry.py
b/python/tvm/script/parser/tir/entry.py
index 649f817411..d5bff7a856 100644
--- a/python/tvm/script/parser/tir/entry.py
+++ b/python/tvm/script/parser/tir/entry.py
@@ -42,7 +42,9 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(inspect.stack(), func):
return func
- return parse(func, utils.inspect_function_capture(func))
+ f = parse(func, utils.inspect_function_capture(func))
+ setattr(f, "__name__", func.__name__)
+ return f
setattr(prim_func, "dispatch_token", "tir")
diff --git a/tests/python/unittest/test_tvmscript_parser_ir.py
b/tests/python/unittest/test_tvmscript_parser_ir.py
index d3e758fbe1..d33594794f 100644
--- a/tests/python/unittest/test_tvmscript_parser_ir.py
+++ b/tests/python/unittest/test_tvmscript_parser_ir.py
@@ -29,6 +29,7 @@ def test_ir_base():
pass
assert isinstance(BlankIRModule, IRModule) and
len(BlankIRModule.functions.items()) == 0
+ assert BlankIRModule.__name__ == "BlankIRModule"
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py
b/tests/python/unittest/test_tvmscript_parser_tir.py
index 20be6d1498..31bf5cc101 100644
--- a/tests/python/unittest/test_tvmscript_parser_tir.py
+++ b/tests/python/unittest/test_tvmscript_parser_tir.py
@@ -16,8 +16,6 @@
# under the License.
"""Unittests for tvm.script.parser.tir"""
-import pytest
-import inspect
import tvm.testing
from tvm.script.parser import tir as T
from tvm import ir, tir
@@ -59,5 +57,19 @@ def test_tir_ptr_proxy():
)
+def test_tir_func_name():
+ @T.prim_func
+ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128])
+ B = T.match_buffer(b, [128, 128])
+ C = T.match_buffer(c, [128, 128])
+ for i, j, k in T.grid(128, 128, 128):
+ with T.block("update"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+ assert matmul.__name__ == "matmul"
+
+
if __name__ == "__main__":
tvm.testing.main()