This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a64d1f1cc3 [TIR] Make T.reinterpret nop when dtype is the same (#16879)
a64d1f1cc3 is described below

commit a64d1f1cc37da7f202d943c2bea7eb747e624599
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Sun Apr 14 08:21:30 2024 -0700

    [TIR] Make T.reinterpret nop when dtype is the same (#16879)
    
    * [TIR] Make T.reinterpret nop when dtype is the same
    
    * fix scalable vec handling
---
 python/tvm/tir/op.py                               |  4 ++--
 src/tir/op/op.cc                                   |  8 ++++++--
 tests/python/codegen/test_target_codegen_cuda.py   |  2 +-
 .../python/tvmscript/test_tvmscript_parser_tir.py  | 22 ++++++++++++++++++++++
 4 files changed, 31 insertions(+), 5 deletions(-)

diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 8816880e7b..6b72e63f29 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1789,7 +1789,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> 
Any:
     return _ffi_api.infinity(dtype, span)  # type: ignore
 
 
-def reinterpret(dtype, value) -> Any:
+def reinterpret(dtype, value, span: Optional[Span] = None) -> Any:
     """infinity value of dtype
 
     Parameters
@@ -1808,7 +1808,7 @@ def reinterpret(dtype, value) -> Any:
     value : tvm.Expr
         The reinterpret cast value of dtype.
     """
-    return call_intrin(dtype, "tir.reinterpret", value)
+    return _ffi_api.reinterpret(dtype, value, span)  # type: ignore
 
 
 def exp(x):
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 7f47e66062..b613639786 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -409,8 +409,10 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span 
span) {
 // reinterpret
 PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) {
   if (value.dtype() == t) return value;
-  ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes())
-      << "Bitcast requires size match " << t << " vs " << value.dtype();
+  if (!t.is_scalable_vector() && !value.dtype().is_scalable_vector()) {
+    ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * 
t.lanes())
+        << "Bitcast requires size match " << t << " vs " << value.dtype();
+  }
   return tir::Call(t, tir::builtin::reinterpret(), {value}, span);
 }
 
@@ -1083,6 +1085,8 @@ 
TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
 
 TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
 
+TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret);
+
 // operator overloading, smarter than make
 #define REGISTER_MAKE_BINARY_OP(Node, Func)                                    
            \
   TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, 
Span span) { \
diff --git a/tests/python/codegen/test_target_codegen_cuda.py 
b/tests/python/codegen/test_target_codegen_cuda.py
index 23ba0fc3ce..112c521d06 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -1120,7 +1120,7 @@ def test_invalid_reinterpret():
     @T.prim_func
     def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None:
         for tx in T.thread_binding(4, "threadIdx.x"):
-            B[tx] = T.reinterpret("uint8", A[tx])
+            B[tx] = T.call_intrin("uint8", "tir.reinterpret", A[tx])
 
     with pytest.raises(tvm.error.TVMError):
         tvm.build(func, target="cuda")
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py 
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index 465ffa5cb6..530746a6fc 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -449,5 +449,27 @@ def test_inferred_sinfo_with_dynamic_buffer():
     tvm.ir.assert_structural_equal(func.struct_info, expected)
 
 
+def test_reinterpret_nop():
+    """Test builtin reinterpret op"""
+
+    @T.prim_func
+    def func(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> 
None:
+        T.func_attr({"global_symbol": "main"})
+        for i in T.serial(0, 32):
+            with T.block():
+                vi = T.axis.remap("S", [i])
+                B[vi] = T.reinterpret("float32", A[vi])
+
+    @T.prim_func
+    def expected(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) 
-> None:
+        T.func_attr({"global_symbol": "main"})
+        for i in T.serial(0, 32):
+            with T.block():
+                vi = T.axis.remap("S", [i])
+                B[vi] = A[vi]
+
+    tvm.ir.assert_structural_equal(func, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to