This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new d5725a8443 [Unity][CUTLASS] Support batched matmul + residual fusion 
(#14613)
d5725a8443 is described below

commit d5725a84430c139e2992ba5cbee557fddcead724
Author: masahi <masahi...@gmail.com>
AuthorDate: Sat Apr 15 05:02:08 2023 +0900

    [Unity][CUTLASS] Support batched matmul + residual fusion (#14613)
    
    support batched matmul + residual fusion
---
 python/tvm/contrib/cutlass/gemm_operation.py | 21 ++++++++++-----------
 tests/python/relax/test_codegen_cutlass.py   |  2 ++
 2 files changed, 12 insertions(+), 11 deletions(-)

diff --git a/python/tvm/contrib/cutlass/gemm_operation.py 
b/python/tvm/contrib/cutlass/gemm_operation.py
index b820ead016..60ee106919 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -297,12 +297,11 @@ def instantiate_gemm_template(attrs):
     """
 
     # See cutlass/gemm/kernel/gemm_with_fused_epilogue.h
-    # Batched GEMM + residual fusion is not supported for now.
     argument_template_residual = """
   typename ${kernel}::Arguments arguments{
-    cutlass::gemm::GemmUniversalMode::kGemm,
+    cutlass::gemm::GemmUniversalMode::${gemm_universal_mode},
     problem_size,
-    1, // batch_count,
+    ${split_k_slices_or_batch}, // batch_count
     {${alpha_beta}},
     static_cast<ElementInputA*>(ptr_a),
     static_cast<ElementInputB*>(ptr_b),
@@ -310,10 +309,10 @@ def instantiate_gemm_template(attrs):
     static_cast<ElementOutput*>(ptr_out),
     static_cast<ElementOutput*>(ptr_bias),
     nullptr, // ptr_Tensor
-    0, // batch_stride_A,
-    0, // batch_stride_B,
-    0, // batch_stride_C,
-    0, // batch_stride_D,
+    ${batch_stride_A}
+    ${batch_stride_B}
+    ${batch_stride_C}
+    ${batch_stride_D}
     0, // batch_stride_Vector,
     0, // batch_stride_Tensor,
     ${lda},
@@ -388,13 +387,13 @@ def instantiate_gemm_template(attrs):
         aux_map["alpha_beta"] = "alpha, beta"
 
     for key in ["batch_stride_A", "batch_stride_B", "batch_stride_C"]:
-        if not batched:
+        if not batched and not has_residual_block:
             aux_map[key] = ""
         else:
-            aux_map[key] = attrs[key] + ","
+            aux_map[key] = attrs.get(key, "0") + ","
 
     aux_map["batch_stride_D"] = aux_map["batch_stride_C"]
-    if has_bias and batched:
+    if has_bias and batched and not has_residual_block:
         aux_map["batch_stride_C"] = "0,"
 
     if batched:
@@ -403,9 +402,9 @@ def instantiate_gemm_template(attrs):
         attrs["split_k_slices_or_batch"] = 1
 
     if has_residual_block:
-        assert not batched, "Residual fusion is supported only for non-batched 
GEMM for now."
         template = substitute_template(template, {"argument": 
argument_template_residual})
         aux_map["residual_decl"] = "void* ptr_residual = 
(void*)(${residual_arg}->data);\n"
+        aux_map["gemm_universal_mode"] = "kBatched" if batched else "kGemm"
     else:
         template = substitute_template(template, {"argument": 
argument_template_default})
         aux_map["residual_decl"] = ""
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 9288db3eb5..4309627bf0 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -317,6 +317,8 @@ def test_cutlass_partition_conv2d_residual_blocked():
         # Residual
         ((32, 8), (8, 8), False, "bias", "add"),
         ((4, 16), (16, 16), True, "relu", "add_relu"),
+        ((8, 32, 8), (8, 8, 8), False, "bias", "add"),
+        ((5, 3, 32, 8), (8, 8), True, "relu", "add"),
         # Residual fusion without bias - this is supported via the matmul + 
bias pattern
         # where bias == residual input
         ((4, 16), (16, 16), False, "none", "add"),

Reply via email to