This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new d5725a8443 [Unity][CUTLASS] Support batched matmul + residual fusion (#14613) d5725a8443 is described below commit d5725a84430c139e2992ba5cbee557fddcead724 Author: masahi <masahi...@gmail.com> AuthorDate: Sat Apr 15 05:02:08 2023 +0900 [Unity][CUTLASS] Support batched matmul + residual fusion (#14613) support batched matmul + residual fusion --- python/tvm/contrib/cutlass/gemm_operation.py | 21 ++++++++++----------- tests/python/relax/test_codegen_cutlass.py | 2 ++ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index b820ead016..60ee106919 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -297,12 +297,11 @@ def instantiate_gemm_template(attrs): """ # See cutlass/gemm/kernel/gemm_with_fused_epilogue.h - # Batched GEMM + residual fusion is not supported for now. argument_template_residual = """ typename ${kernel}::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, + cutlass::gemm::GemmUniversalMode::${gemm_universal_mode}, problem_size, - 1, // batch_count, + ${split_k_slices_or_batch}, // batch_count {${alpha_beta}}, static_cast<ElementInputA*>(ptr_a), static_cast<ElementInputB*>(ptr_b), @@ -310,10 +309,10 @@ def instantiate_gemm_template(attrs): static_cast<ElementOutput*>(ptr_out), static_cast<ElementOutput*>(ptr_bias), nullptr, // ptr_Tensor - 0, // batch_stride_A, - 0, // batch_stride_B, - 0, // batch_stride_C, - 0, // batch_stride_D, + ${batch_stride_A} + ${batch_stride_B} + ${batch_stride_C} + ${batch_stride_D} 0, // batch_stride_Vector, 0, // batch_stride_Tensor, ${lda}, @@ -388,13 +387,13 @@ def instantiate_gemm_template(attrs): aux_map["alpha_beta"] = "alpha, beta" for key in ["batch_stride_A", "batch_stride_B", "batch_stride_C"]: - if not batched: + if not batched and not has_residual_block: aux_map[key] = "" else: - aux_map[key] = attrs[key] + "," + aux_map[key] = attrs.get(key, "0") + "," aux_map["batch_stride_D"] = aux_map["batch_stride_C"] - if has_bias and batched: + if has_bias and batched and not has_residual_block: aux_map["batch_stride_C"] = "0," if batched: @@ -403,9 +402,9 @@ def instantiate_gemm_template(attrs): attrs["split_k_slices_or_batch"] = 1 if has_residual_block: - assert not batched, "Residual fusion is supported only for non-batched GEMM for now." template = substitute_template(template, {"argument": argument_template_residual}) aux_map["residual_decl"] = "void* ptr_residual = (void*)(${residual_arg}->data);\n" + aux_map["gemm_universal_mode"] = "kBatched" if batched else "kGemm" else: template = substitute_template(template, {"argument": argument_template_default}) aux_map["residual_decl"] = "" diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 9288db3eb5..4309627bf0 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -317,6 +317,8 @@ def test_cutlass_partition_conv2d_residual_blocked(): # Residual ((32, 8), (8, 8), False, "bias", "add"), ((4, 16), (16, 16), True, "relu", "add_relu"), + ((8, 32, 8), (8, 8, 8), False, "bias", "add"), + ((5, 3, 32, 8), (8, 8), True, "relu", "add"), # Residual fusion without bias - this is supported via the matmul + bias pattern # where bias == residual input ((4, 16), (16, 16), False, "none", "add"),