This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new a64d1f1cc3 [TIR] Make T.reinterpret nop when dtype is the same (#16879) a64d1f1cc3 is described below commit a64d1f1cc37da7f202d943c2bea7eb747e624599 Author: Wuwei Lin <wu...@apache.org> AuthorDate: Sun Apr 14 08:21:30 2024 -0700 [TIR] Make T.reinterpret nop when dtype is the same (#16879) * [TIR] Make T.reinterpret nop when dtype is the same * fix scalable vec handling --- python/tvm/tir/op.py | 4 ++-- src/tir/op/op.cc | 8 ++++++-- tests/python/codegen/test_target_codegen_cuda.py | 2 +- .../python/tvmscript/test_tvmscript_parser_tir.py | 22 ++++++++++++++++++++++ 4 files changed, 31 insertions(+), 5 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 8816880e7b..6b72e63f29 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1789,7 +1789,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> Any: return _ffi_api.infinity(dtype, span) # type: ignore -def reinterpret(dtype, value) -> Any: +def reinterpret(dtype, value, span: Optional[Span] = None) -> Any: """infinity value of dtype Parameters @@ -1808,7 +1808,7 @@ def reinterpret(dtype, value) -> Any: value : tvm.Expr The reinterpret cast value of dtype. """ - return call_intrin(dtype, "tir.reinterpret", value) + return _ffi_api.reinterpret(dtype, value, span) # type: ignore def exp(x): diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 7f47e66062..b613639786 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -409,8 +409,10 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { // reinterpret PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) { if (value.dtype() == t) return value; - ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes()) - << "Bitcast requires size match " << t << " vs " << value.dtype(); + if (!t.is_scalable_vector() && !value.dtype().is_scalable_vector()) { + ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes()) + << "Bitcast requires size match " << t << " vs " << value.dtype(); + } return tir::Call(t, tir::builtin::reinterpret(), {value}, span); } @@ -1083,6 +1085,8 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); +TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret); + // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \ diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 23ba0fc3ce..112c521d06 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -1120,7 +1120,7 @@ def test_invalid_reinterpret(): @T.prim_func def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None: for tx in T.thread_binding(4, "threadIdx.x"): - B[tx] = T.reinterpret("uint8", A[tx]) + B[tx] = T.call_intrin("uint8", "tir.reinterpret", A[tx]) with pytest.raises(tvm.error.TVMError): tvm.build(func, target="cuda") diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 465ffa5cb6..530746a6fc 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -449,5 +449,27 @@ def test_inferred_sinfo_with_dynamic_buffer(): tvm.ir.assert_structural_equal(func.struct_info, expected) +def test_reinterpret_nop(): + """Test builtin reinterpret op""" + + @T.prim_func + def func(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> None: + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 32): + with T.block(): + vi = T.axis.remap("S", [i]) + B[vi] = T.reinterpret("float32", A[vi]) + + @T.prim_func + def expected(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> None: + T.func_attr({"global_symbol": "main"}) + for i in T.serial(0, 32): + with T.block(): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + tvm.ir.assert_structural_equal(func, expected) + + if __name__ == "__main__": tvm.testing.main()