This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 40dd376375 [Unity][TIR] Clear struct info when specializing PrimFunc
(#16584)
40dd376375 is described below
commit 40dd376375b8fdb469477c999b09d8dbb6ba8762
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Mar 9 07:13:45 2024 -0600
[Unity][TIR] Clear struct info when specializing PrimFunc (#16584)
In rare cases, a `PrimFunc` may be annotated with `StructInfo`, to
indicate that it is an impure function with specific shapes for the
parameters. If struct info is present, it is invalidated when
specializing a `PrimFunc`, and should be cleared.
---
src/tir/ir/specialize.cc | 2 ++
tests/python/tir-base/test_tir_specialize.py | 38 ++++++++++++++++++++++++++++
2 files changed, 40 insertions(+)
diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc
index 5964f02932..8095b3141f 100644
--- a/src/tir/ir/specialize.cc
+++ b/src/tir/ir/specialize.cc
@@ -109,6 +109,8 @@ class PrimFuncSpecializer : public StmtExprMutator {
f_ptr->params = std::move(params);
f_ptr->buffer_map = std::move(buffer_map);
f_ptr->body = std::move(body);
+ f_ptr->struct_info_ = NullOpt;
+ f_ptr->checked_type_ = Type(nullptr);
}
return f;
}
diff --git a/tests/python/tir-base/test_tir_specialize.py
b/tests/python/tir-base/test_tir_specialize.py
index f695b85225..fd2843f743 100644
--- a/tests/python/tir-base/test_tir_specialize.py
+++ b/tests/python/tir-base/test_tir_specialize.py
@@ -16,6 +16,8 @@
# under the License.
# pylint: disable=missing-function-docstring, missing-module-docstring
+import pytest
+
import tvm
from tvm.script import tir as T
from tvm.tir.schedule.testing import
assert_structural_equal_ignore_global_symbol
@@ -324,5 +326,41 @@ def test_specialize_buffer_var_to_expr():
tvm.ir.assert_structural_equal(expected, after)
+def test_specialization_removes_struct_info():
+ """Reset struct info in specialization
+
+ While a PrimFunc usually doesn't have a `relax.StructInfo`, the
+ field can be populated in some edge cases. If that PrimFunc is
+ specialized, the struct info should be reset.
+ """
+
+ @T.prim_func(private=True)
+ def before(n: T.int32) -> T.int32:
+ T.ret(n * 10)
+
+ @T.prim_func(private=True)
+ def expected() -> T.int32:
+ T.ret(50)
+
+ sinfo = tvm.relax.FuncStructInfo(
+ [tvm.relax.PrimStructInfo("int32")], tvm.relax.PrimStructInfo("int32")
+ )
+ tvm.relax.expr._update_struct_info(before, sinfo)
+
+ n = before.params[0]
+ param_map = {n: 5}
+ after = before.specialize(param_map)
+
+ tvm.ir.assert_structural_equal(expected, after)
+ assert before.struct_info is not None
+
+ # PrimFuncs do not expose the `struct_info_` field. Checking the
+ # `struct_info` field when it isn't set raises an exception. This
+ # is the desired behavior, since the struct info before
+ # specialization is no longer valid.
+ with pytest.raises(tvm.TVMError):
+ after.struct_info
+
+
if __name__ == "__main__":
tvm.testing.main()