This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 40dd376375 [Unity][TIR] Clear struct info when specializing PrimFunc 
(#16584)
40dd376375 is described below

commit 40dd376375b8fdb469477c999b09d8dbb6ba8762
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Mar 9 07:13:45 2024 -0600

    [Unity][TIR] Clear struct info when specializing PrimFunc (#16584)
    
    In rare cases, a `PrimFunc` may be annotated with `StructInfo`, to
    indicate that it is an impure function with specific shapes for the
    parameters.  If struct info is present, it is invalidated when
    specializing a `PrimFunc`, and should be cleared.
---
 src/tir/ir/specialize.cc                     |  2 ++
 tests/python/tir-base/test_tir_specialize.py | 38 ++++++++++++++++++++++++++++
 2 files changed, 40 insertions(+)

diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc
index 5964f02932..8095b3141f 100644
--- a/src/tir/ir/specialize.cc
+++ b/src/tir/ir/specialize.cc
@@ -109,6 +109,8 @@ class PrimFuncSpecializer : public StmtExprMutator {
       f_ptr->params = std::move(params);
       f_ptr->buffer_map = std::move(buffer_map);
       f_ptr->body = std::move(body);
+      f_ptr->struct_info_ = NullOpt;
+      f_ptr->checked_type_ = Type(nullptr);
     }
     return f;
   }
diff --git a/tests/python/tir-base/test_tir_specialize.py 
b/tests/python/tir-base/test_tir_specialize.py
index f695b85225..fd2843f743 100644
--- a/tests/python/tir-base/test_tir_specialize.py
+++ b/tests/python/tir-base/test_tir_specialize.py
@@ -16,6 +16,8 @@
 # under the License.
 # pylint: disable=missing-function-docstring, missing-module-docstring
 
+import pytest
+
 import tvm
 from tvm.script import tir as T
 from tvm.tir.schedule.testing import 
assert_structural_equal_ignore_global_symbol
@@ -324,5 +326,41 @@ def test_specialize_buffer_var_to_expr():
     tvm.ir.assert_structural_equal(expected, after)
 
 
+def test_specialization_removes_struct_info():
+    """Reset struct info in specialization
+
+    While a PrimFunc usually doesn't have a `relax.StructInfo`, the
+    field can be populated in some edge cases.  If that PrimFunc is
+    specialized, the struct info should be reset.
+    """
+
+    @T.prim_func(private=True)
+    def before(n: T.int32) -> T.int32:
+        T.ret(n * 10)
+
+    @T.prim_func(private=True)
+    def expected() -> T.int32:
+        T.ret(50)
+
+    sinfo = tvm.relax.FuncStructInfo(
+        [tvm.relax.PrimStructInfo("int32")], tvm.relax.PrimStructInfo("int32")
+    )
+    tvm.relax.expr._update_struct_info(before, sinfo)
+
+    n = before.params[0]
+    param_map = {n: 5}
+    after = before.specialize(param_map)
+
+    tvm.ir.assert_structural_equal(expected, after)
+    assert before.struct_info is not None
+
+    # PrimFuncs do not expose the `struct_info_` field.  Checking the
+    # `struct_info` field when it isn't set raises an exception.  This
+    # is the desired behavior, since the struct info before
+    # specialization is no longer valid.
+    with pytest.raises(tvm.TVMError):
+        after.struct_info
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to