masahi commented on code in PR #15318:
URL: https://github.com/apache/tvm/pull/15318#discussion_r1263481377


##########
tests/python/relax/test_codegen_cutlass.py:
##########
@@ -1488,6 +1491,153 @@ def split_transform_deploy_mod(mod):
         tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_fp16A_int8B_gemm():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def decode(
+            A: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            B: T.Buffer((T.int64(64),), "float16"),
+            decode_1: T.Buffer((T.int64(64), T.int64(64)), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i, j in T.grid(T.int64(64), T.int64(64)):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(A[v_i, v_j], B[v_j])
+                    T.writes(decode_1[v_i, v_j])
+                    decode_1[v_i, v_j] = T.Cast("float16", A[v_i, v_j]) * 
B[v_j]
+
+        @T.prim_func
+        def encode(
+            A: T.Buffer((T.int64(64), T.int64(64)), "float16"),
+            w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"),
+            compute: T.Buffer((T.int64(64),), "float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            max_abs_value = T.alloc_buffer((T.int64(64),), "float16")
+            scale = T.alloc_buffer((T.int64(64),))
+            for i, k in T.grid(T.int64(64), T.int64(64)):
+                with T.block("max_abs_value"):
+                    v_i, v_k = T.axis.remap("SR", [i, k])
+                    T.reads(A[v_i, v_k])
+                    T.writes(max_abs_value[v_i])
+                    with T.init():
+                        max_abs_value[v_i] = T.float16(-65504)
+                    max_abs_value[v_i] = T.max(max_abs_value[v_i], 
T.fabs(A[v_i, v_k]))
+            for i in range(T.int64(64)):
+                with T.block("scale"):
+                    v_i = T.axis.spatial(T.int64(64), i)
+                    T.reads(max_abs_value[v_i])
+                    T.writes(scale[v_i])
+                    scale[v_i] = T.max(
+                        T.Cast("float32", max_abs_value[v_i]), 
T.float32(0.0001)
+                    ) * T.float32(0.0078125)
+            for j, i in T.grid(T.int64(64), T.int64(64)):
+                with T.block("w_gathered"):
+                    v_j, v_i = T.axis.remap("SS", [j, i])
+                    T.reads(A[v_i, v_j], scale[v_i])
+                    T.writes(w_gathered[v_j, v_i])
+                    w_gathered[v_j, v_i] = T.Cast(
+                        "int8",
+                        T.min(
+                            T.max(
+                                T.round(T.Cast("float32", A[v_i, v_j]) / 
scale[v_i]),
+                                T.float32(-128),
+                            ),
+                            T.float32(127),
+                        ),
+                    )
+            for i0 in range(T.int64(64)):
+                with T.block("compute"):
+                    v_i0 = T.axis.spatial(T.int64(64), i0)
+                    T.reads(scale[v_i0])
+                    T.writes(compute[v_i0])
+                    compute[v_i0] = T.Cast("float16", scale[v_i0])
+
+        @R.function
+        def main(
+            x: R.Tensor((64, 64), dtype="float16"),
+            y: R.Tensor((64, 64), dtype="float16"),
+            bias: R.Tensor((64, 64), dtype="float16"),
+        ) -> R.Tensor((64, 64), dtype="float16"):
+            R.func_attr({"num_input": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.encode,
+                    (y,),
+                    out_sinfo=[R.Tensor((64, 64), dtype="int8"), 
R.Tensor((64,), dtype="float16")],
+                )
+                lv1: R.Tensor((64, 64), dtype="int8") = lv[0]
+                lv2: R.Tensor((64, 64), dtype="int8") = R.call_pure_packed(
+                    "cutlass.ft_preprocess_weight",
+                    lv1,
+                    R.prim_value(80),
+                    R.prim_value(0),
+                    sinfo_args=(R.Tensor((64, 64), dtype="int8"),),
+                )
+                lv3: R.Tensor((64,), dtype="float16") = lv[1]
+                lv4: R.Tensor((64, 64), dtype="int8") = 
R.builtin.stop_lift_params(lv2)
+                lv5: R.Tensor((64,), dtype="float16") = 
R.builtin.stop_lift_params(lv3)
+                lv6 = R.call_tir(
+                    cls.decode, (lv4, lv5), out_sinfo=R.Tensor((64, 64), 
dtype="float16")
+                )
+                lv1_1: R.Tensor((64, 64), dtype="float16") = R.matmul(x, lv6, 
out_dtype="float16")
+                lv2_1: R.Tensor((64, 128), dtype="float16") = R.add(lv1_1, 
bias)
+                lv2_2: R.Tensor((64, 128), dtype="float16") = R.nn.gelu(lv2_1)

Review Comment:
   This test demonstrates that we are no supporting a bias shape like this and 
also gelu activation offloaded to FT.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to