This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 62c719a12a51a2c460e2a16f5274a9c267110395 Author: tqchen <[email protected]> AuthorDate: Mon Apr 21 20:18:46 2025 -0400 Fix lower call packed --- .../test_tir_transform_lower_tvm_builtin.py | 194 ++++++++------------- .../test_tir_transform_make_packed_api.py | 2 + 2 files changed, 79 insertions(+), 117 deletions(-) diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py index 89e8b9e350..c63d2f8a41 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py @@ -18,7 +18,7 @@ import tvm import tvm.testing -from tvm import te +from tvm.script import ir as I from tvm.script import tir as T import numpy as np @@ -31,123 +31,83 @@ def my_matmul(a, b, c): c.copyfrom(np.dot(a.numpy(), b.numpy())) -def check_packed_func(target="llvm"): - ib = tvm.tir.ir_builder.create() - - m = n = k = 16 - - # - # Prepare buffer for a, b and c: - # - a = te.placeholder((m, k), name="a", dtype="float64") - b = te.placeholder((k, n), name="b", dtype="float64") - k = te.reduce_axis((0, k), name="k") - c = te.compute((m, n), lambda i, j: te.sum(a[i, k] * b[k, j], axis=k), name="c") - - a_buffer = tvm.tir.decl_buffer( - a.shape, a.dtype, name="a_buffer", offset_factor=1, strides=[te.var("s1"), 1] - ) - b_buffer = tvm.tir.decl_buffer( - b.shape, b.dtype, name="b_buffer", offset_factor=1, strides=[te.var("s2"), 1] - ) - c_buffer = tvm.tir.decl_buffer( - c.shape, c.dtype, name="c_buffer", offset_factor=1, strides=[te.var("s3"), 1] - ) - - with ib.for_range(0, 10, "i", kind="parallel"): - ib.emit(tvm.tir.call_packed("tvm.test_matmul", a_buffer, b_buffer, c_buffer)) - - stmt = ib.get() - - # Construct a valid IRModule to be lowered: - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([a_buffer, b_buffer, c_buffer], stmt)) - - target = tvm.target.Target(target, host="llvm") - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", "main"))(mod) - mod = tvm.tir.transform.MakePackedAPI()(mod) - - # Do the lowering: - mod = tvm.tir.transform.LowerTVMBuiltin()(mod) - - # Get the PrimFunc from module: - prim_func = mod.functions.items()[0][1] - - node = prim_func.body - - # Recursively visit PrimFunc until we meet the for-loop: - while True: - if isinstance( - node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt, tvm.tir.DeclBuffer) +def test_lower_call_packed(): + @I.ir_module + class Before: + @T.prim_func + def main( + A: T.Buffer((64, 64), "float32"), + B: T.Buffer((64, 64), "float32"), + C: T.Buffer((64, 64), "float32"), + ): + T.func_attr({"target": tvm.target.Target("llvm")}) + T.attr("", "device_id", T.int32(0)) + T.call_packed("tvm.test_matmul", A, B, C) + + @I.ir_module + class Expected: + @T.prim_func + def main( + A: T.Buffer((64, 64), "float32"), + B: T.Buffer((64, 64), "float32"), + C: T.Buffer((64, 64), "float32"), ): - node = node.body - elif isinstance(node, tvm.tir.SeqStmt): - node = node[0] - else: - break - - # For-loop: - assert isinstance(node, tvm.tir.stmt.For) - - # - # let stack_tcode = tir.tvm_stack_alloca("arg_tcode", 4) - # - alloca_tcode = node.body - assert isinstance(alloca_tcode, tvm.tir.LetStmt) - - expected_value = tvm.tir.call_intrin( - "handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "arg_tcode", 4 - ) - expected_var = alloca_tcode.var - expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, alloca_tcode.body) - - tvm.ir.assert_structural_equal(alloca_tcode, expected_stmt, map_free_vars=True) - - # - # let stack_value = tir.tvm_stack_alloca("arg_value", 4) - # - alloca_value = alloca_tcode.body.body - assert isinstance(alloca_value, tvm.tir.LetStmt) - - expected_value = tvm.tir.call_intrin( - "handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "arg_value", 4 - ) - expected_var = alloca_value.var - expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, alloca_value.body) - - tvm.ir.assert_structural_equal(alloca_value, expected_stmt, map_free_vars=True) - - # - # let stack_array = tir.tvm_stack_alloca("array", 3) - # - alloca_array = alloca_value.body - assert isinstance(alloca_array, tvm.tir.LetStmt) - - expected_value = tvm.tir.call_intrin( - "handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "array", 3 - ) - expected_var = alloca_array.var - expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, alloca_array.body) - - tvm.ir.assert_structural_equal(alloca_array, expected_stmt, map_free_vars=True) - - # - # let stack_shape = tir.tvm_stack_alloca("shape", 12) - # - alloca_shape = alloca_array.body - assert isinstance(alloca_shape, tvm.tir.LetStmt) - - expected_value = tvm.tir.call_intrin( - "handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "shape", 12 - ) - expected_var = alloca_shape.var - expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, alloca_shape.body) - - tvm.ir.assert_structural_equal(alloca_shape, expected_stmt, map_free_vars=True) - - -def test_lower_packed_func(): - check_packed_func("llvm") + T.func_attr({"target": tvm.target.Target("llvm")}) + stack_ffi_any: T.handle = T.tvm_stack_alloca("tvm_ffi_any", 4) + stack_array: T.handle = T.tvm_stack_alloca("array", 3) + stack_shape: T.handle("int64") = T.tvm_stack_alloca("shape", 6) + stack_shape_1 = T.decl_buffer((T.int64(6),), "int64", data=stack_shape) + stack_shape_1[0] = T.int64(64) + stack_shape_1[1] = T.int64(64) + T.tvm_struct_set(stack_array, 0, 1, A.data) + stack_shape_2 = T.Buffer((1,), "int64", data=stack_shape) + T.tvm_struct_set(stack_array, 0, 2, T.address_of(stack_shape_2[0])) + T.tvm_struct_set(stack_array, 0, 3, T.reinterpret("handle", T.uint64(0))) + T.tvm_struct_set(stack_array, 0, 4, 2) + T.tvm_struct_set(stack_array, 0, 5, T.uint8(2)) + T.tvm_struct_set(stack_array, 0, 6, T.uint8(32)) + T.tvm_struct_set(stack_array, 0, 7, T.uint16(1)) + T.tvm_struct_set(stack_array, 0, 8, T.uint64(0)) + T.tvm_struct_set(stack_array, 0, 9, 0) + T.tvm_struct_set(stack_array, 0, 10, 1) + stack_shape_1[2] = T.int64(64) + stack_shape_1[3] = T.int64(64) + T.tvm_struct_set(stack_array, 1, 1, B.data) + stack_shape_3 = T.Buffer((3,), "int64", data=stack_shape) + T.tvm_struct_set(stack_array, 1, 2, T.address_of(stack_shape_3[2])) + T.tvm_struct_set(stack_array, 1, 3, T.reinterpret("handle", T.uint64(0))) + T.tvm_struct_set(stack_array, 1, 4, 2) + T.tvm_struct_set(stack_array, 1, 5, T.uint8(2)) + T.tvm_struct_set(stack_array, 1, 6, T.uint8(32)) + T.tvm_struct_set(stack_array, 1, 7, T.uint16(1)) + T.tvm_struct_set(stack_array, 1, 8, T.uint64(0)) + T.tvm_struct_set(stack_array, 1, 9, 0) + T.tvm_struct_set(stack_array, 1, 10, 1) + stack_shape_1[4] = T.int64(64) + stack_shape_1[5] = T.int64(64) + T.tvm_struct_set(stack_array, 2, 1, C.data) + stack_shape_4 = T.Buffer((5,), "int64", data=stack_shape) + T.tvm_struct_set(stack_array, 2, 2, T.address_of(stack_shape_4[4])) + T.tvm_struct_set(stack_array, 2, 3, T.reinterpret("handle", T.uint64(0))) + T.tvm_struct_set(stack_array, 2, 4, 2) + T.tvm_struct_set(stack_array, 2, 5, T.uint8(2)) + T.tvm_struct_set(stack_array, 2, 6, T.uint8(32)) + T.tvm_struct_set(stack_array, 2, 7, T.uint16(1)) + T.tvm_struct_set(stack_array, 2, 8, T.uint64(0)) + T.tvm_struct_set(stack_array, 2, 9, 0) + T.tvm_struct_set(stack_array, 2, 10, 1) + T.tvm_struct_set(stack_ffi_any, 0, 13, 7) + T.tvm_struct_set(stack_ffi_any, 0, 14, T.tvm_struct_get(stack_array, 0, 0, "handle")) + T.tvm_struct_set(stack_ffi_any, 1, 13, 7) + T.tvm_struct_set(stack_ffi_any, 1, 14, T.tvm_struct_get(stack_array, 1, 0, "handle")) + T.tvm_struct_set(stack_ffi_any, 2, 13, 7) + T.tvm_struct_set(stack_ffi_any, 2, 14, T.tvm_struct_get(stack_array, 2, 0, "handle")) + T.tvm_struct_set(stack_ffi_any, 3, 13, 0) + T.tvm_struct_set(stack_ffi_any, 3, 14, T.int64(0)) + T.call_packed_lowered("tvm.test_matmul", stack_ffi_any, 0, 3) + + After = tvm.tir.transform.LowerTVMBuiltin()(Before) + tvm.ir.assert_structural_equal(After, Expected) @tvm.testing.requires_llvm diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index cbd5f0b3e5..fdc67c76f9 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -60,6 +60,7 @@ def _find_compute_scope(func): return result + @pytest.mark.parametrize("use_global_symbol", [True, False]) def test_no_op_when_global_symbol_is_absent(use_global_symbol): func_attr = {"target": tvm.target.Target("llvm", host="llvm")} @@ -272,6 +273,7 @@ def test_zero_arg_function(): T.tvm_struct_set(result, 0, 14, T.Cast("int64", T.int64(42))) return 0 return 0 + After = tvm.tir.transform.MakePackedAPI()(Before) tvm.ir.assert_structural_equal(Expected, After)
