This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 62c719a12a51a2c460e2a16f5274a9c267110395
Author: tqchen <[email protected]>
AuthorDate: Mon Apr 21 20:18:46 2025 -0400

    Fix lower call packed
---
 .../test_tir_transform_lower_tvm_builtin.py        | 194 ++++++++-------------
 .../test_tir_transform_make_packed_api.py          |   2 +
 2 files changed, 79 insertions(+), 117 deletions(-)

diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py 
b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
index 89e8b9e350..c63d2f8a41 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
@@ -18,7 +18,7 @@
 import tvm
 import tvm.testing
 
-from tvm import te
+from tvm.script import ir as I
 from tvm.script import tir as T
 
 import numpy as np
@@ -31,123 +31,83 @@ def my_matmul(a, b, c):
     c.copyfrom(np.dot(a.numpy(), b.numpy()))
 
 
-def check_packed_func(target="llvm"):
-    ib = tvm.tir.ir_builder.create()
-
-    m = n = k = 16
-
-    #
-    # Prepare buffer for a, b and c:
-    #
-    a = te.placeholder((m, k), name="a", dtype="float64")
-    b = te.placeholder((k, n), name="b", dtype="float64")
-    k = te.reduce_axis((0, k), name="k")
-    c = te.compute((m, n), lambda i, j: te.sum(a[i, k] * b[k, j], axis=k), 
name="c")
-
-    a_buffer = tvm.tir.decl_buffer(
-        a.shape, a.dtype, name="a_buffer", offset_factor=1, 
strides=[te.var("s1"), 1]
-    )
-    b_buffer = tvm.tir.decl_buffer(
-        b.shape, b.dtype, name="b_buffer", offset_factor=1, 
strides=[te.var("s2"), 1]
-    )
-    c_buffer = tvm.tir.decl_buffer(
-        c.shape, c.dtype, name="c_buffer", offset_factor=1, 
strides=[te.var("s3"), 1]
-    )
-
-    with ib.for_range(0, 10, "i", kind="parallel"):
-        ib.emit(tvm.tir.call_packed("tvm.test_matmul", a_buffer, b_buffer, 
c_buffer))
-
-    stmt = ib.get()
-
-    # Construct a valid IRModule to be lowered:
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([a_buffer, b_buffer, 
c_buffer], stmt))
-
-    target = tvm.target.Target(target, host="llvm")
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", 
"main"))(mod)
-    mod = tvm.tir.transform.MakePackedAPI()(mod)
-
-    # Do the lowering:
-    mod = tvm.tir.transform.LowerTVMBuiltin()(mod)
-
-    # Get the PrimFunc from module:
-    prim_func = mod.functions.items()[0][1]
-
-    node = prim_func.body
-
-    # Recursively visit PrimFunc until we meet the for-loop:
-    while True:
-        if isinstance(
-            node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt, 
tvm.tir.DeclBuffer)
+def test_lower_call_packed():
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(
+            A: T.Buffer((64, 64), "float32"),
+            B: T.Buffer((64, 64), "float32"),
+            C: T.Buffer((64, 64), "float32"),
+        ):
+            T.func_attr({"target": tvm.target.Target("llvm")})
+            T.attr("", "device_id", T.int32(0))
+            T.call_packed("tvm.test_matmul", A, B, C)
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def main(
+            A: T.Buffer((64, 64), "float32"),
+            B: T.Buffer((64, 64), "float32"),
+            C: T.Buffer((64, 64), "float32"),
         ):
-            node = node.body
-        elif isinstance(node, tvm.tir.SeqStmt):
-            node = node[0]
-        else:
-            break
-
-    # For-loop:
-    assert isinstance(node, tvm.tir.stmt.For)
-
-    #
-    # let stack_tcode = tir.tvm_stack_alloca("arg_tcode", 4)
-    #
-    alloca_tcode = node.body
-    assert isinstance(alloca_tcode, tvm.tir.LetStmt)
-
-    expected_value = tvm.tir.call_intrin(
-        "handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "arg_tcode", 4
-    )
-    expected_var = alloca_tcode.var
-    expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, 
alloca_tcode.body)
-
-    tvm.ir.assert_structural_equal(alloca_tcode, expected_stmt, 
map_free_vars=True)
-
-    #
-    # let stack_value = tir.tvm_stack_alloca("arg_value", 4)
-    #
-    alloca_value = alloca_tcode.body.body
-    assert isinstance(alloca_value, tvm.tir.LetStmt)
-
-    expected_value = tvm.tir.call_intrin(
-        "handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "arg_value", 4
-    )
-    expected_var = alloca_value.var
-    expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, 
alloca_value.body)
-
-    tvm.ir.assert_structural_equal(alloca_value, expected_stmt, 
map_free_vars=True)
-
-    #
-    # let stack_array = tir.tvm_stack_alloca("array", 3)
-    #
-    alloca_array = alloca_value.body
-    assert isinstance(alloca_array, tvm.tir.LetStmt)
-
-    expected_value = tvm.tir.call_intrin(
-        "handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "array", 3
-    )
-    expected_var = alloca_array.var
-    expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, 
alloca_array.body)
-
-    tvm.ir.assert_structural_equal(alloca_array, expected_stmt, 
map_free_vars=True)
-
-    #
-    # let stack_shape = tir.tvm_stack_alloca("shape", 12)
-    #
-    alloca_shape = alloca_array.body
-    assert isinstance(alloca_shape, tvm.tir.LetStmt)
-
-    expected_value = tvm.tir.call_intrin(
-        "handle", tvm.ir.Op.get("tir.tvm_stack_alloca"), "shape", 12
-    )
-    expected_var = alloca_shape.var
-    expected_stmt = tvm.tir.LetStmt(expected_var, expected_value, 
alloca_shape.body)
-
-    tvm.ir.assert_structural_equal(alloca_shape, expected_stmt, 
map_free_vars=True)
-
-
-def test_lower_packed_func():
-    check_packed_func("llvm")
+            T.func_attr({"target": tvm.target.Target("llvm")})
+            stack_ffi_any: T.handle = T.tvm_stack_alloca("tvm_ffi_any", 4)
+            stack_array: T.handle = T.tvm_stack_alloca("array", 3)
+            stack_shape: T.handle("int64") = T.tvm_stack_alloca("shape", 6)
+            stack_shape_1 = T.decl_buffer((T.int64(6),), "int64", 
data=stack_shape)
+            stack_shape_1[0] = T.int64(64)
+            stack_shape_1[1] = T.int64(64)
+            T.tvm_struct_set(stack_array, 0, 1, A.data)
+            stack_shape_2 = T.Buffer((1,), "int64", data=stack_shape)
+            T.tvm_struct_set(stack_array, 0, 2, T.address_of(stack_shape_2[0]))
+            T.tvm_struct_set(stack_array, 0, 3, T.reinterpret("handle", 
T.uint64(0)))
+            T.tvm_struct_set(stack_array, 0, 4, 2)
+            T.tvm_struct_set(stack_array, 0, 5, T.uint8(2))
+            T.tvm_struct_set(stack_array, 0, 6, T.uint8(32))
+            T.tvm_struct_set(stack_array, 0, 7, T.uint16(1))
+            T.tvm_struct_set(stack_array, 0, 8, T.uint64(0))
+            T.tvm_struct_set(stack_array, 0, 9, 0)
+            T.tvm_struct_set(stack_array, 0, 10, 1)
+            stack_shape_1[2] = T.int64(64)
+            stack_shape_1[3] = T.int64(64)
+            T.tvm_struct_set(stack_array, 1, 1, B.data)
+            stack_shape_3 = T.Buffer((3,), "int64", data=stack_shape)
+            T.tvm_struct_set(stack_array, 1, 2, T.address_of(stack_shape_3[2]))
+            T.tvm_struct_set(stack_array, 1, 3, T.reinterpret("handle", 
T.uint64(0)))
+            T.tvm_struct_set(stack_array, 1, 4, 2)
+            T.tvm_struct_set(stack_array, 1, 5, T.uint8(2))
+            T.tvm_struct_set(stack_array, 1, 6, T.uint8(32))
+            T.tvm_struct_set(stack_array, 1, 7, T.uint16(1))
+            T.tvm_struct_set(stack_array, 1, 8, T.uint64(0))
+            T.tvm_struct_set(stack_array, 1, 9, 0)
+            T.tvm_struct_set(stack_array, 1, 10, 1)
+            stack_shape_1[4] = T.int64(64)
+            stack_shape_1[5] = T.int64(64)
+            T.tvm_struct_set(stack_array, 2, 1, C.data)
+            stack_shape_4 = T.Buffer((5,), "int64", data=stack_shape)
+            T.tvm_struct_set(stack_array, 2, 2, T.address_of(stack_shape_4[4]))
+            T.tvm_struct_set(stack_array, 2, 3, T.reinterpret("handle", 
T.uint64(0)))
+            T.tvm_struct_set(stack_array, 2, 4, 2)
+            T.tvm_struct_set(stack_array, 2, 5, T.uint8(2))
+            T.tvm_struct_set(stack_array, 2, 6, T.uint8(32))
+            T.tvm_struct_set(stack_array, 2, 7, T.uint16(1))
+            T.tvm_struct_set(stack_array, 2, 8, T.uint64(0))
+            T.tvm_struct_set(stack_array, 2, 9, 0)
+            T.tvm_struct_set(stack_array, 2, 10, 1)
+            T.tvm_struct_set(stack_ffi_any, 0, 13, 7)
+            T.tvm_struct_set(stack_ffi_any, 0, 14, 
T.tvm_struct_get(stack_array, 0, 0, "handle"))
+            T.tvm_struct_set(stack_ffi_any, 1, 13, 7)
+            T.tvm_struct_set(stack_ffi_any, 1, 14, 
T.tvm_struct_get(stack_array, 1, 0, "handle"))
+            T.tvm_struct_set(stack_ffi_any, 2, 13, 7)
+            T.tvm_struct_set(stack_ffi_any, 2, 14, 
T.tvm_struct_get(stack_array, 2, 0, "handle"))
+            T.tvm_struct_set(stack_ffi_any, 3, 13, 0)
+            T.tvm_struct_set(stack_ffi_any, 3, 14, T.int64(0))
+            T.call_packed_lowered("tvm.test_matmul", stack_ffi_any, 0, 3)
+
+    After = tvm.tir.transform.LowerTVMBuiltin()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
 
 
 @tvm.testing.requires_llvm
diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py 
b/tests/python/tir-transform/test_tir_transform_make_packed_api.py
index cbd5f0b3e5..fdc67c76f9 100644
--- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py
+++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py
@@ -60,6 +60,7 @@ def _find_compute_scope(func):
 
     return result
 
+
 @pytest.mark.parametrize("use_global_symbol", [True, False])
 def test_no_op_when_global_symbol_is_absent(use_global_symbol):
     func_attr = {"target": tvm.target.Target("llvm", host="llvm")}
@@ -272,6 +273,7 @@ def test_zero_arg_function():
                 T.tvm_struct_set(result, 0, 14, T.Cast("int64", T.int64(42)))
                 return 0
             return 0
+
     After = tvm.tir.transform.MakePackedAPI()(Before)
     tvm.ir.assert_structural_equal(Expected, After)
 

Reply via email to