This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 046c1ba996 [Relax] Move TIR backend to gpu_generic (#17749)
046c1ba996 is described below

commit 046c1ba996ec0382ef65f450cfcc3e0475dd035b
Author: Siyuan Feng <[email protected]>
AuthorDate: Sun Mar 16 21:41:22 2025 +0800

    [Relax] Move TIR backend to gpu_generic (#17749)
    
    * [Relax] Move TIR backend to gpu_generic
    
    This PR moves the TIR backend code to the gpu_generic backend directory 
structure.
    The main changes include:
    - Relocating TIR implementation files from relax/backend_tir to 
relax/backend/gpu_generic/tir
    - Updating import paths in dispatch_sampling.py and dispatch_sort_scan.py
    - Removing unused files (pattern.py, cutlass.py and related tests)
    
    * remove tir folder
---
 python/tvm/relax/backend/dispatch_sampling.py      |   2 +-
 python/tvm/relax/backend/dispatch_sort_scan.py     |   2 +-
 python/tvm/relax/backend/gpu_generic/__init__.py   |   4 +-
 .../{backend_tir => backend/gpu_generic}/cumsum.py |   0
 .../gpu_generic}/sampling.py                       |   0
 python/tvm/relax/backend_tir/__init__.py           |  22 -
 python/tvm/relax/backend_tir/contrib/__init__.py   |  20 -
 python/tvm/relax/backend_tir/contrib/cutlass.py    | 720 ---------------------
 python/tvm/relax/backend_tir/pattern.py            | 576 -----------------
 tests/python/relax/test_codegen_tir_cutlass.py     | 702 --------------------
 10 files changed, 5 insertions(+), 2043 deletions(-)

diff --git a/python/tvm/relax/backend/dispatch_sampling.py 
b/python/tvm/relax/backend/dispatch_sampling.py
index 68d162fdf1..528529c723 100644
--- a/python/tvm/relax/backend/dispatch_sampling.py
+++ b/python/tvm/relax/backend/dispatch_sampling.py
@@ -36,7 +36,7 @@ class SamplingDispatcher(BackendDispatcher):
             return super().visit_call_(call)
 
         if call.op.name == "relax.multinomial_from_uniform":
-            from tvm.relax.backend_tir import (  # pylint: 
disable=import-outside-toplevel
+            from tvm.relax.backend.gpu_generic import (  # pylint: 
disable=import-outside-toplevel
                 generic_get_sample_index,
                 gpu_multinomial_from_uniform,
             )
diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py 
b/python/tvm/relax/backend/dispatch_sort_scan.py
index b5a94619c2..9f7cbaee9a 100644
--- a/python/tvm/relax/backend/dispatch_sort_scan.py
+++ b/python/tvm/relax/backend/dispatch_sort_scan.py
@@ -141,7 +141,7 @@ class SortScanDispatcher(BackendDispatcher):
                 and call.op.name == "relax.cumsum"
                 and call.attrs.exclusive == 0
             ):
-                from tvm.relax.backend_tir import (  # pylint: 
disable=import-outside-toplevel
+                from tvm.relax.backend.gpu_generic import (  # pylint: 
disable=import-outside-toplevel
                     gpu_2d_continuous_cumsum,
                 )
 
diff --git a/python/tvm/relax/backend/gpu_generic/__init__.py 
b/python/tvm/relax/backend/gpu_generic/__init__.py
index ea2d2a2afb..d7c316d28c 100644
--- a/python/tvm/relax/backend/gpu_generic/__init__.py
+++ b/python/tvm/relax/backend/gpu_generic/__init__.py
@@ -15,10 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 """The Relax Metal backend compilation pipeline and other passes."""
+from .cumsum import gpu_2d_continuous_cumsum
 from .pipeline import (
+    dataflow_lower_passes,
     finalize_passes,
     get_default_pipeline,
     legalize_passes,
-    dataflow_lower_passes,
     library_dispatch_passes,
 )
+from .sampling import generic_get_sample_index, gpu_multinomial_from_uniform
diff --git a/python/tvm/relax/backend_tir/cumsum.py 
b/python/tvm/relax/backend/gpu_generic/cumsum.py
similarity index 100%
rename from python/tvm/relax/backend_tir/cumsum.py
rename to python/tvm/relax/backend/gpu_generic/cumsum.py
diff --git a/python/tvm/relax/backend_tir/sampling.py 
b/python/tvm/relax/backend/gpu_generic/sampling.py
similarity index 100%
rename from python/tvm/relax/backend_tir/sampling.py
rename to python/tvm/relax/backend/gpu_generic/sampling.py
diff --git a/python/tvm/relax/backend_tir/__init__.py 
b/python/tvm/relax/backend_tir/__init__.py
deleted file mode 100644
index b64bdcda6b..0000000000
--- a/python/tvm/relax/backend_tir/__init__.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Relax backends, tir based"""
-
-from . import contrib
-from .cumsum import gpu_2d_continuous_cumsum
-from .pattern import get_tir_pattern
-from .sampling import gpu_multinomial_from_uniform, generic_get_sample_index
diff --git a/python/tvm/relax/backend_tir/contrib/__init__.py 
b/python/tvm/relax/backend_tir/contrib/__init__.py
deleted file mode 100644
index 9274f22374..0000000000
--- a/python/tvm/relax/backend_tir/contrib/__init__.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""External backend codegen modules for Relax, tir based."""
-
-from .cutlass import cutlass_fcodegen
diff --git a/python/tvm/relax/backend_tir/contrib/cutlass.py 
b/python/tvm/relax/backend_tir/contrib/cutlass.py
deleted file mode 100644
index 0dbe31c468..0000000000
--- a/python/tvm/relax/backend_tir/contrib/cutlass.py
+++ /dev/null
@@ -1,720 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: 
disable=invalid-name,comparison-with-callable,unused-variable,missing-function-docstring
-"""codegen for cutlass"""
-import operator
-from functools import reduce
-from typing import List, Dict, Any
-
-from tvm.contrib.cutlass.build import _get_cutlass_path, 
_get_cutlass_compile_options
-from tvm.contrib.nvcc import get_target_compute_version
-from tvm.contrib.cutlass.library import LayoutType, ConvKind
-from tvm.contrib.cutlass.gen_tensor_op import instantiate_template
-from tvm.contrib.cutlass.gen_gemm import CutlassGemmProfiler
-from tvm.contrib.cutlass.gen_conv2d import CutlassConv2DProfiler
-from ..pattern import (
-    MatchResult,
-    matmul_rrr_fp16,
-    bias_row_2d_fp16,
-    bias_row_1d_fp16,
-    batch_bias_row_2d_fp16,
-    batch_bias_row_1d_fp16,
-    relu_fp16,
-    erf_3d_fp32,
-    batch_matmul_rrr_2d_fp16,
-    batch_matmul_rrr_3d_fp16,
-    conv2d_nhwc_fp16,
-    padding_2d_nhwc_fp16,
-    copy_4d_fp16,
-    bias_add_nhwc_2d_fp16,
-    bias_add_nhwc_1d_fp16,
-    elem_add_4d_fp16,
-    elem_mul_3d_fp16,
-    scalar_add_3d_fp16,
-    scalar_mul_3d_fp16,
-    cast_3d_fp16,
-    cast_3d_fp32,
-)
-
-#### helper functions ####
-# list representing the anchor ops
-# in the future more layouts/dtypes can be supported
-MATMUL_LIST = [matmul_rrr_fp16]
-MATMUL_BIAS_LIST = [bias_row_2d_fp16, bias_row_1d_fp16]
-BATCH_MATMUL_LIST = [batch_matmul_rrr_2d_fp16, batch_matmul_rrr_3d_fp16]
-BATCH_MATMUL_BIAS_LIST = [batch_bias_row_2d_fp16, batch_bias_row_1d_fp16]
-CONV2D_LIST = [conv2d_nhwc_fp16]
-
-# attributes for anchor ops used in code generation
-OP_PATTERN_ATTR_LIST = {
-    matmul_rrr_fp16: {
-        "arg0_dtype": "float16",
-        "arg1_dtype": "float16",
-        "ret_dtype": "float16",
-    },
-    batch_matmul_rrr_2d_fp16: {
-        "arg0_dtype": "float16",
-        "arg1_dtype": "float16",
-        "ret_dtype": "float16",
-    },
-    batch_matmul_rrr_3d_fp16: {
-        "arg0_dtype": "float16",
-        "arg1_dtype": "float16",
-        "ret_dtype": "float16",
-    },
-    conv2d_nhwc_fp16: {
-        "arg0_dtype": "float16",
-        "arg1_dtype": "float16",
-        "ret_dtype": "float16",
-        # in the future we can add layout here
-    },
-}
-
-
-def _get_cutlass_code(attr):
-    pattern = attr["op_type"]
-    if pattern.startswith("cutlass.matmul"):
-        return cutlass_codegen_gemm(attr)
-    elif pattern.startswith("cutlass.conv2d"):
-        return cutlass_codegen_conv2d(attr)
-    else:
-        raise ValueError("op not supported")
-
-
-def _final_code(code, headers, func_args):
-    res = ""
-    res += "#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>\n"
-    res += "#include <tvm/runtime/c_runtime_api.h>\n"
-    res += "#include <tvm/runtime/packed_func.h>\n"
-    res += "#include <dlpack/dlpack.h>\n"
-    res += "#include <cuda_fp16.h>\n"
-    res += "#include <cutlass/cutlass.h>\n"
-    res += "#include <cutlass/coord.h>\n"
-    res += "#include <cutlass/tensor_ref.h>\n"
-    res += "#include <cutlass/util/host_tensor.h>\n"
-
-    for header in headers:
-        res += "#include <" + header + ">\n"
-    res += "namespace {\n"
-    res += "using namespace tvm;\n"
-    res += "using namespace tvm::runtime;\n"
-    res += "void _cutlass_kernel("
-    for arg in func_args:
-        res += "NDArray " + arg + ", "
-    res += "NDArray out0) {"
-    res += code
-    res += "}\n"
-    res += "}  // namespace\n"
-    res += "TVM_DLL_EXPORT_TYPED_FUNC({global_symbol}, _cutlass_kernel);\n"
-    return res
-
-
-#### cutlass patterns ####
-def matmul_bias_relu(match_results, attr, get_code=True):
-    if len(match_results) < 3:
-        return None
-    attr = matmul_bias(match_results[:2], attr, get_code=False)
-    if attr is None or match_results[2].pattern != relu_fp16:
-        return None
-    m_bias, n_bias = match_results[1].symbol_values
-    m_relu, n_relu = match_results[2].symbol_values
-    A_bias, B_bias, C_bias = match_results[1].matched_buffers
-    A_relu, B_relu = match_results[2].matched_buffers
-    if m_bias == m_relu and n_bias == n_relu and C_bias == A_relu:
-        attr["op_type"] = "cutlass.matmul_bias_relu"
-        return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code 
else attr
-    return None
-
-
-def matmul_bias(match_results, attr, get_code=True):
-    if len(match_results) < 2:
-        return None
-    attr = matmul(match_results[:1], attr, get_code=False)
-    if attr is None or match_results[1].pattern not in MATMUL_BIAS_LIST:
-        return None
-    m_matmul, n_matmul, k_matmul = match_results[0].symbol_values
-    m_bias, n_bias = match_results[1].symbol_values
-    A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
-    A_bias, B_bias, C_bias = match_results[1].matched_buffers
-    if m_matmul == m_bias and n_matmul == n_bias and C_matmul == A_bias:
-        attr["op_type"] = "cutlass.matmul_bias"
-        attr["bias_arg_idx"] = 2
-        attr["args"].append(B_bias)
-        return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code 
else attr
-    return None
-
-
-def matmul(match_results, attr, get_code=True):
-    if len(match_results) < 1:
-        return None
-    if match_results[0].pattern in MATMUL_LIST:
-        # matmul
-        attr["op_type"] = "cutlass.matmul"
-        return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code 
else attr
-    return None
-
-
-def batch_matmul_bias_gelu(match_results, attr, get_code=True):
-    if len(match_results) < 9:
-        return None
-    attr = batch_matmul_bias(match_results[:2], attr, get_code=False)  # 
batch_matmul, batch_bias
-    if (
-        attr is None
-        or match_results[2].pattern != scalar_mul_3d_fp16
-        or match_results[3].pattern != cast_3d_fp32
-        or match_results[4].pattern != erf_3d_fp32
-        or match_results[5].pattern != cast_3d_fp16
-        or match_results[6].pattern != scalar_mul_3d_fp16
-        or match_results[7].pattern != scalar_add_3d_fp16
-        or match_results[8].pattern != elem_mul_3d_fp16
-    ):
-        return None
-
-    def shape_match_3d(shape1, shape2):
-        if len(shape1) < 3 or len(shape2) < 3:
-            return False
-        return shape1[0] == shape2[0] and shape1[1] == shape2[1] and shape1[2] 
== shape2[2]
-
-    for i in range(1, 8):
-        if not shape_match_3d(match_results[i].symbol_values, match_results[i 
+ 1].symbol_values):
-            return None
-
-    if not (
-        match_results[1].matched_buffers[-1] == 
match_results[2].matched_buffers[0]
-        and match_results[2].matched_buffers[-1] == 
match_results[3].matched_buffers[0]
-        and match_results[3].matched_buffers[-1] == 
match_results[4].matched_buffers[0]
-        and match_results[4].matched_buffers[-1] == 
match_results[5].matched_buffers[0]
-        and match_results[5].matched_buffers[-1] == 
match_results[6].matched_buffers[0]
-        and match_results[6].matched_buffers[-1] == 
match_results[7].matched_buffers[0]
-        and match_results[1].matched_buffers[-1] == 
match_results[8].matched_buffers[0]
-        and match_results[7].matched_buffers[-1] == 
match_results[8].matched_buffers[1]
-    ):
-        return None
-
-    if (
-        abs(float(match_results[2].symbol_values[-1] - 0.5**0.5)) > 1e-5
-        or abs(float(match_results[6].symbol_values[-1] - 0.5)) > 1e-5
-        or abs(float(match_results[7].symbol_values[-1] - 0.5)) > 1e-5
-    ):
-        return None
-
-    attr["op_type"] = "cutlass.matmul_bias_gelu"
-    return [_get_cutlass_code(attr=attr), 9, attr["args"]] if get_code else 
attr
-
-
-def batch_matmul_bias_residual_mul(match_results, attr, get_code=True):
-    if len(match_results) < 3:
-        return None
-    attr = batch_matmul_bias(match_results[:2], attr, get_code=False)  # 
batch_matmul, batch_bias
-    if attr is None or match_results[2].pattern != elem_mul_3d_fp16:
-        return None
-    (
-        b_bias,
-        m_bias,
-        n_bias,
-    ) = match_results[1].symbol_values
-    (
-        b_mul,
-        m_mul,
-        n_mul,
-    ) = match_results[2].symbol_values
-    A_bias, B_bias, C_bias = match_results[1].matched_buffers
-    A_mul, B_mul, C_mul = match_results[2].matched_buffers
-    if b_bias == b_mul and m_bias == m_mul and n_bias == n_mul and C_bias == 
A_mul:
-        attr["op_type"] = "cutlass.matmul_bias_residual_multiply"
-        attr["residual_arg_idx"] = 3
-        return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code 
else attr
-    return None
-
-
-def batch_matmul_bias(match_results, attr, get_code=True):
-    if len(match_results) < 2:
-        return None
-    attr = batch_matmul(match_results[:1], attr, get_code=False)
-    if attr is None or match_results[1].pattern not in BATCH_MATMUL_BIAS_LIST:
-        return None
-    (
-        b_matmul,
-        m_matmul,
-        n_matmul,
-        k_matmul,
-    ) = match_results[0].symbol_values
-    (
-        b_bias,
-        m_bias,
-        n_bias,
-    ) = match_results[1].symbol_values
-    A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
-    A_bias, B_bias, C_bias = match_results[1].matched_buffers
-    if b_matmul == b_bias and m_matmul == m_bias and n_matmul == n_bias and 
C_matmul == A_bias:
-        attr["op_type"] = "cutlass.matmul_bias"
-        attr["bias_arg_idx"] = 2
-        attr["args"].append(B_bias)
-        return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code 
else attr
-    return None
-
-
-def batch_matmul(match_results, attr, get_code=True):
-    if len(match_results) < 1:
-        return None
-    if match_results[0].pattern in BATCH_MATMUL_LIST:
-        attr["op_type"] = "cutlass.matmul"
-        return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code 
else attr
-    return None
-
-
-def conv2d_bias_residual_add(match_results, attr, get_code=True):
-    if len(match_results) < 4:
-        return None
-    attr = conv2d_bias(match_results[:3], attr, get_code=False)
-    if attr is None or match_results[3].pattern != elem_add_4d_fp16:
-        return None
-    N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values
-    in1_bias, in2_bias, out_bias = match_results[2].matched_buffers
-    N_add, H_add, W_add, C_add = match_results[3].symbol_values
-    in1_add, in2_add, out_add = match_results[3].matched_buffers
-    if (
-        N_bias == N_add
-        and H_bias == H_add
-        and W_bias == W_add
-        and C_bias == C_add
-        and out_bias in [in1_add, in2_add]
-    ):
-        attr["op_type"] = "cutlass.conv2d_bias_residual_add"
-        attr["residual_arg_idx"] = 3
-        attr["args"].append(in2_add if out_bias == in1_add else in1_add)
-        return [_get_cutlass_code(attr=attr), 4, attr["args"]] if get_code 
else attr
-    return None
-
-
-def conv2d_bias(match_results, attr, get_code=True):
-    if len(match_results) < 3:
-        return None
-    attr = conv2d(match_results[:2], attr, get_code=False)
-    if attr is None or (
-        match_results[2].pattern not in [bias_add_nhwc_2d_fp16, 
bias_add_nhwc_1d_fp16]
-    ):
-        return None
-    (N_conv, pH_conv, pW_conv, H_conv, W_conv, C_conv, O_conv,) = 
match_results[
-        1
-    ].symbol_values[:7]
-    A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers
-    N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values
-    A_bias, B_bias, out_bias = match_results[2].matched_buffers
-    if (
-        N_bias == N_conv
-        and H_bias == H_conv
-        and W_bias == W_conv
-        and C_bias == O_conv
-        and out_conv == A_bias
-    ):
-        attr["op_type"] = "cutlass.conv2d_bias"
-        attr["bias_arg_idx"] = 2
-        attr["args"].append(B_bias)
-        return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code 
else attr
-    return None
-
-
-def conv2d(match_results, attr, get_code=True):
-    if len(match_results) < 2:
-        return None
-    if (
-        match_results[0].pattern in [padding_2d_nhwc_fp16, copy_4d_fp16]
-        and match_results[1].pattern == conv2d_nhwc_fp16
-    ):
-        if match_results[0].pattern == padding_2d_nhwc_fp16:
-            (
-                N_pad,
-                H_pad,
-                W_pad,
-                C_pad,
-                pH_pad,
-                pW_pad,
-                lH_pad,
-                lW_pad,
-                rH_pad,
-                rW_pad,
-            ) = match_results[0].symbol_values
-        else:
-            (
-                N_pad,
-                H_pad,
-                W_pad,
-                C_pad,
-            ) = match_results[0].symbol_values
-            pH_pad = rH_pad = H_pad
-            pW_pad = rW_pad = W_pad
-            lH_pad = lW_pad = 0
-        (
-            N_conv,
-            pH_conv,
-            pW_conv,
-            H_conv,
-            W_conv,
-            C_conv,
-            O_conv,
-            KH_conv,
-            KW_conv,
-            stride_h_conv,
-            stride_w_conv,
-            dilation_h_conv,
-            dilation_w_conv,
-        ) = match_results[1].symbol_values
-        A, A_pad = match_results[0].matched_buffers
-        A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers
-        if (
-            N_pad == N_conv
-            and pH_pad == pH_conv
-            and pW_pad == pW_conv
-            and C_pad == C_conv
-            and A_pad == A_pad_conv
-        ):
-            if (
-                lH_pad == pH_pad - rH_pad
-                and lW_pad == pW_pad - rW_pad
-                and lH_pad + H_pad == rH_pad
-                and lW_pad + W_pad == rW_pad
-            ):
-                padding = (lH_pad, lW_pad)
-                strides = (stride_h_conv, stride_w_conv)
-                dilation = (dilation_h_conv, dilation_w_conv)
-                attr["padding"] = padding
-                attr["strides"] = strides
-                attr["dilation"] = dilation
-                attr["op_type"] = "cutlass.conv2d"
-                return [_get_cutlass_code(attr=attr), 2, attr["args"]] if 
get_code else attr
-    return None
-
-
-### cutlass codegen functions ###
-def compile_options(target, threads=-1, use_fast_math=False):
-    compute_version = 
int("".join(get_target_compute_version(target).split(".")))
-    kwargs = _get_cutlass_compile_options(compute_version, threads, 
use_fast_math)
-    kwargs["options"].remove("-c")
-    return kwargs
-
-
-def cutlass_fcodegen(sm=80, bin_dir="./bin"):
-    gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), bin_dir)
-    conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), bin_dir)
-
-    def cutlass_codegen_with_match_results(match_results: List[MatchResult]):
-        """generate cutlass code with match results"""
-        nonlocal gemm_profiler
-        nonlocal conv2d_profiler
-
-        assert len(match_results) > 0
-
-        # add shape into attr
-        if match_results[0].pattern in MATMUL_LIST:
-            A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
-            attr: Dict[Any, Any] = 
OP_PATTERN_ATTR_LIST[match_results[0].pattern]
-            attr["args"] = [A_matmul, B_matmul]
-            attr["arg0_shape"] = A_matmul.shape
-            attr["arg1_shape"] = B_matmul.shape
-            attr["ret_shape"] = C_matmul.shape
-            attr["lhs_arg_idx"] = 0
-            attr["rhs_arg_idx"] = 1
-        elif match_results[0].pattern in BATCH_MATMUL_LIST:
-            A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
-            attr = OP_PATTERN_ATTR_LIST[match_results[0].pattern]
-            attr["args"] = [A_matmul, B_matmul]
-            attr["arg0_shape"] = A_matmul.shape
-            attr["arg1_shape"] = B_matmul.shape
-            attr["ret_shape"] = C_matmul.shape
-            attr["lhs_arg_idx"] = 0
-            attr["rhs_arg_idx"] = 1
-        elif len(match_results) >= 1 and match_results[1].pattern in 
CONV2D_LIST:
-            A_input = match_results[0].matched_buffers[0]
-            A_conv2d, B_conv2d, C_conv2d = match_results[1].matched_buffers
-            attr = OP_PATTERN_ATTR_LIST[match_results[1].pattern]
-            attr["args"] = [A_input, B_conv2d]
-            attr["arg0_shape"] = A_input.shape
-            attr["arg1_shape"] = B_conv2d.shape
-            attr["ret_shape"] = C_conv2d.shape
-            attr["lhs_arg_idx"] = 0
-            attr["rhs_arg_idx"] = 1
-        else:
-            return ["", 0]
-
-        # add profiler into attr
-        attr["gemm_profiler"] = gemm_profiler
-        attr["conv2d_profiler"] = conv2d_profiler
-
-        cutlass_patterns = [
-            # 9
-            batch_matmul_bias_gelu,
-            # 4
-            conv2d_bias_residual_add,
-            # 3
-            batch_matmul_bias_residual_mul,
-            matmul_bias_relu,
-            conv2d_bias,
-            # 2
-            matmul_bias,
-            batch_matmul_bias,
-            conv2d,
-            # 1
-            matmul,
-            batch_matmul,
-        ]
-        for pattern in cutlass_patterns:
-            res = pattern(match_results, attr)
-            if res is not None:
-                return res
-
-        return ["", 0]
-
-    return cutlass_codegen_with_match_results
-
-
-def cutlass_codegen_gemm(attrs):
-    """cutlass codegen for gemm"""
-    gemm_profiler = attrs["gemm_profiler"]
-    op_type = attrs["op_type"]
-    lhs_shape = attrs["arg0_shape"]
-    rhs_shape = attrs["arg1_shape"]
-    MM = lhs_shape[-2]
-    KK = lhs_shape[-1]
-    if "transposed" in op_type:
-        NN = rhs_shape[-2]
-        ldb = "K"
-        layout_b = LayoutType.ColumnMajor
-    else:
-        NN = rhs_shape[-1]
-        ldb = "N"
-        layout_b = LayoutType.RowMajor
-
-    lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
-    rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
-    if lhs_batches == 1 and rhs_batches == 1:
-        # Regular matmul
-        is_batched = False
-        batch_attrs = {}
-    else:
-        is_batched = True
-        batch_attrs = {
-            # If both lhs_batches and rhs_batches are greater than 1,
-            # they must be equal. This is checked by 
is_shape_valid_for_cutlass_matmul.
-            "batch": lhs_batches if rhs_batches == 1 else rhs_batches,
-            "batch_stride_A": 0 if lhs_batches == 1 else MM * KK,
-            "batch_stride_B": 0 if rhs_batches == 1 else KK * NN,
-            "batch_stride_C": MM * NN,
-        }
-    op_name, op_def, _ = gemm_profiler.profile(
-        op_type,
-        MM,
-        NN,
-        KK,
-        attrs["ret_dtype"],
-        attrs["arg0_dtype"],
-        attrs["arg1_dtype"],
-        False,
-        batched=is_batched,
-        find_first_valid=False,
-        use_multiprocessing=True,
-        layout_b=layout_b,
-    )
-    attrs["cutlass_op_name"] = op_name
-    attrs["cutlass_op_def"] = op_def
-    attrs["lda"] = "K"
-    attrs["ldb"] = ldb
-    attrs["ldc"] = "N"
-    attrs.update(batch_attrs)
-    del attrs["gemm_profiler"]
-    del attrs["conv2d_profiler"]
-
-    nargs = 2
-    if "bias_arg_idx" in attrs:
-        nargs += 1
-    if "residual_arg_idx" in attrs:
-        nargs += 1
-    func_args = ["inp" + str(i) for i in range(nargs)]
-
-    # A temporary solution to handle batch matmul residual cases
-    # TODO(@bohan): remove this after initialize_template supports bmm residual
-    if op_type in [
-        "cutlass.matmul_bias_residual_multiply",
-    ]:
-
-        def _convert_dtype_str(dtype):
-            if isinstance(dtype, list):
-                arr = []
-                for t in dtype:
-                    arr.append(_convert_dtype_str(t))
-                return arr
-            elif isinstance(dtype, str):
-                if dtype == "float16":
-                    return "cutlass::half_t"
-                elif dtype == "float32":
-                    return "float"
-            raise ValueError("dtype not supported")
-
-        typea, typeb, typec = _convert_dtype_str(
-            [attrs["arg0_dtype"], attrs["arg1_dtype"], attrs["ret_dtype"]]
-        )
-
-        text = f"""
-#define CUTLASS_ENABLE_CUBLAS 1
-#define CUTLASS_NAMESPACE cutlass
-#define CUTLASS_ENABLE_TENSOR_CORE_MMA 1
-#define NDEBUG
-#include <cutlass/cutlass.h>
-#include <cutlass/tensor_ref.h>
-#include <cutlass/util/host_tensor.h>
-#include <cutlass/gemm/device/gemm.h>
-#include <cutlass/gemm/device/gemm_batched.h>
-#include <cutlass/layout/matrix.h>
-#include <cutlass/numeric_types.h>
-#include "cutlass/epilogue/thread/activation.h"
-#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
-#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
-#include <fstream>
-#include <iostream>
-#include <sstream>
-#include <vector>
-#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
-#include <tvm/runtime/logging.h>
-#include <tvm/runtime/ndarray.h>
-#include <tvm/runtime/packed_func.h>
-namespace {{
-using namespace tvm;
-using namespace tvm::runtime;
-void _BHGEMM(NDArray A, NDArray B, NDArray Bias, NDArray D, NDArray C) {{
-    // A: [Batch, M, K], B: [1, K, N]/[K, N], Bias: [1, N]/[N], D: [Batch, M, 
N], C: [Batch, M, N]
-    CHECK_EQ(A->ndim, 3);
-    int bdim = B->ndim;
-    int bias_dim = Bias->ndim;
-    CHECK_EQ(C->ndim, 3);
-    CHECK_EQ(A->shape[2], B->shape[bdim - 2]);
-    CHECK_EQ(Bias->shape[bias_dim - 1], B->shape[bdim - 1]);
-    CHECK_EQ(D->ndim, 3);
-    CHECK_EQ(D->shape[0], A->shape[0]);
-    CHECK_EQ(D->shape[1], A->shape[1]);
-    CHECK_EQ(D->shape[2], B->shape[bdim - 1]);
-    CHECK_EQ(C->shape[0], A->shape[0]);
-    CHECK_EQ(C->shape[1], A->shape[1]);
-    CHECK_EQ(C->shape[2], B->shape[bdim - 1]);
-    int64_t M = A->shape[0] * A->shape[1];
-    int64_t N = B->shape[bdim - 1];
-    int64_t K = A->shape[2];
-    int64_t input_a_batch_stride = M * K;
-    int64_t input_a_stride = K;
-    int64_t input_a_offset = 0; // default to 0
-    int64_t input_b_batch_stride = K * N;
-    int64_t input_b_stride = N;
-    int64_t input_b_offset = 0; // default to 0
-    int64_t output_stride = N;
-    int64_t output_offset = 0;
-    int64_t a_size = 1;
-    a_size *= A->shape[0];
-    a_size *= A->shape[1];
-    a_size *= A->shape[2];
-
-    int64_t b_size = 1;
-    b_size *= B->shape[bias_dim - 2];
-    b_size *= B->shape[bias_dim - 1];
-
-    int64_t c_size = 1;
-    c_size *= C->shape[0];
-    c_size *= C->shape[1];
-    c_size *= C->shape[2];
-
-    // Define the GEMM operation
-    {op_def}
-    using kernel = Operation_{op_name};
-    using ElementComputeEpilogue = typename kernel::ElementAccumulator;
-    typename kernel::Arguments arguments({{
-        cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
-        {{M, N, K}}, // GemmCoord problem_size
-        1, // int batch_count
-        {{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}}, // typename 
EpilogueOutputOp::Params epilogue
-        ({typea}*)(A->data) + input_a_offset, // void const * ptr_A
-        ({typeb}*)(B->data) + input_b_offset, // void const * ptr_B
-        ({typec}*)(D->data), // void const * ptr_C1
-        ({typec}*)(C->data) + output_offset, // void * ptr_D
-        ({typea}*)(Bias->data), // void * ptr_Vector
-        nullptr, // void * ptr_Tensor
-        input_a_batch_stride, // int64_t batch_stride_A
-        input_b_batch_stride, // int64_t batch_stride_B
-        0, // int64_t batch_stride_C1
-        0, // int64_t batch_stride_D
-        0, // int64_t batch_stride_Vector
-        0, // int64_t batch_stride_Tensor
-        input_a_stride, // typename LayoutA::Stride::Index lda
-        input_b_stride, // typename LayoutB::Stride::Index ldb
-        N, // typename LayoutC::Stride::Index ldc1
-        output_stride, // typename LayoutC::Stride::Index ldd
-        0, // typename LayoutC::Stride::Index ldr
-        0, // typename LayoutC::Stride::Index ldt
-    }});
-    kernel gemm_op;
-    size_t workspace_size = gemm_op.get_workspace_size(arguments);
-    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
-    cutlass::Status status = gemm_op.can_implement(arguments);
-    CHECK(status == cutlass::Status::kSuccess);
-    status = gemm_op.initialize(arguments, workspace.get());
-    CHECK(status == cutlass::Status::kSuccess);
-    status = gemm_op();
-    CHECK(status == cutlass::Status::kSuccess);
-    return;
-}}
-}}  // namespace
-TVM_DLL_EXPORT_TYPED_FUNC({{global_symbol}}, _BHGEMM);
-      """
-        return text
-
-    code = instantiate_template(op_type, attrs, func_args)
-    return _final_code(code.code, code.headers, func_args)
-
-
-def cutlass_codegen_conv2d(attrs):
-    """cutlass codegen for conv2d"""
-    # cutlass backend only supports nhwc for now
-    conv2d_profiler = attrs["conv2d_profiler"]
-    op_type = attrs["op_type"]
-    conv_kind = ConvKind.Fprop
-    op_name, op_def, _ = conv2d_profiler.profile(
-        op_type=attrs["op_type"],
-        d_shape=attrs["arg0_shape"],
-        w_shape=attrs["arg1_shape"],
-        padding=attrs["padding"],
-        stride=attrs["strides"],
-        dilation=attrs["dilation"],
-        out_dtype=attrs["ret_dtype"],
-        data_dtype=attrs["arg0_dtype"],
-        weight_dtype=attrs["arg1_dtype"],
-        use_3xtf32=False,
-        conv_kind=conv_kind,
-        split_k_slices=[1],
-        profile_all_alignments=True,
-        find_first_valid=False,
-        use_multiprocessing=True,
-    )
-    attrs["cutlass_op_def"] = op_def
-    attrs["cutlass_op_name"] = op_name
-    del attrs["gemm_profiler"]
-    del attrs["conv2d_profiler"]
-
-    nargs = 2
-    if "bias_arg_idx" in attrs:
-        nargs += 1
-    if "residual_arg_idx" in attrs:
-        nargs += 1
-    func_args = ["inp" + str(i) for i in range(nargs)]
-    code = instantiate_template(op_type, attrs, func_args)
-    return _final_code(code.code, code.headers, func_args)
diff --git a/python/tvm/relax/backend_tir/pattern.py 
b/python/tvm/relax/backend_tir/pattern.py
deleted file mode 100644
index 10f7a3b162..0000000000
--- a/python/tvm/relax/backend_tir/pattern.py
+++ /dev/null
@@ -1,576 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name,missing-function-docstring,chained-comparison
-"""TIR Patterns"""
-from typing import List
-
-import tvm
-from tvm.runtime import Object
-import tvm._ffi
-
-from tvm.script import tir as T
-
-
-@tvm._ffi.register_object("relax.MatchResult")
-class MatchResult(Object):
-    """The match result of a TIR pattern."""
-
-    def __init__(self, pattern, symbol_values, matched_buffers):
-        self.__init_handle_by_constructor__(
-            tvm._ffi.MatchResult, pattern, symbol_values, matched_buffers
-        )
-
-
[email protected]_func
-def matmul_rrr_fp16(
-    var_rxplaceholder: T.handle,
-    var_rxplaceholder_1: T.handle,
-    var_matmul: T.handle,
-    M: T.int64,
-    N: T.int64,
-    K: T.int64,
-) -> None:
-    # function attr dict
-    T.func_attr({"tir.noalias": True})
-    rxplaceholder = T.match_buffer(var_rxplaceholder, [M, K], dtype="float16")
-    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N], 
dtype="float16")
-    matmul = T.match_buffer(var_matmul, [M, N], dtype="float16")
-    # body
-    # with T.block("root")
-    for i0, i1, i2 in T.grid(M, N, K):
-        with T.block("matmul"):
-            i0_1, i1_1, k = T.axis.remap("SSR", [i0, i1, i2])
-            T.reads(rxplaceholder[i0_1, k], rxplaceholder_1[k, i1_1])
-            T.writes(matmul[i0_1, i1_1])
-            with T.init():
-                matmul[i0_1, i1_1] = T.float16(0)
-            matmul[i0_1, i1_1] = (
-                matmul[i0_1, i1_1] + rxplaceholder[i0_1, k] * 
rxplaceholder_1[k, i1_1]
-            )
-
-
[email protected]_func
-def bias_row_2d_fp16(
-    var_rxplaceholder: T.handle,
-    var_rxplaceholder_1: T.handle,
-    var_T_add: T.handle,
-    M: T.int64,
-    N: T.int64,
-) -> None:
-    # function attr dict
-    T.func_attr({"tir.noalias": True})
-    rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
-    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N], 
dtype="float16")
-    T_add = T.match_buffer(var_T_add, [M, N], dtype="float16")
-    # body
-    # with T.block("root")
-    for i0, i1 in T.grid(M, N):
-        with T.block("T_add"):
-            ax0, ax1 = T.axis.remap("SS", [i0, i1])
-            T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[T.int64(0), ax1])
-            T.writes(T_add[ax0, ax1])
-            T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + 
rxplaceholder_1[T.int64(0), ax1]
-
-
[email protected]_func
-def bias_row_1d_fp16(
-    var_rxplaceholder: T.handle,
-    var_rxplaceholder_1: T.handle,
-    var_T_add: T.handle,
-    M: T.int64,
-    N: T.int64,
-) -> None:
-    # function attr dict
-    T.func_attr({"tir.noalias": True})
-    rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
-    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16")
-    T_add = T.match_buffer(var_T_add, [M, N], dtype="float16")
-    # body
-    # with T.block("root")
-    for i0, i1 in T.grid(M, N):
-        with T.block("T_add"):
-            ax0, ax1 = T.axis.remap("SS", [i0, i1])
-            T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1])
-            T.writes(T_add[ax0, ax1])
-            T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[ax1]
-
-
[email protected]_func
-def batch_bias_row_2d_fp16(
-    var_rxplaceholder: T.handle,
-    var_rxplaceholder_1: T.handle,
-    var_T_add: T.handle,
-    batch: T.int64,
-    M: T.int64,
-    N: T.int64,
-) -> None:
-    # function attr dict
-    T.func_attr({"tir.noalias": True})
-    rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N], 
dtype="float16")
-    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N], 
dtype="float16")
-    T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16")
-    # body
-    # with T.block("root")
-    for i0, i1, i2 in T.grid(batch, M, N):
-        with T.block("T_add"):
-            ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
-            T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[T.int64(0), 
ax2])
-            T.writes(T_add[ax0, ax1, ax2])
-            T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + 
rxplaceholder_1[T.int64(0), ax2]
-
-
[email protected]_func
-def batch_bias_row_1d_fp16(
-    var_rxplaceholder: T.handle,
-    var_rxplaceholder_1: T.handle,
-    var_T_add: T.handle,
-    batch: T.int64,
-    M: T.int64,
-    N: T.int64,
-) -> None:
-    # function attr dict
-    T.func_attr({"tir.noalias": True})
-    rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N], 
dtype="float16")
-    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16")
-    T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16")
-    # body
-    # with T.block("root")
-    for i0, i1, i2 in T.grid(batch, M, N):
-        with T.block("T_add"):
-            ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
-            T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[ax2])
-            T.writes(T_add[ax0, ax1, ax2])
-            T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] + 
rxplaceholder_1[ax2]
-
-
[email protected]_func
-def relu_fp16(var_rxplaceholder: T.handle, var_compute: T.handle, M: T.int64, 
N: T.int64) -> None:
-    # function attr dict
-    T.func_attr({"tir.noalias": True})
-    rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
-    compute = T.match_buffer(var_compute, [M, N], dtype="float16")
-    # body
-    # with T.block("root")
-    for i0, i1 in T.grid(M, N):
-        with T.block("compute"):
-            i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
-            T.reads(rxplaceholder[i0_1, i1_1])
-            T.writes(compute[i0_1, i1_1])
-            compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1], 
T.float16(0))
-
-
[email protected]_func
-def batch_matmul_rrr_2d_fp16(
-    var_rxplaceholder: T.handle,
-    var_rxplaceholder_1: T.handle,
-    var_matmul: T.handle,
-    batch: T.int64,
-    M: T.int64,
-    N: T.int64,
-    K: T.int64,
-) -> None:
-    # function attr dict
-    T.func_attr({"tir.noalias": True})
-    rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K], 
dtype="float16")
-    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N], 
dtype="float16")
-    matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16")
-    # body
-    # with T.block("root")
-    for i0, i1, i2, i3 in T.grid(batch, M, N, K):
-        with T.block("matmul"):
-            i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
-            T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[k, i2_1])
-            T.writes(matmul[i0_1, i1_1, i2_1])
-            with T.init():
-                matmul[i0_1, i1_1, i2_1] = T.float16(0)
-            matmul[i0_1, i1_1, i2_1] = (
-                matmul[i0_1, i1_1, i2_1] + rxplaceholder[i0_1, i1_1, k] * 
rxplaceholder_1[k, i2_1]
-            )
-
-
[email protected]_func
-def batch_matmul_rrr_3d_fp16(
-    var_rxplaceholder: T.handle,
-    var_rxplaceholder_1: T.handle,
-    var_matmul: T.handle,
-    batch: T.int64,
-    M: T.int64,
-    N: T.int64,
-    K: T.int64,
-) -> None:
-    # function attr dict
-    T.func_attr({"tir.noalias": True})
-    rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K], 
dtype="float16")
-    rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [batch, K, N], 
dtype="float16")
-    matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16")
-    # body
-    # with T.block("root")
-    for i0, i1, i2, i3 in T.grid(batch, M, N, K):
-        with T.block("matmul"):
-            i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
-            T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[i0_1, k, 
i2_1])
-            T.writes(matmul[i0_1, i1_1, i2_1])
-            with T.init():
-                matmul[i0_1, i1_1, i2_1] = T.float16(0)
-            matmul[i0_1, i1_1, i2_1] = (
-                matmul[i0_1, i1_1, i2_1]
-                + rxplaceholder[i0_1, i1_1, k] * rxplaceholder_1[i0_1, k, i2_1]
-            )
-
-
[email protected]_func
-def copy_4d_fp16(
-    A_handle: T.handle,
-    B_handle: T.handle,
-    N: T.int64,
-    H: T.int64,
-    W: T.int64,
-    C: T.int64,
-) -> None:
-    A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
-    B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16")
-    # body
-    # with T.block("root")
-    for n, h, w, c in T.grid(N, H, W, C):
-        with T.block("copy"):
-            vn, vh, vw, vc = T.axis.remap("SSSS", [n, h, w, c])
-            T.reads(A[vn, vh, vw, vc])
-            T.writes(B[vn, vh, vw, vc])
-            B[vn, vh, vw, vc] = A[vn, vh, vw, vc]
-
-
[email protected]_func
-def padding_2d_nhwc_fp16(
-    A_handle: T.handle,
-    B_handle: T.handle,
-    N: T.int64,
-    H: T.int64,
-    W: T.int64,
-    C: T.int64,
-    pH: T.int64,
-    pW: T.int64,
-    lH: T.int64,
-    lW: T.int64,
-    rH: T.int64,
-    rW: T.int64,
-) -> None:
-    A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
-    B = T.match_buffer(B_handle, [N, pH, pW, C], dtype="float16")
-    # body
-    # with T.block("root")
-    for v, v_1, v_2, v_3 in T.grid(N, pH, pW, C):
-        with T.block("copy"):
-            v_4, v_5, v_6, v_7 = T.axis.remap("SSSS", [v, v_1, v_2, v_3])
-            T.reads(A[v_4, v_5 - lH, v_6 - lW, v_7])
-            T.writes(B[v_4, v_5, v_6, v_7])
-            B[v_4, v_5, v_6, v_7] = T.if_then_else(
-                lH <= v_5 and v_5 < rH and lW <= v_6 and v_6 < rW,
-                A[v_4, v_5 - lH, v_6 - lW, v_7],
-                T.float16(0),
-                dtype="float16",
-            )
-
-
[email protected]_func
-def conv2d_nhwc_fp16(
-    A_handle: T.handle,
-    B_handle: T.handle,
-    out_handle: T.handle,
-    N: T.int64,
-    pH: T.int64,
-    pW: T.int64,
-    H: T.int64,
-    W: T.int64,
-    C: T.int64,
-    O: T.int64,
-    KH: T.int64,
-    KW: T.int64,
-    StrideH: T.int64,
-    StrideW: T.int64,
-    DilateH: T.int64,
-    DilateW: T.int64,
-) -> None:
-    A = T.match_buffer(A_handle, [N, pH, pW, C], dtype="float16")
-    B = T.match_buffer(B_handle, [O, KH, KW, C], dtype="float16")
-    out = T.match_buffer(out_handle, [N, H, W, O], dtype="float16")
-    # body
-    # with T.block("root")
-    for v, v_1, v_2, v_3, v_4, v_5, v_6 in T.grid(N, H, W, O, KH, KW, C):
-        with T.block("conv"):
-            v_7, v_8, v_9, v_10, v_11, v_12, v_13 = T.axis.remap(
-                "SSSSRRR", [v, v_1, v_2, v_3, v_4, v_5, v_6]
-            )
-            T.reads(
-                A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 * 
StrideW, v_13],
-                B[v_10, v_11, v_12, v_13],
-            )
-            T.writes(out[v_7, v_8, v_9, v_10])
-            with T.init():
-                out[v_7, v_8, v_9, v_10] = T.float16(0)
-            out[v_7, v_8, v_9, v_10] = (
-                out[v_7, v_8, v_9, v_10]
-                + A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 
* StrideW, v_13]
-                * B[v_10, v_11, v_12, v_13]
-            )
-
-
[email protected]_func
-def bias_add_nhwc_2d_fp16(
-    A_handle: T.handle,
-    B_handle: T.handle,
-    out_handle: T.handle,
-    N: T.int64,
-    H: T.int64,
-    W: T.int64,
-    C: T.int64,
-):
-    A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
-    B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16")
-    out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
-    for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
-        with T.block("T_add"):
-            v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, 
ax3])
-            T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), 
T.int64(0), v_ax3])
-            T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
-            out[v_ax0, v_ax1, v_ax2, v_ax3] = (
-                A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, T.int64(0), 
T.int64(0), v_ax3]
-            )
-
-
[email protected]_func
-def bias_add_nhwc_1d_fp16(
-    A_handle: T.handle,
-    B_handle: T.handle,
-    out_handle: T.handle,
-    N: T.int64,
-    H: T.int64,
-    W: T.int64,
-    C: T.int64,
-):
-    A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
-    B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16")
-    out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
-    for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
-        with T.block("T_add"):
-            v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, 
ax3])
-            T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[T.int64(0), T.int64(0), 
T.int64(0), v_ax3])
-            T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
-            out[v_ax0, v_ax1, v_ax2, v_ax3] = (
-                A[v_ax0, v_ax1, v_ax2, v_ax3] + B[T.int64(0), T.int64(0), 
T.int64(0), v_ax3]
-            )
-
-
[email protected]_func
-def elem_add_2d_fp16(
-    in0_handle: T.handle,
-    in1_handle: T.handle,
-    out_handle: T.handle,
-    N: T.int64,
-    M: T.int64,
-):
-    in0 = T.match_buffer(in0_handle, [N, M], dtype="float16")
-    in1 = T.match_buffer(in1_handle, [N, M], dtype="float16")
-    out = T.match_buffer(out_handle, [N, M], dtype="float16")
-    for ax0, ax1 in T.grid(N, M):
-        with T.block("T_add"):
-            v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-            T.reads(in0[v_ax0, v_ax1], in1[v_ax0, v_ax1])
-            T.writes(out[v_ax0, v_ax1])
-            out[v_ax0, v_ax1] = in0[v_ax0, v_ax1] + in1[v_ax0, v_ax1]
-
-
[email protected]_func
-def elem_add_3d_fp16(
-    in0_handle: T.handle,
-    in1_handle: T.handle,
-    out_handle: T.handle,
-    B: T.int64,
-    N: T.int64,
-    M: T.int64,
-):
-    in0 = T.match_buffer(in0_handle, [B, N, M], dtype="float16")
-    in1 = T.match_buffer(in1_handle, [B, N, M], dtype="float16")
-    out = T.match_buffer(out_handle, [B, N, M], dtype="float16")
-    for ax0, ax1, ax2 in T.grid(B, N, M):
-        with T.block("T_add"):
-            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-            T.reads(in0[v_ax0, v_ax1, v_ax2], in1[v_ax0, v_ax1, v_ax2])
-            T.writes(out[v_ax0, v_ax1, v_ax2])
-            out[v_ax0, v_ax1, v_ax2] = in0[v_ax0, v_ax1, v_ax2] + in1[v_ax0, 
v_ax1, v_ax2]
-
-
[email protected]_func
-def elem_add_4d_fp16(
-    A_handle: T.handle,
-    B_handle: T.handle,
-    out_handle: T.handle,
-    N: T.int64,
-    H: T.int64,
-    W: T.int64,
-    C: T.int64,
-):
-    A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
-    B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16")
-    out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
-    for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
-        with T.block("T_add"):
-            v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, 
ax3])
-            T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, v_ax1, v_ax2, 
v_ax3])
-            T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
-            out[v_ax0, v_ax1, v_ax2, v_ax3] = (
-                A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, v_ax1, v_ax2, v_ax3]
-            )
-
-
[email protected]_func
-def scalar_mul_3d_fp16(
-    inp0_handle: T.handle,
-    out_handle: T.handle,
-    D1: T.int64,
-    D2: T.int64,
-    D3: T.int64,
-    scalar: T.float16,
-):
-    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
-    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
-    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
-        with T.block("T_mul"):
-            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-            T.reads(inp0[v_ax0, v_ax1, v_ax2])
-            T.writes(out[v_ax0, v_ax1, v_ax2])
-            out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * scalar
-
-
[email protected]_func
-def erf_3d_fp32(
-    inp0_handle: T.handle,
-    out_handle: T.handle,
-    D1: T.int64,
-    D2: T.int64,
-    D3: T.int64,
-):
-    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32")
-    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32")
-    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
-        with T.block("T_erf"):
-            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-            T.reads(inp0[v_ax0, v_ax1, v_ax2])
-            T.writes(out[v_ax0, v_ax1, v_ax2])
-            out[v_ax0, v_ax1, v_ax2] = T.erf(inp0[v_ax0, v_ax1, v_ax2])
-
-
[email protected]_func
-def scalar_add_3d_fp16(
-    inp0_handle: T.handle,
-    out_handle: T.handle,
-    D1: T.int64,
-    D2: T.int64,
-    D3: T.int64,
-    scalar: T.float16,
-):
-    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
-    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
-    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
-        with T.block("T_add"):
-            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-            T.reads(inp0[v_ax0, v_ax1, v_ax2])
-            T.writes(out[v_ax0, v_ax1, v_ax2])
-            out[v_ax0, v_ax1, v_ax2] = scalar + inp0[v_ax0, v_ax1, v_ax2]
-
-
[email protected]_func
-def elem_mul_3d_fp16(
-    inp0_handle: T.handle,
-    inp1_handle: T.handle,
-    out_handle: T.handle,
-    D1: T.int64,
-    D2: T.int64,
-    D3: T.int64,
-):
-    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
-    inp1 = T.match_buffer(inp1_handle, [D1, D2, D3], dtype="float16")
-    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
-    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
-        with T.block("T_mul"):
-            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-            T.reads(inp0[v_ax0, v_ax1, v_ax2], inp1[v_ax0, v_ax1, v_ax2])
-            T.writes(out[v_ax0, v_ax1, v_ax2])
-            out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * inp1[v_ax0, 
v_ax1, v_ax2]
-
-
[email protected]_func
-def cast_3d_fp16(
-    inp0_handle: T.handle,
-    out_handle: T.handle,
-    D1: T.int64,
-    D2: T.int64,
-    D3: T.int64,
-):
-    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32")
-    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
-    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
-        with T.block("T_cast"):
-            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-            T.reads(inp0[v_ax0, v_ax1, v_ax2])
-            T.writes(out[v_ax0, v_ax1, v_ax2])
-            out[v_ax0, v_ax1, v_ax2] = T.Cast("float16", inp0[v_ax0, v_ax1, 
v_ax2])
-
-
[email protected]_func
-def cast_3d_fp32(
-    inp0_handle: T.handle,
-    out_handle: T.handle,
-    D1: T.int64,
-    D2: T.int64,
-    D3: T.int64,
-):
-    inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
-    out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32")
-    for ax0, ax1, ax2 in T.grid(D1, D2, D3):
-        with T.block("T_cast"):
-            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-            T.reads(inp0[v_ax0, v_ax1, v_ax2])
-            T.writes(out[v_ax0, v_ax1, v_ax2])
-            out[v_ax0, v_ax1, v_ax2] = T.Cast("float32", inp0[v_ax0, v_ax1, 
v_ax2])
-
-
-def get_tir_pattern() -> List[tvm.tir.PrimFunc]:
-    """Get the tir patterns for backend dispatch."""
-    return [
-        matmul_rrr_fp16,
-        bias_row_2d_fp16,
-        bias_row_1d_fp16,
-        batch_bias_row_2d_fp16,
-        batch_bias_row_1d_fp16,
-        relu_fp16,
-        erf_3d_fp32,
-        batch_matmul_rrr_2d_fp16,
-        batch_matmul_rrr_3d_fp16,
-        copy_4d_fp16,
-        padding_2d_nhwc_fp16,
-        conv2d_nhwc_fp16,
-        bias_add_nhwc_2d_fp16,
-        bias_add_nhwc_1d_fp16,
-        elem_add_2d_fp16,
-        elem_add_3d_fp16,
-        elem_add_4d_fp16,
-        elem_mul_3d_fp16,
-        scalar_add_3d_fp16,
-        scalar_mul_3d_fp16,
-        cast_3d_fp16,
-        cast_3d_fp32,
-    ]
diff --git a/tests/python/relax/test_codegen_tir_cutlass.py 
b/tests/python/relax/test_codegen_tir_cutlass.py
deleted file mode 100644
index 9670f15986..0000000000
--- a/tests/python/relax/test_codegen_tir_cutlass.py
+++ /dev/null
@@ -1,702 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import annotations
-import tempfile
-
-from tvm import relax, runtime
-import tvm
-import tvm.testing
-from tvm import relax
-import scipy
-from scipy.special import erf
-import numpy as np
-from tvm.target import Target
-from tvm.relax.vm_build import build as relax_build
-from tvm.script.ir_builder import relax as R
-from tvm.script.ir_builder import ir as I
-from tvm.script.ir_builder import tir as T
-from tvm.script.ir_builder import IRBuilder
-
-from tvm.relax.backend_tir import get_tir_pattern
-from tvm.relax.backend_tir.contrib.cutlass import cutlass_fcodegen, 
compile_options
-
-A_TYPE = "float16"
-B_TYPE = "float16"
-C_TYPE = "float16"
-
-target = Target("cuda")
-
-
-def f_run(rt_mod: runtime.Module, device: runtime.ndarray.Device, *input):
-    vm = relax.vm.VirtualMachine(rt_mod=rt_mod, device=device)
-    return vm["main"](*input)
-
-
-def build(mod):
-    mod = relax.transform.LegalizeOps()(mod)
-    mod = relax.transform.AnnotateTIROpPattern()(mod)
-    mod = relax.transform.FuseOps()(mod)
-    mod = relax.transform.FuseTIR()(mod)
-    mod = relax.transform.SplitCallTIRByPattern(get_tir_pattern(), 
cutlass_fcodegen())(mod)
-    mod = relax.transform.DeadCodeElimination()(mod)
-    executable = tvm.compile(mod, target)
-    return executable.jit(**compile_options(target))
-
-
-def build_and_run_reference(mod, inputs_np):
-    dev = tvm.device("llvm", 0)
-    ex = tvm.compile(mod, "llvm")
-    vm = relax.VirtualMachine(ex, dev)
-    f = vm["main"]
-    inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
-    return f(*inputs).numpy()
-
-
-def constructGEMM(M, N, K):
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                A = R.arg(
-                    "A", relax.TensorStructInfo((M, K), A_TYPE)
-                )  # pylint: disable=invalid-name
-                B = R.arg(
-                    "B", relax.TensorStructInfo((K, N), B_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
-                    R.output(C)
-                (C,) = df.output_vars
-                R.func_ret_value(C)
-    relax_mod = ib.get()
-    return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_dense():
-    m, n, k = 128, 64, 256
-    executable = build(constructGEMM(m, n, k))
-    dev = tvm.cuda()
-    A = np.random.randn(m, k).astype("float16")
-    B = np.random.randn(k, n).astype("float16")
-    A_tvm = tvm.nd.array(A, dev)
-    B_tvm = tvm.nd.array(B, dev)
-    result = f_run(executable, dev, A_tvm, B_tvm)
-    np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
-
-
-def constructGEMM_bias(M, N, K):
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                A = R.arg(
-                    "A", relax.TensorStructInfo((M, K), A_TYPE)
-                )  # pylint: disable=invalid-name
-                B = R.arg(
-                    "B", relax.TensorStructInfo((K, N), B_TYPE)
-                )  # pylint: disable=invalid-name
-                bias = R.arg(
-                    "bias", relax.TensorStructInfo((1, N), A_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
-                    D = R.emit(R.add(C, bias))
-                    R.output(D)
-                (D,) = df.output_vars
-                R.func_ret_value(D)
-    relax_mod = ib.get()
-    return relax_mod
-
-
-def constructGEMM_bias2(M, N, K):
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                A = R.arg(
-                    "A", relax.TensorStructInfo((M, K), A_TYPE)
-                )  # pylint: disable=invalid-name
-                B = R.arg(
-                    "B", relax.TensorStructInfo((K, N), B_TYPE)
-                )  # pylint: disable=invalid-name
-                bias = R.arg(
-                    "bias", relax.TensorStructInfo((N,), A_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
-                    D = R.emit(R.add(C, bias))
-                    R.output(D)
-                (D,) = df.output_vars
-                R.func_ret_value(D)
-    relax_mod = ib.get()
-    return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_dense_bias():
-    m, n, k = 128, 64, 256
-    executable = build(constructGEMM_bias(m, n, k))
-    dev = tvm.cuda()
-    A = np.random.randn(m, k).astype("float16")
-    B = np.random.randn(k, n).astype("float16")
-    bias = np.random.randn(1, n).astype("float16")
-    A_tvm = tvm.nd.array(A, dev)
-    B_tvm = tvm.nd.array(B, dev)
-    bias_tvm = tvm.nd.array(bias, dev)
-    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
-    np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, 
atol=5e-2)
-
-
[email protected]_cutlass
-def test_cutlass_dense_bias2():
-    m, n, k = 128, 64, 256
-    executable = build(constructGEMM_bias2(m, n, k))
-    dev = tvm.cuda()
-    A = np.random.randn(m, k).astype("float16")
-    B = np.random.randn(k, n).astype("float16")
-    bias = np.random.randn(n).astype("float16")
-    A_tvm = tvm.nd.array(A, dev)
-    B_tvm = tvm.nd.array(B, dev)
-    bias_tvm = tvm.nd.array(bias, dev)
-    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
-    np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, 
atol=5e-2)
-
-
-def constructGEMM_bias_relu(M, N, K):
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                A = R.arg(
-                    "A", relax.TensorStructInfo((M, K), A_TYPE)
-                )  # pylint: disable=invalid-name
-                B = R.arg(
-                    "B", relax.TensorStructInfo((K, N), B_TYPE)
-                )  # pylint: disable=invalid-name
-                bias = R.arg(
-                    "bias", relax.TensorStructInfo((1, N), A_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
-                    D = R.emit(R.add(C, bias))
-                    E = R.emit(R.nn.relu(D))
-                    R.output(E)
-                (E,) = df.output_vars
-                R.func_ret_value(E)
-    relax_mod = ib.get()
-    return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_dense_bias_relu():
-    m, n, k = 128, 64, 256
-    executable = build(constructGEMM_bias_relu(m, n, k))
-    dev = tvm.cuda()
-    A = np.random.randn(m, k).astype("float16")
-    B = np.random.randn(k, n).astype("float16")
-    bias = np.random.randn(1, n).astype("float16")
-    A_tvm = tvm.nd.array(A, dev)
-    B_tvm = tvm.nd.array(B, dev)
-    bias_tvm = tvm.nd.array(bias, dev)
-    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
-    np.testing.assert_allclose(result.numpy(), np.maximum(A @ B + bias, 0), 
rtol=5e-2, atol=5e-2)
-
-
-def constructBatchGEMM(batch, M, N, K):
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                A = R.arg(
-                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
-                )  # pylint: disable=invalid-name
-                B = R.arg(
-                    "B", relax.TensorStructInfo((K, N), B_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
-                    R.output(C)
-                (C,) = df.output_vars
-                R.func_ret_value(C)
-    relax_mod = ib.get()
-    return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense():
-    b, m, n, k = 2, 128, 256, 64
-    executable = build(constructBatchGEMM(b, m, n, k))
-    dev = tvm.cuda()
-    A = np.random.randn(b, m, k).astype("float16")
-    B = np.random.randn(k, n).astype("float16")
-    A_tvm = tvm.nd.array(A, dev)
-    B_tvm = tvm.nd.array(B, dev)
-    result = f_run(executable, dev, A_tvm, B_tvm)
-    np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
-
-
-def constructBatchGEMM2(batch, M, N, K):
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                A = R.arg(
-                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
-                )  # pylint: disable=invalid-name
-                B = R.arg(
-                    "B", relax.TensorStructInfo((batch, K, N), B_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
-                    R.output(C)
-                (C,) = df.output_vars
-                R.func_ret_value(C)
-    relax_mod = ib.get()
-    return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense2():
-    b, m, n, k = 2, 128, 256, 64
-    executable = build(constructBatchGEMM2(b, m, n, k))
-    dev = tvm.cuda()
-    A = np.random.randn(b, m, k).astype("float16")
-    B = np.random.randn(b, k, n).astype("float16")
-    A_tvm = tvm.nd.array(A, dev)
-    B_tvm = tvm.nd.array(B, dev)
-    result = f_run(executable, dev, A_tvm, B_tvm)
-    np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
-
-
-def constructBatchGEMM_bias(batch, M, N, K):
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                A = R.arg(
-                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
-                )  # pylint: disable=invalid-name
-                B = R.arg(
-                    "B", relax.TensorStructInfo((K, N), B_TYPE)
-                )  # pylint: disable=invalid-name
-                bias = R.arg(
-                    "bias", relax.TensorStructInfo((1, N), A_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
-                    D = R.emit(R.add(C, bias))
-                    R.output(D)
-                (D,) = df.output_vars
-                R.func_ret_value(D)
-    relax_mod = ib.get()
-    return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense_bias():
-    b, m, n, k = 2, 128, 256, 64
-    executable = build(constructBatchGEMM_bias(b, m, n, k))
-    dev = tvm.cuda()
-    A = np.random.randn(b, m, k).astype("float16")
-    B = np.random.randn(k, n).astype("float16")
-    bias = np.random.randn(1, n).astype("float16")
-    A_tvm = tvm.nd.array(A, dev)
-    B_tvm = tvm.nd.array(B, dev)
-    bias_tvm = tvm.nd.array(bias, dev)
-    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
-    np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, 
atol=5e-2)
-
-
-def constructBatchGEMM_bias2(batch, M, N, K):
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                A = R.arg(
-                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
-                )  # pylint: disable=invalid-name
-                B = R.arg(
-                    "B", relax.TensorStructInfo((K, N), B_TYPE)
-                )  # pylint: disable=invalid-name
-                bias = R.arg(
-                    "bias", relax.TensorStructInfo((N,), A_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
-                    D = R.emit(R.add(C, bias))
-                    R.output(D)
-                (D,) = df.output_vars
-                R.func_ret_value(D)
-    relax_mod = ib.get()
-    return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense_bias2():
-    b, m, n, k = 2, 128, 256, 64
-    executable = build(constructBatchGEMM_bias2(b, m, n, k))
-    dev = tvm.cuda()
-    A = np.random.randn(b, m, k).astype("float16")
-    B = np.random.randn(k, n).astype("float16")
-    bias = np.random.randn(n).astype("float16")
-    A_tvm = tvm.nd.array(A, dev)
-    B_tvm = tvm.nd.array(B, dev)
-    bias_tvm = tvm.nd.array(bias, dev)
-    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
-    np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, 
atol=5e-2)
-
-
-def constructBatchGEMM_bias2_gelu(batch, M, N, K):
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                A = R.arg(
-                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
-                )  # pylint: disable=invalid-name
-                B = R.arg(
-                    "B", relax.TensorStructInfo((K, N), B_TYPE)
-                )  # pylint: disable=invalid-name
-                bias = R.arg(
-                    "bias", relax.TensorStructInfo((N,), A_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
-                    D = R.emit(R.add(C, bias))
-                    E = R.emit(R.nn.gelu(D))
-                    R.output(E)
-                (E,) = df.output_vars
-                R.func_ret_value(E)
-    relax_mod = ib.get()
-    return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense_bias2_gelu():
-    b, m, n, k = 2, 128, 64, 256
-    executable = build(constructBatchGEMM_bias2_gelu(b, m, n, k))
-    dev = tvm.cuda()
-    A = np.random.randn(b, m, k).astype("float16")
-    B = np.random.randn(k, n).astype("float16")
-    bias = np.random.randn(n).astype("float16")
-    A_tvm = tvm.nd.array(A, dev)
-    B_tvm = tvm.nd.array(B, dev)
-    bias_tvm = tvm.nd.array(bias, dev)
-    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
-    C = A @ B + bias
-    O = 0.5 * C * (1 + erf(C / np.sqrt(2)))
-    np.testing.assert_allclose(result.numpy(), O, rtol=5e-2, atol=5e-2)
-
-
-def constructBatchGEMM_bias2_mul(batch, M, N, K):
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                A = R.arg(
-                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
-                )  # pylint: disable=invalid-name
-                B = R.arg(
-                    "B", relax.TensorStructInfo((K, N), B_TYPE)
-                )  # pylint: disable=invalid-name
-                bias = R.arg(
-                    "bias", relax.TensorStructInfo((N,), A_TYPE)
-                )  # pylint: disable=invalid-name
-                residual = R.arg("residual", relax.TensorStructInfo((batch, M, 
N), A_TYPE))
-                with R.dataflow() as df:
-                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
-                    D = R.emit(R.add(C, bias))
-                    E = R.emit(R.multiply(D, residual))
-                    R.output(E)
-                (E,) = df.output_vars
-                R.func_ret_value(E)
-    relax_mod = ib.get()
-    return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense_bias2_mul():
-    b, m, n, k = 2, 128, 256, 64
-    executable = build(constructBatchGEMM_bias2_mul(b, m, n, k))
-    dev = tvm.cuda()
-    A = np.random.randn(b, m, k).astype("float16")
-    B = np.random.randn(k, n).astype("float16")
-    bias = np.random.randn(n).astype("float16")
-    residual = np.random.randn(b, m, n).astype("float16")
-    A_tvm = tvm.nd.array(A, dev)
-    B_tvm = tvm.nd.array(B, dev)
-    bias_tvm = tvm.nd.array(bias, dev)
-    residual_tvm = tvm.nd.array(residual, dev)
-    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, residual_tvm)
-    np.testing.assert_allclose(result.numpy(), ((A @ B) + bias) * residual, 
rtol=5e-2, atol=5e-2)
-
-
-def constructBatchGEMM2_bias(batch, M, N, K):
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                A = R.arg(
-                    "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
-                )  # pylint: disable=invalid-name
-                B = R.arg(
-                    "B", relax.TensorStructInfo((batch, K, N), B_TYPE)
-                )  # pylint: disable=invalid-name
-                bias = R.arg(
-                    "bias", relax.TensorStructInfo((1, N), A_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
-                    D = R.emit(R.add(C, bias))
-                    R.output(D)
-                (D,) = df.output_vars
-                R.func_ret_value(D)
-    relax_mod = ib.get()
-    return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense2_bias():
-    b, m, n, k = 2, 128, 256, 64
-    executable = build(constructBatchGEMM2_bias(b, m, n, k))
-    dev = tvm.cuda()
-    A = np.random.randn(b, m, k).astype("float16")
-    B = np.random.randn(b, k, n).astype("float16")
-    bias = np.random.randn(1, n).astype("float16")
-    A_tvm = tvm.nd.array(A, dev)
-    B_tvm = tvm.nd.array(B, dev)
-    bias_tvm = tvm.nd.array(bias, dev)
-    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
-    np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, 
atol=5e-2)
-
-
-def constructConv2D(N, C, H, W, KH, KW, O, strides, padding, dilation):
-    from tvm.script.ir_builder import IRBuilder
-    from tvm.script.ir_builder import ir as I
-    from tvm.script.ir_builder import relax as R
-    from tvm.script.ir_builder import tir as T
-
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                x = R.arg(
-                    "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
-                )  # pylint: disable=invalid-name
-                w = R.arg(
-                    "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(
-                        R.nn.conv2d(
-                            x,
-                            w,
-                            strides=strides,
-                            padding=padding,
-                            dilation=dilation,
-                            groups=1,
-                            data_layout="NHWC",
-                            kernel_layout="OHWI",
-                            out_layout="NHWC",
-                            out_dtype=C_TYPE,
-                        )
-                    )
-                    R.output(C)
-                (C,) = df.output_vars
-                R.func_ret_value(C)
-    mod = ib.get()
-    return mod
-
-
[email protected]_cutlass
-def test_cutlass_conv2d():
-    n, c, h, w = 1, 3, 224, 224
-    kh, kw, o = 3, 3, 64
-    for strides in [(1, 1), (2, 2)]:
-        for padding in [(0, 0), (3, 3)]:
-            for dilation in [(1, 1), (4, 4)]:
-                mod = constructConv2D(n, c, h, w, kh, kw, o, strides, padding, 
dilation)
-                executable = build(mod)
-                dev = tvm.cuda()
-                np.random.seed(0)
-                A = np.random.randn(n, h, w, c).astype("float16")
-                B = np.random.randn(o, kh, kw, c).astype("float16")
-                A_tvm = tvm.nd.array(A, dev)
-                B_tvm = tvm.nd.array(B, dev)
-                result = f_run(executable, dev, A_tvm, B_tvm)
-                result_ref = build_and_run_reference(mod, [A, B])
-                np.testing.assert_allclose(
-                    result.numpy(),
-                    result_ref,
-                    rtol=5e-2,
-                    atol=5e-2,
-                )
-
-
-def constructConv2D_bias(N, C, H, W, KH, KW, O, strides, padding, dilation):
-    from tvm.script.ir_builder import IRBuilder
-    from tvm.script.ir_builder import ir as I
-    from tvm.script.ir_builder import relax as R
-    from tvm.script.ir_builder import tir as T
-
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                x = R.arg(
-                    "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
-                )  # pylint: disable=invalid-name
-                w = R.arg(
-                    "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
-                )  # pylint: disable=invalid-name
-                bias = R.arg(
-                    "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(
-                        R.nn.conv2d(
-                            x,
-                            w,
-                            strides=strides,
-                            padding=padding,
-                            dilation=dilation,
-                            groups=1,
-                            data_layout="NHWC",
-                            kernel_layout="OHWI",
-                            out_layout="NHWC",
-                            out_dtype=C_TYPE,
-                        )
-                    )
-                    D = R.emit(R.add(C, bias))
-                    R.output(D)
-                (D,) = df.output_vars
-                R.func_ret_value(D)
-    mod = ib.get()
-    return mod
-
-
[email protected]_cutlass
-def test_cutlass_conv2d_bias():
-    c, h, w = 3, 224, 224
-    kh, kw, o = 3, 3, 64
-    for n in [1, 2]:
-        for strides in [(1, 1), (2, 2)]:
-            for padding in [(0, 0), (3, 3)]:
-                for dilation in [(1, 1), (4, 4)]:
-                    mod = constructConv2D_bias(n, c, h, w, kh, kw, o, strides, 
padding, dilation)
-                    executable = build(mod)
-                    dev = tvm.cuda()
-                    np.random.seed(0)
-                    A = np.random.randn(n, h, w, c).astype("float16")
-                    B = np.random.randn(o, kh, kw, c).astype("float16")
-                    bias = np.random.randn(1, 1, 1, o).astype("float16")
-                    A_tvm = tvm.nd.array(A, dev)
-                    B_tvm = tvm.nd.array(B, dev)
-                    bias_tvm = tvm.nd.array(bias, dev)
-                    result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
-                    result_ref = build_and_run_reference(mod, [A, B, bias])
-                    np.testing.assert_allclose(
-                        result.numpy(),
-                        result_ref,
-                        rtol=5e-2,
-                        atol=5e-2,
-                    )
-
-
-def constructConv2D_bias_add(N, C, H, W, KH, KW, O, OH, OW, strides, padding, 
dilation):
-    from tvm.script.ir_builder import IRBuilder
-    from tvm.script.ir_builder import ir as I
-    from tvm.script.ir_builder import relax as R
-    from tvm.script.ir_builder import tir as T
-
-    with IRBuilder() as ib:  # pylint: disable=invalid-name
-        with I.ir_module() as frame:
-            with R.function():
-                R.func_name("main")
-                x = R.arg(
-                    "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
-                )  # pylint: disable=invalid-name
-                w = R.arg(
-                    "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
-                )  # pylint: disable=invalid-name
-                bias = R.arg(
-                    "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE)
-                )  # pylint: disable=invalid-name
-                res = R.arg(
-                    "res", relax.TensorStructInfo((N, OH, OW, O), A_TYPE)
-                )  # pylint: disable=invalid-name
-                with R.dataflow() as df:
-                    C = R.emit(
-                        R.nn.conv2d(
-                            x,
-                            w,
-                            strides=strides,
-                            padding=padding,
-                            dilation=dilation,
-                            groups=1,
-                            data_layout="NHWC",
-                            kernel_layout="OHWI",
-                            out_layout="NHWC",
-                            out_dtype=C_TYPE,
-                        )
-                    )
-                    D = R.emit(R.add(C, bias))
-                    E = R.emit(R.add(D, res))
-                    R.output(E)
-                (E,) = df.output_vars
-                R.func_ret_value(E)
-    mod = ib.get()
-    return mod
-
-
[email protected]_cutlass
-def test_cutlass_conv2d_bias_add():
-    n, c, h, w = 2, 3, 224, 224
-    kh, kw, o = 3, 3, 64
-    for strides in [(1, 1), (2, 2)]:
-        for padding in [(0, 0), (3, 3)]:
-            for dilation in [(1, 1), (4, 4)]:
-                oh = (h + 2 * padding[0] - dilation[0] * (kh - 1) - 1) // 
strides[0] + 1
-                ow = (w + 2 * padding[1] - dilation[1] * (kw - 1) - 1) // 
strides[1] + 1
-                mod = constructConv2D_bias_add(
-                    n, c, h, w, kh, kw, o, oh, ow, strides, padding, dilation
-                )
-                executable = build(mod)
-                dev = tvm.cuda()
-                np.random.seed(0)
-                A = np.random.randn(n, h, w, c).astype("float16")
-                B = np.random.randn(o, kh, kw, c).astype("float16")
-                bias = np.random.randn(1, 1, 1, o).astype("float16")
-                res = np.random.randn(n, oh, ow, o).astype("float16")
-                A_tvm = tvm.nd.array(A, dev)
-                B_tvm = tvm.nd.array(B, dev)
-                bias_tvm = tvm.nd.array(bias, dev)
-                res_tvm = tvm.nd.array(res, dev)
-                result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, 
res_tvm)
-                result_ref = build_and_run_reference(mod, [A, B, bias, res])
-                np.testing.assert_allclose(
-                    result.numpy(),
-                    result_ref,
-                    rtol=5e-2,
-                    atol=5e-2,
-                )
-
-
-if __name__ == "__main__":
-    tvm.testing.main()

Reply via email to