This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 046c1ba996 [Relax] Move TIR backend to gpu_generic (#17749)
046c1ba996 is described below
commit 046c1ba996ec0382ef65f450cfcc3e0475dd035b
Author: Siyuan Feng <[email protected]>
AuthorDate: Sun Mar 16 21:41:22 2025 +0800
[Relax] Move TIR backend to gpu_generic (#17749)
* [Relax] Move TIR backend to gpu_generic
This PR moves the TIR backend code to the gpu_generic backend directory
structure.
The main changes include:
- Relocating TIR implementation files from relax/backend_tir to
relax/backend/gpu_generic/tir
- Updating import paths in dispatch_sampling.py and dispatch_sort_scan.py
- Removing unused files (pattern.py, cutlass.py and related tests)
* remove tir folder
---
python/tvm/relax/backend/dispatch_sampling.py | 2 +-
python/tvm/relax/backend/dispatch_sort_scan.py | 2 +-
python/tvm/relax/backend/gpu_generic/__init__.py | 4 +-
.../{backend_tir => backend/gpu_generic}/cumsum.py | 0
.../gpu_generic}/sampling.py | 0
python/tvm/relax/backend_tir/__init__.py | 22 -
python/tvm/relax/backend_tir/contrib/__init__.py | 20 -
python/tvm/relax/backend_tir/contrib/cutlass.py | 720 ---------------------
python/tvm/relax/backend_tir/pattern.py | 576 -----------------
tests/python/relax/test_codegen_tir_cutlass.py | 702 --------------------
10 files changed, 5 insertions(+), 2043 deletions(-)
diff --git a/python/tvm/relax/backend/dispatch_sampling.py
b/python/tvm/relax/backend/dispatch_sampling.py
index 68d162fdf1..528529c723 100644
--- a/python/tvm/relax/backend/dispatch_sampling.py
+++ b/python/tvm/relax/backend/dispatch_sampling.py
@@ -36,7 +36,7 @@ class SamplingDispatcher(BackendDispatcher):
return super().visit_call_(call)
if call.op.name == "relax.multinomial_from_uniform":
- from tvm.relax.backend_tir import ( # pylint:
disable=import-outside-toplevel
+ from tvm.relax.backend.gpu_generic import ( # pylint:
disable=import-outside-toplevel
generic_get_sample_index,
gpu_multinomial_from_uniform,
)
diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py
b/python/tvm/relax/backend/dispatch_sort_scan.py
index b5a94619c2..9f7cbaee9a 100644
--- a/python/tvm/relax/backend/dispatch_sort_scan.py
+++ b/python/tvm/relax/backend/dispatch_sort_scan.py
@@ -141,7 +141,7 @@ class SortScanDispatcher(BackendDispatcher):
and call.op.name == "relax.cumsum"
and call.attrs.exclusive == 0
):
- from tvm.relax.backend_tir import ( # pylint:
disable=import-outside-toplevel
+ from tvm.relax.backend.gpu_generic import ( # pylint:
disable=import-outside-toplevel
gpu_2d_continuous_cumsum,
)
diff --git a/python/tvm/relax/backend/gpu_generic/__init__.py
b/python/tvm/relax/backend/gpu_generic/__init__.py
index ea2d2a2afb..d7c316d28c 100644
--- a/python/tvm/relax/backend/gpu_generic/__init__.py
+++ b/python/tvm/relax/backend/gpu_generic/__init__.py
@@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""The Relax Metal backend compilation pipeline and other passes."""
+from .cumsum import gpu_2d_continuous_cumsum
from .pipeline import (
+ dataflow_lower_passes,
finalize_passes,
get_default_pipeline,
legalize_passes,
- dataflow_lower_passes,
library_dispatch_passes,
)
+from .sampling import generic_get_sample_index, gpu_multinomial_from_uniform
diff --git a/python/tvm/relax/backend_tir/cumsum.py
b/python/tvm/relax/backend/gpu_generic/cumsum.py
similarity index 100%
rename from python/tvm/relax/backend_tir/cumsum.py
rename to python/tvm/relax/backend/gpu_generic/cumsum.py
diff --git a/python/tvm/relax/backend_tir/sampling.py
b/python/tvm/relax/backend/gpu_generic/sampling.py
similarity index 100%
rename from python/tvm/relax/backend_tir/sampling.py
rename to python/tvm/relax/backend/gpu_generic/sampling.py
diff --git a/python/tvm/relax/backend_tir/__init__.py
b/python/tvm/relax/backend_tir/__init__.py
deleted file mode 100644
index b64bdcda6b..0000000000
--- a/python/tvm/relax/backend_tir/__init__.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Relax backends, tir based"""
-
-from . import contrib
-from .cumsum import gpu_2d_continuous_cumsum
-from .pattern import get_tir_pattern
-from .sampling import gpu_multinomial_from_uniform, generic_get_sample_index
diff --git a/python/tvm/relax/backend_tir/contrib/__init__.py
b/python/tvm/relax/backend_tir/contrib/__init__.py
deleted file mode 100644
index 9274f22374..0000000000
--- a/python/tvm/relax/backend_tir/contrib/__init__.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""External backend codegen modules for Relax, tir based."""
-
-from .cutlass import cutlass_fcodegen
diff --git a/python/tvm/relax/backend_tir/contrib/cutlass.py
b/python/tvm/relax/backend_tir/contrib/cutlass.py
deleted file mode 100644
index 0dbe31c468..0000000000
--- a/python/tvm/relax/backend_tir/contrib/cutlass.py
+++ /dev/null
@@ -1,720 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint:
disable=invalid-name,comparison-with-callable,unused-variable,missing-function-docstring
-"""codegen for cutlass"""
-import operator
-from functools import reduce
-from typing import List, Dict, Any
-
-from tvm.contrib.cutlass.build import _get_cutlass_path,
_get_cutlass_compile_options
-from tvm.contrib.nvcc import get_target_compute_version
-from tvm.contrib.cutlass.library import LayoutType, ConvKind
-from tvm.contrib.cutlass.gen_tensor_op import instantiate_template
-from tvm.contrib.cutlass.gen_gemm import CutlassGemmProfiler
-from tvm.contrib.cutlass.gen_conv2d import CutlassConv2DProfiler
-from ..pattern import (
- MatchResult,
- matmul_rrr_fp16,
- bias_row_2d_fp16,
- bias_row_1d_fp16,
- batch_bias_row_2d_fp16,
- batch_bias_row_1d_fp16,
- relu_fp16,
- erf_3d_fp32,
- batch_matmul_rrr_2d_fp16,
- batch_matmul_rrr_3d_fp16,
- conv2d_nhwc_fp16,
- padding_2d_nhwc_fp16,
- copy_4d_fp16,
- bias_add_nhwc_2d_fp16,
- bias_add_nhwc_1d_fp16,
- elem_add_4d_fp16,
- elem_mul_3d_fp16,
- scalar_add_3d_fp16,
- scalar_mul_3d_fp16,
- cast_3d_fp16,
- cast_3d_fp32,
-)
-
-#### helper functions ####
-# list representing the anchor ops
-# in the future more layouts/dtypes can be supported
-MATMUL_LIST = [matmul_rrr_fp16]
-MATMUL_BIAS_LIST = [bias_row_2d_fp16, bias_row_1d_fp16]
-BATCH_MATMUL_LIST = [batch_matmul_rrr_2d_fp16, batch_matmul_rrr_3d_fp16]
-BATCH_MATMUL_BIAS_LIST = [batch_bias_row_2d_fp16, batch_bias_row_1d_fp16]
-CONV2D_LIST = [conv2d_nhwc_fp16]
-
-# attributes for anchor ops used in code generation
-OP_PATTERN_ATTR_LIST = {
- matmul_rrr_fp16: {
- "arg0_dtype": "float16",
- "arg1_dtype": "float16",
- "ret_dtype": "float16",
- },
- batch_matmul_rrr_2d_fp16: {
- "arg0_dtype": "float16",
- "arg1_dtype": "float16",
- "ret_dtype": "float16",
- },
- batch_matmul_rrr_3d_fp16: {
- "arg0_dtype": "float16",
- "arg1_dtype": "float16",
- "ret_dtype": "float16",
- },
- conv2d_nhwc_fp16: {
- "arg0_dtype": "float16",
- "arg1_dtype": "float16",
- "ret_dtype": "float16",
- # in the future we can add layout here
- },
-}
-
-
-def _get_cutlass_code(attr):
- pattern = attr["op_type"]
- if pattern.startswith("cutlass.matmul"):
- return cutlass_codegen_gemm(attr)
- elif pattern.startswith("cutlass.conv2d"):
- return cutlass_codegen_conv2d(attr)
- else:
- raise ValueError("op not supported")
-
-
-def _final_code(code, headers, func_args):
- res = ""
- res += "#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>\n"
- res += "#include <tvm/runtime/c_runtime_api.h>\n"
- res += "#include <tvm/runtime/packed_func.h>\n"
- res += "#include <dlpack/dlpack.h>\n"
- res += "#include <cuda_fp16.h>\n"
- res += "#include <cutlass/cutlass.h>\n"
- res += "#include <cutlass/coord.h>\n"
- res += "#include <cutlass/tensor_ref.h>\n"
- res += "#include <cutlass/util/host_tensor.h>\n"
-
- for header in headers:
- res += "#include <" + header + ">\n"
- res += "namespace {\n"
- res += "using namespace tvm;\n"
- res += "using namespace tvm::runtime;\n"
- res += "void _cutlass_kernel("
- for arg in func_args:
- res += "NDArray " + arg + ", "
- res += "NDArray out0) {"
- res += code
- res += "}\n"
- res += "} // namespace\n"
- res += "TVM_DLL_EXPORT_TYPED_FUNC({global_symbol}, _cutlass_kernel);\n"
- return res
-
-
-#### cutlass patterns ####
-def matmul_bias_relu(match_results, attr, get_code=True):
- if len(match_results) < 3:
- return None
- attr = matmul_bias(match_results[:2], attr, get_code=False)
- if attr is None or match_results[2].pattern != relu_fp16:
- return None
- m_bias, n_bias = match_results[1].symbol_values
- m_relu, n_relu = match_results[2].symbol_values
- A_bias, B_bias, C_bias = match_results[1].matched_buffers
- A_relu, B_relu = match_results[2].matched_buffers
- if m_bias == m_relu and n_bias == n_relu and C_bias == A_relu:
- attr["op_type"] = "cutlass.matmul_bias_relu"
- return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code
else attr
- return None
-
-
-def matmul_bias(match_results, attr, get_code=True):
- if len(match_results) < 2:
- return None
- attr = matmul(match_results[:1], attr, get_code=False)
- if attr is None or match_results[1].pattern not in MATMUL_BIAS_LIST:
- return None
- m_matmul, n_matmul, k_matmul = match_results[0].symbol_values
- m_bias, n_bias = match_results[1].symbol_values
- A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
- A_bias, B_bias, C_bias = match_results[1].matched_buffers
- if m_matmul == m_bias and n_matmul == n_bias and C_matmul == A_bias:
- attr["op_type"] = "cutlass.matmul_bias"
- attr["bias_arg_idx"] = 2
- attr["args"].append(B_bias)
- return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code
else attr
- return None
-
-
-def matmul(match_results, attr, get_code=True):
- if len(match_results) < 1:
- return None
- if match_results[0].pattern in MATMUL_LIST:
- # matmul
- attr["op_type"] = "cutlass.matmul"
- return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code
else attr
- return None
-
-
-def batch_matmul_bias_gelu(match_results, attr, get_code=True):
- if len(match_results) < 9:
- return None
- attr = batch_matmul_bias(match_results[:2], attr, get_code=False) #
batch_matmul, batch_bias
- if (
- attr is None
- or match_results[2].pattern != scalar_mul_3d_fp16
- or match_results[3].pattern != cast_3d_fp32
- or match_results[4].pattern != erf_3d_fp32
- or match_results[5].pattern != cast_3d_fp16
- or match_results[6].pattern != scalar_mul_3d_fp16
- or match_results[7].pattern != scalar_add_3d_fp16
- or match_results[8].pattern != elem_mul_3d_fp16
- ):
- return None
-
- def shape_match_3d(shape1, shape2):
- if len(shape1) < 3 or len(shape2) < 3:
- return False
- return shape1[0] == shape2[0] and shape1[1] == shape2[1] and shape1[2]
== shape2[2]
-
- for i in range(1, 8):
- if not shape_match_3d(match_results[i].symbol_values, match_results[i
+ 1].symbol_values):
- return None
-
- if not (
- match_results[1].matched_buffers[-1] ==
match_results[2].matched_buffers[0]
- and match_results[2].matched_buffers[-1] ==
match_results[3].matched_buffers[0]
- and match_results[3].matched_buffers[-1] ==
match_results[4].matched_buffers[0]
- and match_results[4].matched_buffers[-1] ==
match_results[5].matched_buffers[0]
- and match_results[5].matched_buffers[-1] ==
match_results[6].matched_buffers[0]
- and match_results[6].matched_buffers[-1] ==
match_results[7].matched_buffers[0]
- and match_results[1].matched_buffers[-1] ==
match_results[8].matched_buffers[0]
- and match_results[7].matched_buffers[-1] ==
match_results[8].matched_buffers[1]
- ):
- return None
-
- if (
- abs(float(match_results[2].symbol_values[-1] - 0.5**0.5)) > 1e-5
- or abs(float(match_results[6].symbol_values[-1] - 0.5)) > 1e-5
- or abs(float(match_results[7].symbol_values[-1] - 0.5)) > 1e-5
- ):
- return None
-
- attr["op_type"] = "cutlass.matmul_bias_gelu"
- return [_get_cutlass_code(attr=attr), 9, attr["args"]] if get_code else
attr
-
-
-def batch_matmul_bias_residual_mul(match_results, attr, get_code=True):
- if len(match_results) < 3:
- return None
- attr = batch_matmul_bias(match_results[:2], attr, get_code=False) #
batch_matmul, batch_bias
- if attr is None or match_results[2].pattern != elem_mul_3d_fp16:
- return None
- (
- b_bias,
- m_bias,
- n_bias,
- ) = match_results[1].symbol_values
- (
- b_mul,
- m_mul,
- n_mul,
- ) = match_results[2].symbol_values
- A_bias, B_bias, C_bias = match_results[1].matched_buffers
- A_mul, B_mul, C_mul = match_results[2].matched_buffers
- if b_bias == b_mul and m_bias == m_mul and n_bias == n_mul and C_bias ==
A_mul:
- attr["op_type"] = "cutlass.matmul_bias_residual_multiply"
- attr["residual_arg_idx"] = 3
- return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code
else attr
- return None
-
-
-def batch_matmul_bias(match_results, attr, get_code=True):
- if len(match_results) < 2:
- return None
- attr = batch_matmul(match_results[:1], attr, get_code=False)
- if attr is None or match_results[1].pattern not in BATCH_MATMUL_BIAS_LIST:
- return None
- (
- b_matmul,
- m_matmul,
- n_matmul,
- k_matmul,
- ) = match_results[0].symbol_values
- (
- b_bias,
- m_bias,
- n_bias,
- ) = match_results[1].symbol_values
- A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
- A_bias, B_bias, C_bias = match_results[1].matched_buffers
- if b_matmul == b_bias and m_matmul == m_bias and n_matmul == n_bias and
C_matmul == A_bias:
- attr["op_type"] = "cutlass.matmul_bias"
- attr["bias_arg_idx"] = 2
- attr["args"].append(B_bias)
- return [_get_cutlass_code(attr=attr), 2, attr["args"]] if get_code
else attr
- return None
-
-
-def batch_matmul(match_results, attr, get_code=True):
- if len(match_results) < 1:
- return None
- if match_results[0].pattern in BATCH_MATMUL_LIST:
- attr["op_type"] = "cutlass.matmul"
- return [_get_cutlass_code(attr=attr), 1, attr["args"]] if get_code
else attr
- return None
-
-
-def conv2d_bias_residual_add(match_results, attr, get_code=True):
- if len(match_results) < 4:
- return None
- attr = conv2d_bias(match_results[:3], attr, get_code=False)
- if attr is None or match_results[3].pattern != elem_add_4d_fp16:
- return None
- N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values
- in1_bias, in2_bias, out_bias = match_results[2].matched_buffers
- N_add, H_add, W_add, C_add = match_results[3].symbol_values
- in1_add, in2_add, out_add = match_results[3].matched_buffers
- if (
- N_bias == N_add
- and H_bias == H_add
- and W_bias == W_add
- and C_bias == C_add
- and out_bias in [in1_add, in2_add]
- ):
- attr["op_type"] = "cutlass.conv2d_bias_residual_add"
- attr["residual_arg_idx"] = 3
- attr["args"].append(in2_add if out_bias == in1_add else in1_add)
- return [_get_cutlass_code(attr=attr), 4, attr["args"]] if get_code
else attr
- return None
-
-
-def conv2d_bias(match_results, attr, get_code=True):
- if len(match_results) < 3:
- return None
- attr = conv2d(match_results[:2], attr, get_code=False)
- if attr is None or (
- match_results[2].pattern not in [bias_add_nhwc_2d_fp16,
bias_add_nhwc_1d_fp16]
- ):
- return None
- (N_conv, pH_conv, pW_conv, H_conv, W_conv, C_conv, O_conv,) =
match_results[
- 1
- ].symbol_values[:7]
- A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers
- N_bias, H_bias, W_bias, C_bias = match_results[2].symbol_values
- A_bias, B_bias, out_bias = match_results[2].matched_buffers
- if (
- N_bias == N_conv
- and H_bias == H_conv
- and W_bias == W_conv
- and C_bias == O_conv
- and out_conv == A_bias
- ):
- attr["op_type"] = "cutlass.conv2d_bias"
- attr["bias_arg_idx"] = 2
- attr["args"].append(B_bias)
- return [_get_cutlass_code(attr=attr), 3, attr["args"]] if get_code
else attr
- return None
-
-
-def conv2d(match_results, attr, get_code=True):
- if len(match_results) < 2:
- return None
- if (
- match_results[0].pattern in [padding_2d_nhwc_fp16, copy_4d_fp16]
- and match_results[1].pattern == conv2d_nhwc_fp16
- ):
- if match_results[0].pattern == padding_2d_nhwc_fp16:
- (
- N_pad,
- H_pad,
- W_pad,
- C_pad,
- pH_pad,
- pW_pad,
- lH_pad,
- lW_pad,
- rH_pad,
- rW_pad,
- ) = match_results[0].symbol_values
- else:
- (
- N_pad,
- H_pad,
- W_pad,
- C_pad,
- ) = match_results[0].symbol_values
- pH_pad = rH_pad = H_pad
- pW_pad = rW_pad = W_pad
- lH_pad = lW_pad = 0
- (
- N_conv,
- pH_conv,
- pW_conv,
- H_conv,
- W_conv,
- C_conv,
- O_conv,
- KH_conv,
- KW_conv,
- stride_h_conv,
- stride_w_conv,
- dilation_h_conv,
- dilation_w_conv,
- ) = match_results[1].symbol_values
- A, A_pad = match_results[0].matched_buffers
- A_pad_conv, B_conv, out_conv = match_results[1].matched_buffers
- if (
- N_pad == N_conv
- and pH_pad == pH_conv
- and pW_pad == pW_conv
- and C_pad == C_conv
- and A_pad == A_pad_conv
- ):
- if (
- lH_pad == pH_pad - rH_pad
- and lW_pad == pW_pad - rW_pad
- and lH_pad + H_pad == rH_pad
- and lW_pad + W_pad == rW_pad
- ):
- padding = (lH_pad, lW_pad)
- strides = (stride_h_conv, stride_w_conv)
- dilation = (dilation_h_conv, dilation_w_conv)
- attr["padding"] = padding
- attr["strides"] = strides
- attr["dilation"] = dilation
- attr["op_type"] = "cutlass.conv2d"
- return [_get_cutlass_code(attr=attr), 2, attr["args"]] if
get_code else attr
- return None
-
-
-### cutlass codegen functions ###
-def compile_options(target, threads=-1, use_fast_math=False):
- compute_version =
int("".join(get_target_compute_version(target).split(".")))
- kwargs = _get_cutlass_compile_options(compute_version, threads,
use_fast_math)
- kwargs["options"].remove("-c")
- return kwargs
-
-
-def cutlass_fcodegen(sm=80, bin_dir="./bin"):
- gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), bin_dir)
- conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), bin_dir)
-
- def cutlass_codegen_with_match_results(match_results: List[MatchResult]):
- """generate cutlass code with match results"""
- nonlocal gemm_profiler
- nonlocal conv2d_profiler
-
- assert len(match_results) > 0
-
- # add shape into attr
- if match_results[0].pattern in MATMUL_LIST:
- A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
- attr: Dict[Any, Any] =
OP_PATTERN_ATTR_LIST[match_results[0].pattern]
- attr["args"] = [A_matmul, B_matmul]
- attr["arg0_shape"] = A_matmul.shape
- attr["arg1_shape"] = B_matmul.shape
- attr["ret_shape"] = C_matmul.shape
- attr["lhs_arg_idx"] = 0
- attr["rhs_arg_idx"] = 1
- elif match_results[0].pattern in BATCH_MATMUL_LIST:
- A_matmul, B_matmul, C_matmul = match_results[0].matched_buffers
- attr = OP_PATTERN_ATTR_LIST[match_results[0].pattern]
- attr["args"] = [A_matmul, B_matmul]
- attr["arg0_shape"] = A_matmul.shape
- attr["arg1_shape"] = B_matmul.shape
- attr["ret_shape"] = C_matmul.shape
- attr["lhs_arg_idx"] = 0
- attr["rhs_arg_idx"] = 1
- elif len(match_results) >= 1 and match_results[1].pattern in
CONV2D_LIST:
- A_input = match_results[0].matched_buffers[0]
- A_conv2d, B_conv2d, C_conv2d = match_results[1].matched_buffers
- attr = OP_PATTERN_ATTR_LIST[match_results[1].pattern]
- attr["args"] = [A_input, B_conv2d]
- attr["arg0_shape"] = A_input.shape
- attr["arg1_shape"] = B_conv2d.shape
- attr["ret_shape"] = C_conv2d.shape
- attr["lhs_arg_idx"] = 0
- attr["rhs_arg_idx"] = 1
- else:
- return ["", 0]
-
- # add profiler into attr
- attr["gemm_profiler"] = gemm_profiler
- attr["conv2d_profiler"] = conv2d_profiler
-
- cutlass_patterns = [
- # 9
- batch_matmul_bias_gelu,
- # 4
- conv2d_bias_residual_add,
- # 3
- batch_matmul_bias_residual_mul,
- matmul_bias_relu,
- conv2d_bias,
- # 2
- matmul_bias,
- batch_matmul_bias,
- conv2d,
- # 1
- matmul,
- batch_matmul,
- ]
- for pattern in cutlass_patterns:
- res = pattern(match_results, attr)
- if res is not None:
- return res
-
- return ["", 0]
-
- return cutlass_codegen_with_match_results
-
-
-def cutlass_codegen_gemm(attrs):
- """cutlass codegen for gemm"""
- gemm_profiler = attrs["gemm_profiler"]
- op_type = attrs["op_type"]
- lhs_shape = attrs["arg0_shape"]
- rhs_shape = attrs["arg1_shape"]
- MM = lhs_shape[-2]
- KK = lhs_shape[-1]
- if "transposed" in op_type:
- NN = rhs_shape[-2]
- ldb = "K"
- layout_b = LayoutType.ColumnMajor
- else:
- NN = rhs_shape[-1]
- ldb = "N"
- layout_b = LayoutType.RowMajor
-
- lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
- rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
- if lhs_batches == 1 and rhs_batches == 1:
- # Regular matmul
- is_batched = False
- batch_attrs = {}
- else:
- is_batched = True
- batch_attrs = {
- # If both lhs_batches and rhs_batches are greater than 1,
- # they must be equal. This is checked by
is_shape_valid_for_cutlass_matmul.
- "batch": lhs_batches if rhs_batches == 1 else rhs_batches,
- "batch_stride_A": 0 if lhs_batches == 1 else MM * KK,
- "batch_stride_B": 0 if rhs_batches == 1 else KK * NN,
- "batch_stride_C": MM * NN,
- }
- op_name, op_def, _ = gemm_profiler.profile(
- op_type,
- MM,
- NN,
- KK,
- attrs["ret_dtype"],
- attrs["arg0_dtype"],
- attrs["arg1_dtype"],
- False,
- batched=is_batched,
- find_first_valid=False,
- use_multiprocessing=True,
- layout_b=layout_b,
- )
- attrs["cutlass_op_name"] = op_name
- attrs["cutlass_op_def"] = op_def
- attrs["lda"] = "K"
- attrs["ldb"] = ldb
- attrs["ldc"] = "N"
- attrs.update(batch_attrs)
- del attrs["gemm_profiler"]
- del attrs["conv2d_profiler"]
-
- nargs = 2
- if "bias_arg_idx" in attrs:
- nargs += 1
- if "residual_arg_idx" in attrs:
- nargs += 1
- func_args = ["inp" + str(i) for i in range(nargs)]
-
- # A temporary solution to handle batch matmul residual cases
- # TODO(@bohan): remove this after initialize_template supports bmm residual
- if op_type in [
- "cutlass.matmul_bias_residual_multiply",
- ]:
-
- def _convert_dtype_str(dtype):
- if isinstance(dtype, list):
- arr = []
- for t in dtype:
- arr.append(_convert_dtype_str(t))
- return arr
- elif isinstance(dtype, str):
- if dtype == "float16":
- return "cutlass::half_t"
- elif dtype == "float32":
- return "float"
- raise ValueError("dtype not supported")
-
- typea, typeb, typec = _convert_dtype_str(
- [attrs["arg0_dtype"], attrs["arg1_dtype"], attrs["ret_dtype"]]
- )
-
- text = f"""
-#define CUTLASS_ENABLE_CUBLAS 1
-#define CUTLASS_NAMESPACE cutlass
-#define CUTLASS_ENABLE_TENSOR_CORE_MMA 1
-#define NDEBUG
-#include <cutlass/cutlass.h>
-#include <cutlass/tensor_ref.h>
-#include <cutlass/util/host_tensor.h>
-#include <cutlass/gemm/device/gemm.h>
-#include <cutlass/gemm/device/gemm_batched.h>
-#include <cutlass/layout/matrix.h>
-#include <cutlass/numeric_types.h>
-#include "cutlass/epilogue/thread/activation.h"
-#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
-#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
-#include <fstream>
-#include <iostream>
-#include <sstream>
-#include <vector>
-#define DMLC_USE_LOGGING_LIBRARY <tvm/runtime/logging.h>
-#include <tvm/runtime/logging.h>
-#include <tvm/runtime/ndarray.h>
-#include <tvm/runtime/packed_func.h>
-namespace {{
-using namespace tvm;
-using namespace tvm::runtime;
-void _BHGEMM(NDArray A, NDArray B, NDArray Bias, NDArray D, NDArray C) {{
- // A: [Batch, M, K], B: [1, K, N]/[K, N], Bias: [1, N]/[N], D: [Batch, M,
N], C: [Batch, M, N]
- CHECK_EQ(A->ndim, 3);
- int bdim = B->ndim;
- int bias_dim = Bias->ndim;
- CHECK_EQ(C->ndim, 3);
- CHECK_EQ(A->shape[2], B->shape[bdim - 2]);
- CHECK_EQ(Bias->shape[bias_dim - 1], B->shape[bdim - 1]);
- CHECK_EQ(D->ndim, 3);
- CHECK_EQ(D->shape[0], A->shape[0]);
- CHECK_EQ(D->shape[1], A->shape[1]);
- CHECK_EQ(D->shape[2], B->shape[bdim - 1]);
- CHECK_EQ(C->shape[0], A->shape[0]);
- CHECK_EQ(C->shape[1], A->shape[1]);
- CHECK_EQ(C->shape[2], B->shape[bdim - 1]);
- int64_t M = A->shape[0] * A->shape[1];
- int64_t N = B->shape[bdim - 1];
- int64_t K = A->shape[2];
- int64_t input_a_batch_stride = M * K;
- int64_t input_a_stride = K;
- int64_t input_a_offset = 0; // default to 0
- int64_t input_b_batch_stride = K * N;
- int64_t input_b_stride = N;
- int64_t input_b_offset = 0; // default to 0
- int64_t output_stride = N;
- int64_t output_offset = 0;
- int64_t a_size = 1;
- a_size *= A->shape[0];
- a_size *= A->shape[1];
- a_size *= A->shape[2];
-
- int64_t b_size = 1;
- b_size *= B->shape[bias_dim - 2];
- b_size *= B->shape[bias_dim - 1];
-
- int64_t c_size = 1;
- c_size *= C->shape[0];
- c_size *= C->shape[1];
- c_size *= C->shape[2];
-
- // Define the GEMM operation
- {op_def}
- using kernel = Operation_{op_name};
- using ElementComputeEpilogue = typename kernel::ElementAccumulator;
- typename kernel::Arguments arguments({{
- cutlass::gemm::GemmUniversalMode::kGemm, // GemmUniversalMode mode
- {{M, N, K}}, // GemmCoord problem_size
- 1, // int batch_count
- {{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}}, // typename
EpilogueOutputOp::Params epilogue
- ({typea}*)(A->data) + input_a_offset, // void const * ptr_A
- ({typeb}*)(B->data) + input_b_offset, // void const * ptr_B
- ({typec}*)(D->data), // void const * ptr_C1
- ({typec}*)(C->data) + output_offset, // void * ptr_D
- ({typea}*)(Bias->data), // void * ptr_Vector
- nullptr, // void * ptr_Tensor
- input_a_batch_stride, // int64_t batch_stride_A
- input_b_batch_stride, // int64_t batch_stride_B
- 0, // int64_t batch_stride_C1
- 0, // int64_t batch_stride_D
- 0, // int64_t batch_stride_Vector
- 0, // int64_t batch_stride_Tensor
- input_a_stride, // typename LayoutA::Stride::Index lda
- input_b_stride, // typename LayoutB::Stride::Index ldb
- N, // typename LayoutC::Stride::Index ldc1
- output_stride, // typename LayoutC::Stride::Index ldd
- 0, // typename LayoutC::Stride::Index ldr
- 0, // typename LayoutC::Stride::Index ldt
- }});
- kernel gemm_op;
- size_t workspace_size = gemm_op.get_workspace_size(arguments);
- cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
- cutlass::Status status = gemm_op.can_implement(arguments);
- CHECK(status == cutlass::Status::kSuccess);
- status = gemm_op.initialize(arguments, workspace.get());
- CHECK(status == cutlass::Status::kSuccess);
- status = gemm_op();
- CHECK(status == cutlass::Status::kSuccess);
- return;
-}}
-}} // namespace
-TVM_DLL_EXPORT_TYPED_FUNC({{global_symbol}}, _BHGEMM);
- """
- return text
-
- code = instantiate_template(op_type, attrs, func_args)
- return _final_code(code.code, code.headers, func_args)
-
-
-def cutlass_codegen_conv2d(attrs):
- """cutlass codegen for conv2d"""
- # cutlass backend only supports nhwc for now
- conv2d_profiler = attrs["conv2d_profiler"]
- op_type = attrs["op_type"]
- conv_kind = ConvKind.Fprop
- op_name, op_def, _ = conv2d_profiler.profile(
- op_type=attrs["op_type"],
- d_shape=attrs["arg0_shape"],
- w_shape=attrs["arg1_shape"],
- padding=attrs["padding"],
- stride=attrs["strides"],
- dilation=attrs["dilation"],
- out_dtype=attrs["ret_dtype"],
- data_dtype=attrs["arg0_dtype"],
- weight_dtype=attrs["arg1_dtype"],
- use_3xtf32=False,
- conv_kind=conv_kind,
- split_k_slices=[1],
- profile_all_alignments=True,
- find_first_valid=False,
- use_multiprocessing=True,
- )
- attrs["cutlass_op_def"] = op_def
- attrs["cutlass_op_name"] = op_name
- del attrs["gemm_profiler"]
- del attrs["conv2d_profiler"]
-
- nargs = 2
- if "bias_arg_idx" in attrs:
- nargs += 1
- if "residual_arg_idx" in attrs:
- nargs += 1
- func_args = ["inp" + str(i) for i in range(nargs)]
- code = instantiate_template(op_type, attrs, func_args)
- return _final_code(code.code, code.headers, func_args)
diff --git a/python/tvm/relax/backend_tir/pattern.py
b/python/tvm/relax/backend_tir/pattern.py
deleted file mode 100644
index 10f7a3b162..0000000000
--- a/python/tvm/relax/backend_tir/pattern.py
+++ /dev/null
@@ -1,576 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name,missing-function-docstring,chained-comparison
-"""TIR Patterns"""
-from typing import List
-
-import tvm
-from tvm.runtime import Object
-import tvm._ffi
-
-from tvm.script import tir as T
-
-
-@tvm._ffi.register_object("relax.MatchResult")
-class MatchResult(Object):
- """The match result of a TIR pattern."""
-
- def __init__(self, pattern, symbol_values, matched_buffers):
- self.__init_handle_by_constructor__(
- tvm._ffi.MatchResult, pattern, symbol_values, matched_buffers
- )
-
-
[email protected]_func
-def matmul_rrr_fp16(
- var_rxplaceholder: T.handle,
- var_rxplaceholder_1: T.handle,
- var_matmul: T.handle,
- M: T.int64,
- N: T.int64,
- K: T.int64,
-) -> None:
- # function attr dict
- T.func_attr({"tir.noalias": True})
- rxplaceholder = T.match_buffer(var_rxplaceholder, [M, K], dtype="float16")
- rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N],
dtype="float16")
- matmul = T.match_buffer(var_matmul, [M, N], dtype="float16")
- # body
- # with T.block("root")
- for i0, i1, i2 in T.grid(M, N, K):
- with T.block("matmul"):
- i0_1, i1_1, k = T.axis.remap("SSR", [i0, i1, i2])
- T.reads(rxplaceholder[i0_1, k], rxplaceholder_1[k, i1_1])
- T.writes(matmul[i0_1, i1_1])
- with T.init():
- matmul[i0_1, i1_1] = T.float16(0)
- matmul[i0_1, i1_1] = (
- matmul[i0_1, i1_1] + rxplaceholder[i0_1, k] *
rxplaceholder_1[k, i1_1]
- )
-
-
[email protected]_func
-def bias_row_2d_fp16(
- var_rxplaceholder: T.handle,
- var_rxplaceholder_1: T.handle,
- var_T_add: T.handle,
- M: T.int64,
- N: T.int64,
-) -> None:
- # function attr dict
- T.func_attr({"tir.noalias": True})
- rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
- rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N],
dtype="float16")
- T_add = T.match_buffer(var_T_add, [M, N], dtype="float16")
- # body
- # with T.block("root")
- for i0, i1 in T.grid(M, N):
- with T.block("T_add"):
- ax0, ax1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[T.int64(0), ax1])
- T.writes(T_add[ax0, ax1])
- T_add[ax0, ax1] = rxplaceholder[ax0, ax1] +
rxplaceholder_1[T.int64(0), ax1]
-
-
[email protected]_func
-def bias_row_1d_fp16(
- var_rxplaceholder: T.handle,
- var_rxplaceholder_1: T.handle,
- var_T_add: T.handle,
- M: T.int64,
- N: T.int64,
-) -> None:
- # function attr dict
- T.func_attr({"tir.noalias": True})
- rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
- rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16")
- T_add = T.match_buffer(var_T_add, [M, N], dtype="float16")
- # body
- # with T.block("root")
- for i0, i1 in T.grid(M, N):
- with T.block("T_add"):
- ax0, ax1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1])
- T.writes(T_add[ax0, ax1])
- T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[ax1]
-
-
[email protected]_func
-def batch_bias_row_2d_fp16(
- var_rxplaceholder: T.handle,
- var_rxplaceholder_1: T.handle,
- var_T_add: T.handle,
- batch: T.int64,
- M: T.int64,
- N: T.int64,
-) -> None:
- # function attr dict
- T.func_attr({"tir.noalias": True})
- rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N],
dtype="float16")
- rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [T.int64(1), N],
dtype="float16")
- T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16")
- # body
- # with T.block("root")
- for i0, i1, i2 in T.grid(batch, M, N):
- with T.block("T_add"):
- ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
- T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[T.int64(0),
ax2])
- T.writes(T_add[ax0, ax1, ax2])
- T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] +
rxplaceholder_1[T.int64(0), ax2]
-
-
[email protected]_func
-def batch_bias_row_1d_fp16(
- var_rxplaceholder: T.handle,
- var_rxplaceholder_1: T.handle,
- var_T_add: T.handle,
- batch: T.int64,
- M: T.int64,
- N: T.int64,
-) -> None:
- # function attr dict
- T.func_attr({"tir.noalias": True})
- rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, N],
dtype="float16")
- rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [N], dtype="float16")
- T_add = T.match_buffer(var_T_add, [batch, M, N], dtype="float16")
- # body
- # with T.block("root")
- for i0, i1, i2 in T.grid(batch, M, N):
- with T.block("T_add"):
- ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
- T.reads(rxplaceholder[ax0, ax1, ax2], rxplaceholder_1[ax2])
- T.writes(T_add[ax0, ax1, ax2])
- T_add[ax0, ax1, ax2] = rxplaceholder[ax0, ax1, ax2] +
rxplaceholder_1[ax2]
-
-
[email protected]_func
-def relu_fp16(var_rxplaceholder: T.handle, var_compute: T.handle, M: T.int64,
N: T.int64) -> None:
- # function attr dict
- T.func_attr({"tir.noalias": True})
- rxplaceholder = T.match_buffer(var_rxplaceholder, [M, N], dtype="float16")
- compute = T.match_buffer(var_compute, [M, N], dtype="float16")
- # body
- # with T.block("root")
- for i0, i1 in T.grid(M, N):
- with T.block("compute"):
- i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
- T.reads(rxplaceholder[i0_1, i1_1])
- T.writes(compute[i0_1, i1_1])
- compute[i0_1, i1_1] = T.max(rxplaceholder[i0_1, i1_1],
T.float16(0))
-
-
[email protected]_func
-def batch_matmul_rrr_2d_fp16(
- var_rxplaceholder: T.handle,
- var_rxplaceholder_1: T.handle,
- var_matmul: T.handle,
- batch: T.int64,
- M: T.int64,
- N: T.int64,
- K: T.int64,
-) -> None:
- # function attr dict
- T.func_attr({"tir.noalias": True})
- rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K],
dtype="float16")
- rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [K, N],
dtype="float16")
- matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16")
- # body
- # with T.block("root")
- for i0, i1, i2, i3 in T.grid(batch, M, N, K):
- with T.block("matmul"):
- i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
- T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[k, i2_1])
- T.writes(matmul[i0_1, i1_1, i2_1])
- with T.init():
- matmul[i0_1, i1_1, i2_1] = T.float16(0)
- matmul[i0_1, i1_1, i2_1] = (
- matmul[i0_1, i1_1, i2_1] + rxplaceholder[i0_1, i1_1, k] *
rxplaceholder_1[k, i2_1]
- )
-
-
[email protected]_func
-def batch_matmul_rrr_3d_fp16(
- var_rxplaceholder: T.handle,
- var_rxplaceholder_1: T.handle,
- var_matmul: T.handle,
- batch: T.int64,
- M: T.int64,
- N: T.int64,
- K: T.int64,
-) -> None:
- # function attr dict
- T.func_attr({"tir.noalias": True})
- rxplaceholder = T.match_buffer(var_rxplaceholder, [batch, M, K],
dtype="float16")
- rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [batch, K, N],
dtype="float16")
- matmul = T.match_buffer(var_matmul, [batch, M, N], dtype="float16")
- # body
- # with T.block("root")
- for i0, i1, i2, i3 in T.grid(batch, M, N, K):
- with T.block("matmul"):
- i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3])
- T.reads(rxplaceholder[i0_1, i1_1, k], rxplaceholder_1[i0_1, k,
i2_1])
- T.writes(matmul[i0_1, i1_1, i2_1])
- with T.init():
- matmul[i0_1, i1_1, i2_1] = T.float16(0)
- matmul[i0_1, i1_1, i2_1] = (
- matmul[i0_1, i1_1, i2_1]
- + rxplaceholder[i0_1, i1_1, k] * rxplaceholder_1[i0_1, k, i2_1]
- )
-
-
[email protected]_func
-def copy_4d_fp16(
- A_handle: T.handle,
- B_handle: T.handle,
- N: T.int64,
- H: T.int64,
- W: T.int64,
- C: T.int64,
-) -> None:
- A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
- B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16")
- # body
- # with T.block("root")
- for n, h, w, c in T.grid(N, H, W, C):
- with T.block("copy"):
- vn, vh, vw, vc = T.axis.remap("SSSS", [n, h, w, c])
- T.reads(A[vn, vh, vw, vc])
- T.writes(B[vn, vh, vw, vc])
- B[vn, vh, vw, vc] = A[vn, vh, vw, vc]
-
-
[email protected]_func
-def padding_2d_nhwc_fp16(
- A_handle: T.handle,
- B_handle: T.handle,
- N: T.int64,
- H: T.int64,
- W: T.int64,
- C: T.int64,
- pH: T.int64,
- pW: T.int64,
- lH: T.int64,
- lW: T.int64,
- rH: T.int64,
- rW: T.int64,
-) -> None:
- A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
- B = T.match_buffer(B_handle, [N, pH, pW, C], dtype="float16")
- # body
- # with T.block("root")
- for v, v_1, v_2, v_3 in T.grid(N, pH, pW, C):
- with T.block("copy"):
- v_4, v_5, v_6, v_7 = T.axis.remap("SSSS", [v, v_1, v_2, v_3])
- T.reads(A[v_4, v_5 - lH, v_6 - lW, v_7])
- T.writes(B[v_4, v_5, v_6, v_7])
- B[v_4, v_5, v_6, v_7] = T.if_then_else(
- lH <= v_5 and v_5 < rH and lW <= v_6 and v_6 < rW,
- A[v_4, v_5 - lH, v_6 - lW, v_7],
- T.float16(0),
- dtype="float16",
- )
-
-
[email protected]_func
-def conv2d_nhwc_fp16(
- A_handle: T.handle,
- B_handle: T.handle,
- out_handle: T.handle,
- N: T.int64,
- pH: T.int64,
- pW: T.int64,
- H: T.int64,
- W: T.int64,
- C: T.int64,
- O: T.int64,
- KH: T.int64,
- KW: T.int64,
- StrideH: T.int64,
- StrideW: T.int64,
- DilateH: T.int64,
- DilateW: T.int64,
-) -> None:
- A = T.match_buffer(A_handle, [N, pH, pW, C], dtype="float16")
- B = T.match_buffer(B_handle, [O, KH, KW, C], dtype="float16")
- out = T.match_buffer(out_handle, [N, H, W, O], dtype="float16")
- # body
- # with T.block("root")
- for v, v_1, v_2, v_3, v_4, v_5, v_6 in T.grid(N, H, W, O, KH, KW, C):
- with T.block("conv"):
- v_7, v_8, v_9, v_10, v_11, v_12, v_13 = T.axis.remap(
- "SSSSRRR", [v, v_1, v_2, v_3, v_4, v_5, v_6]
- )
- T.reads(
- A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9 *
StrideW, v_13],
- B[v_10, v_11, v_12, v_13],
- )
- T.writes(out[v_7, v_8, v_9, v_10])
- with T.init():
- out[v_7, v_8, v_9, v_10] = T.float16(0)
- out[v_7, v_8, v_9, v_10] = (
- out[v_7, v_8, v_9, v_10]
- + A[v_7, v_11 * DilateH + v_8 * StrideH, v_12 * DilateW + v_9
* StrideW, v_13]
- * B[v_10, v_11, v_12, v_13]
- )
-
-
[email protected]_func
-def bias_add_nhwc_2d_fp16(
- A_handle: T.handle,
- B_handle: T.handle,
- out_handle: T.handle,
- N: T.int64,
- H: T.int64,
- W: T.int64,
- C: T.int64,
-):
- A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
- B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16")
- out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
- for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
- with T.block("T_add"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2,
ax3])
- T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0),
T.int64(0), v_ax3])
- T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
- out[v_ax0, v_ax1, v_ax2, v_ax3] = (
- A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, T.int64(0),
T.int64(0), v_ax3]
- )
-
-
[email protected]_func
-def bias_add_nhwc_1d_fp16(
- A_handle: T.handle,
- B_handle: T.handle,
- out_handle: T.handle,
- N: T.int64,
- H: T.int64,
- W: T.int64,
- C: T.int64,
-):
- A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
- B = T.match_buffer(B_handle, [1, 1, 1, C], dtype="float16")
- out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
- for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
- with T.block("T_add"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2,
ax3])
- T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[T.int64(0), T.int64(0),
T.int64(0), v_ax3])
- T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
- out[v_ax0, v_ax1, v_ax2, v_ax3] = (
- A[v_ax0, v_ax1, v_ax2, v_ax3] + B[T.int64(0), T.int64(0),
T.int64(0), v_ax3]
- )
-
-
[email protected]_func
-def elem_add_2d_fp16(
- in0_handle: T.handle,
- in1_handle: T.handle,
- out_handle: T.handle,
- N: T.int64,
- M: T.int64,
-):
- in0 = T.match_buffer(in0_handle, [N, M], dtype="float16")
- in1 = T.match_buffer(in1_handle, [N, M], dtype="float16")
- out = T.match_buffer(out_handle, [N, M], dtype="float16")
- for ax0, ax1 in T.grid(N, M):
- with T.block("T_add"):
- v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(in0[v_ax0, v_ax1], in1[v_ax0, v_ax1])
- T.writes(out[v_ax0, v_ax1])
- out[v_ax0, v_ax1] = in0[v_ax0, v_ax1] + in1[v_ax0, v_ax1]
-
-
[email protected]_func
-def elem_add_3d_fp16(
- in0_handle: T.handle,
- in1_handle: T.handle,
- out_handle: T.handle,
- B: T.int64,
- N: T.int64,
- M: T.int64,
-):
- in0 = T.match_buffer(in0_handle, [B, N, M], dtype="float16")
- in1 = T.match_buffer(in1_handle, [B, N, M], dtype="float16")
- out = T.match_buffer(out_handle, [B, N, M], dtype="float16")
- for ax0, ax1, ax2 in T.grid(B, N, M):
- with T.block("T_add"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(in0[v_ax0, v_ax1, v_ax2], in1[v_ax0, v_ax1, v_ax2])
- T.writes(out[v_ax0, v_ax1, v_ax2])
- out[v_ax0, v_ax1, v_ax2] = in0[v_ax0, v_ax1, v_ax2] + in1[v_ax0,
v_ax1, v_ax2]
-
-
[email protected]_func
-def elem_add_4d_fp16(
- A_handle: T.handle,
- B_handle: T.handle,
- out_handle: T.handle,
- N: T.int64,
- H: T.int64,
- W: T.int64,
- C: T.int64,
-):
- A = T.match_buffer(A_handle, [N, H, W, C], dtype="float16")
- B = T.match_buffer(B_handle, [N, H, W, C], dtype="float16")
- out = T.match_buffer(out_handle, [N, H, W, C], dtype="float16")
- for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
- with T.block("T_add"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2,
ax3])
- T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, v_ax1, v_ax2,
v_ax3])
- T.writes(out[v_ax0, v_ax1, v_ax2, v_ax3])
- out[v_ax0, v_ax1, v_ax2, v_ax3] = (
- A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax0, v_ax1, v_ax2, v_ax3]
- )
-
-
[email protected]_func
-def scalar_mul_3d_fp16(
- inp0_handle: T.handle,
- out_handle: T.handle,
- D1: T.int64,
- D2: T.int64,
- D3: T.int64,
- scalar: T.float16,
-):
- inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
- out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
- for ax0, ax1, ax2 in T.grid(D1, D2, D3):
- with T.block("T_mul"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(inp0[v_ax0, v_ax1, v_ax2])
- T.writes(out[v_ax0, v_ax1, v_ax2])
- out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * scalar
-
-
[email protected]_func
-def erf_3d_fp32(
- inp0_handle: T.handle,
- out_handle: T.handle,
- D1: T.int64,
- D2: T.int64,
- D3: T.int64,
-):
- inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32")
- out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32")
- for ax0, ax1, ax2 in T.grid(D1, D2, D3):
- with T.block("T_erf"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(inp0[v_ax0, v_ax1, v_ax2])
- T.writes(out[v_ax0, v_ax1, v_ax2])
- out[v_ax0, v_ax1, v_ax2] = T.erf(inp0[v_ax0, v_ax1, v_ax2])
-
-
[email protected]_func
-def scalar_add_3d_fp16(
- inp0_handle: T.handle,
- out_handle: T.handle,
- D1: T.int64,
- D2: T.int64,
- D3: T.int64,
- scalar: T.float16,
-):
- inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
- out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
- for ax0, ax1, ax2 in T.grid(D1, D2, D3):
- with T.block("T_add"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(inp0[v_ax0, v_ax1, v_ax2])
- T.writes(out[v_ax0, v_ax1, v_ax2])
- out[v_ax0, v_ax1, v_ax2] = scalar + inp0[v_ax0, v_ax1, v_ax2]
-
-
[email protected]_func
-def elem_mul_3d_fp16(
- inp0_handle: T.handle,
- inp1_handle: T.handle,
- out_handle: T.handle,
- D1: T.int64,
- D2: T.int64,
- D3: T.int64,
-):
- inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
- inp1 = T.match_buffer(inp1_handle, [D1, D2, D3], dtype="float16")
- out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
- for ax0, ax1, ax2 in T.grid(D1, D2, D3):
- with T.block("T_mul"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(inp0[v_ax0, v_ax1, v_ax2], inp1[v_ax0, v_ax1, v_ax2])
- T.writes(out[v_ax0, v_ax1, v_ax2])
- out[v_ax0, v_ax1, v_ax2] = inp0[v_ax0, v_ax1, v_ax2] * inp1[v_ax0,
v_ax1, v_ax2]
-
-
[email protected]_func
-def cast_3d_fp16(
- inp0_handle: T.handle,
- out_handle: T.handle,
- D1: T.int64,
- D2: T.int64,
- D3: T.int64,
-):
- inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float32")
- out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float16")
- for ax0, ax1, ax2 in T.grid(D1, D2, D3):
- with T.block("T_cast"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(inp0[v_ax0, v_ax1, v_ax2])
- T.writes(out[v_ax0, v_ax1, v_ax2])
- out[v_ax0, v_ax1, v_ax2] = T.Cast("float16", inp0[v_ax0, v_ax1,
v_ax2])
-
-
[email protected]_func
-def cast_3d_fp32(
- inp0_handle: T.handle,
- out_handle: T.handle,
- D1: T.int64,
- D2: T.int64,
- D3: T.int64,
-):
- inp0 = T.match_buffer(inp0_handle, [D1, D2, D3], dtype="float16")
- out = T.match_buffer(out_handle, [D1, D2, D3], dtype="float32")
- for ax0, ax1, ax2 in T.grid(D1, D2, D3):
- with T.block("T_cast"):
- v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(inp0[v_ax0, v_ax1, v_ax2])
- T.writes(out[v_ax0, v_ax1, v_ax2])
- out[v_ax0, v_ax1, v_ax2] = T.Cast("float32", inp0[v_ax0, v_ax1,
v_ax2])
-
-
-def get_tir_pattern() -> List[tvm.tir.PrimFunc]:
- """Get the tir patterns for backend dispatch."""
- return [
- matmul_rrr_fp16,
- bias_row_2d_fp16,
- bias_row_1d_fp16,
- batch_bias_row_2d_fp16,
- batch_bias_row_1d_fp16,
- relu_fp16,
- erf_3d_fp32,
- batch_matmul_rrr_2d_fp16,
- batch_matmul_rrr_3d_fp16,
- copy_4d_fp16,
- padding_2d_nhwc_fp16,
- conv2d_nhwc_fp16,
- bias_add_nhwc_2d_fp16,
- bias_add_nhwc_1d_fp16,
- elem_add_2d_fp16,
- elem_add_3d_fp16,
- elem_add_4d_fp16,
- elem_mul_3d_fp16,
- scalar_add_3d_fp16,
- scalar_mul_3d_fp16,
- cast_3d_fp16,
- cast_3d_fp32,
- ]
diff --git a/tests/python/relax/test_codegen_tir_cutlass.py
b/tests/python/relax/test_codegen_tir_cutlass.py
deleted file mode 100644
index 9670f15986..0000000000
--- a/tests/python/relax/test_codegen_tir_cutlass.py
+++ /dev/null
@@ -1,702 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import annotations
-import tempfile
-
-from tvm import relax, runtime
-import tvm
-import tvm.testing
-from tvm import relax
-import scipy
-from scipy.special import erf
-import numpy as np
-from tvm.target import Target
-from tvm.relax.vm_build import build as relax_build
-from tvm.script.ir_builder import relax as R
-from tvm.script.ir_builder import ir as I
-from tvm.script.ir_builder import tir as T
-from tvm.script.ir_builder import IRBuilder
-
-from tvm.relax.backend_tir import get_tir_pattern
-from tvm.relax.backend_tir.contrib.cutlass import cutlass_fcodegen,
compile_options
-
-A_TYPE = "float16"
-B_TYPE = "float16"
-C_TYPE = "float16"
-
-target = Target("cuda")
-
-
-def f_run(rt_mod: runtime.Module, device: runtime.ndarray.Device, *input):
- vm = relax.vm.VirtualMachine(rt_mod=rt_mod, device=device)
- return vm["main"](*input)
-
-
-def build(mod):
- mod = relax.transform.LegalizeOps()(mod)
- mod = relax.transform.AnnotateTIROpPattern()(mod)
- mod = relax.transform.FuseOps()(mod)
- mod = relax.transform.FuseTIR()(mod)
- mod = relax.transform.SplitCallTIRByPattern(get_tir_pattern(),
cutlass_fcodegen())(mod)
- mod = relax.transform.DeadCodeElimination()(mod)
- executable = tvm.compile(mod, target)
- return executable.jit(**compile_options(target))
-
-
-def build_and_run_reference(mod, inputs_np):
- dev = tvm.device("llvm", 0)
- ex = tvm.compile(mod, "llvm")
- vm = relax.VirtualMachine(ex, dev)
- f = vm["main"]
- inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
- return f(*inputs).numpy()
-
-
-def constructGEMM(M, N, K):
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- A = R.arg(
- "A", relax.TensorStructInfo((M, K), A_TYPE)
- ) # pylint: disable=invalid-name
- B = R.arg(
- "B", relax.TensorStructInfo((K, N), B_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
- R.output(C)
- (C,) = df.output_vars
- R.func_ret_value(C)
- relax_mod = ib.get()
- return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_dense():
- m, n, k = 128, 64, 256
- executable = build(constructGEMM(m, n, k))
- dev = tvm.cuda()
- A = np.random.randn(m, k).astype("float16")
- B = np.random.randn(k, n).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- result = f_run(executable, dev, A_tvm, B_tvm)
- np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
-
-
-def constructGEMM_bias(M, N, K):
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- A = R.arg(
- "A", relax.TensorStructInfo((M, K), A_TYPE)
- ) # pylint: disable=invalid-name
- B = R.arg(
- "B", relax.TensorStructInfo((K, N), B_TYPE)
- ) # pylint: disable=invalid-name
- bias = R.arg(
- "bias", relax.TensorStructInfo((1, N), A_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
- D = R.emit(R.add(C, bias))
- R.output(D)
- (D,) = df.output_vars
- R.func_ret_value(D)
- relax_mod = ib.get()
- return relax_mod
-
-
-def constructGEMM_bias2(M, N, K):
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- A = R.arg(
- "A", relax.TensorStructInfo((M, K), A_TYPE)
- ) # pylint: disable=invalid-name
- B = R.arg(
- "B", relax.TensorStructInfo((K, N), B_TYPE)
- ) # pylint: disable=invalid-name
- bias = R.arg(
- "bias", relax.TensorStructInfo((N,), A_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
- D = R.emit(R.add(C, bias))
- R.output(D)
- (D,) = df.output_vars
- R.func_ret_value(D)
- relax_mod = ib.get()
- return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_dense_bias():
- m, n, k = 128, 64, 256
- executable = build(constructGEMM_bias(m, n, k))
- dev = tvm.cuda()
- A = np.random.randn(m, k).astype("float16")
- B = np.random.randn(k, n).astype("float16")
- bias = np.random.randn(1, n).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- bias_tvm = tvm.nd.array(bias, dev)
- result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
- np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2,
atol=5e-2)
-
-
[email protected]_cutlass
-def test_cutlass_dense_bias2():
- m, n, k = 128, 64, 256
- executable = build(constructGEMM_bias2(m, n, k))
- dev = tvm.cuda()
- A = np.random.randn(m, k).astype("float16")
- B = np.random.randn(k, n).astype("float16")
- bias = np.random.randn(n).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- bias_tvm = tvm.nd.array(bias, dev)
- result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
- np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2,
atol=5e-2)
-
-
-def constructGEMM_bias_relu(M, N, K):
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- A = R.arg(
- "A", relax.TensorStructInfo((M, K), A_TYPE)
- ) # pylint: disable=invalid-name
- B = R.arg(
- "B", relax.TensorStructInfo((K, N), B_TYPE)
- ) # pylint: disable=invalid-name
- bias = R.arg(
- "bias", relax.TensorStructInfo((1, N), A_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
- D = R.emit(R.add(C, bias))
- E = R.emit(R.nn.relu(D))
- R.output(E)
- (E,) = df.output_vars
- R.func_ret_value(E)
- relax_mod = ib.get()
- return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_dense_bias_relu():
- m, n, k = 128, 64, 256
- executable = build(constructGEMM_bias_relu(m, n, k))
- dev = tvm.cuda()
- A = np.random.randn(m, k).astype("float16")
- B = np.random.randn(k, n).astype("float16")
- bias = np.random.randn(1, n).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- bias_tvm = tvm.nd.array(bias, dev)
- result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
- np.testing.assert_allclose(result.numpy(), np.maximum(A @ B + bias, 0),
rtol=5e-2, atol=5e-2)
-
-
-def constructBatchGEMM(batch, M, N, K):
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- A = R.arg(
- "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
- ) # pylint: disable=invalid-name
- B = R.arg(
- "B", relax.TensorStructInfo((K, N), B_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
- R.output(C)
- (C,) = df.output_vars
- R.func_ret_value(C)
- relax_mod = ib.get()
- return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense():
- b, m, n, k = 2, 128, 256, 64
- executable = build(constructBatchGEMM(b, m, n, k))
- dev = tvm.cuda()
- A = np.random.randn(b, m, k).astype("float16")
- B = np.random.randn(k, n).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- result = f_run(executable, dev, A_tvm, B_tvm)
- np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
-
-
-def constructBatchGEMM2(batch, M, N, K):
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- A = R.arg(
- "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
- ) # pylint: disable=invalid-name
- B = R.arg(
- "B", relax.TensorStructInfo((batch, K, N), B_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
- R.output(C)
- (C,) = df.output_vars
- R.func_ret_value(C)
- relax_mod = ib.get()
- return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense2():
- b, m, n, k = 2, 128, 256, 64
- executable = build(constructBatchGEMM2(b, m, n, k))
- dev = tvm.cuda()
- A = np.random.randn(b, m, k).astype("float16")
- B = np.random.randn(b, k, n).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- result = f_run(executable, dev, A_tvm, B_tvm)
- np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2)
-
-
-def constructBatchGEMM_bias(batch, M, N, K):
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- A = R.arg(
- "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
- ) # pylint: disable=invalid-name
- B = R.arg(
- "B", relax.TensorStructInfo((K, N), B_TYPE)
- ) # pylint: disable=invalid-name
- bias = R.arg(
- "bias", relax.TensorStructInfo((1, N), A_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
- D = R.emit(R.add(C, bias))
- R.output(D)
- (D,) = df.output_vars
- R.func_ret_value(D)
- relax_mod = ib.get()
- return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense_bias():
- b, m, n, k = 2, 128, 256, 64
- executable = build(constructBatchGEMM_bias(b, m, n, k))
- dev = tvm.cuda()
- A = np.random.randn(b, m, k).astype("float16")
- B = np.random.randn(k, n).astype("float16")
- bias = np.random.randn(1, n).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- bias_tvm = tvm.nd.array(bias, dev)
- result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
- np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2,
atol=5e-2)
-
-
-def constructBatchGEMM_bias2(batch, M, N, K):
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- A = R.arg(
- "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
- ) # pylint: disable=invalid-name
- B = R.arg(
- "B", relax.TensorStructInfo((K, N), B_TYPE)
- ) # pylint: disable=invalid-name
- bias = R.arg(
- "bias", relax.TensorStructInfo((N,), A_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
- D = R.emit(R.add(C, bias))
- R.output(D)
- (D,) = df.output_vars
- R.func_ret_value(D)
- relax_mod = ib.get()
- return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense_bias2():
- b, m, n, k = 2, 128, 256, 64
- executable = build(constructBatchGEMM_bias2(b, m, n, k))
- dev = tvm.cuda()
- A = np.random.randn(b, m, k).astype("float16")
- B = np.random.randn(k, n).astype("float16")
- bias = np.random.randn(n).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- bias_tvm = tvm.nd.array(bias, dev)
- result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
- np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2,
atol=5e-2)
-
-
-def constructBatchGEMM_bias2_gelu(batch, M, N, K):
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- A = R.arg(
- "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
- ) # pylint: disable=invalid-name
- B = R.arg(
- "B", relax.TensorStructInfo((K, N), B_TYPE)
- ) # pylint: disable=invalid-name
- bias = R.arg(
- "bias", relax.TensorStructInfo((N,), A_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
- D = R.emit(R.add(C, bias))
- E = R.emit(R.nn.gelu(D))
- R.output(E)
- (E,) = df.output_vars
- R.func_ret_value(E)
- relax_mod = ib.get()
- return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense_bias2_gelu():
- b, m, n, k = 2, 128, 64, 256
- executable = build(constructBatchGEMM_bias2_gelu(b, m, n, k))
- dev = tvm.cuda()
- A = np.random.randn(b, m, k).astype("float16")
- B = np.random.randn(k, n).astype("float16")
- bias = np.random.randn(n).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- bias_tvm = tvm.nd.array(bias, dev)
- result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
- C = A @ B + bias
- O = 0.5 * C * (1 + erf(C / np.sqrt(2)))
- np.testing.assert_allclose(result.numpy(), O, rtol=5e-2, atol=5e-2)
-
-
-def constructBatchGEMM_bias2_mul(batch, M, N, K):
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- A = R.arg(
- "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
- ) # pylint: disable=invalid-name
- B = R.arg(
- "B", relax.TensorStructInfo((K, N), B_TYPE)
- ) # pylint: disable=invalid-name
- bias = R.arg(
- "bias", relax.TensorStructInfo((N,), A_TYPE)
- ) # pylint: disable=invalid-name
- residual = R.arg("residual", relax.TensorStructInfo((batch, M,
N), A_TYPE))
- with R.dataflow() as df:
- C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
- D = R.emit(R.add(C, bias))
- E = R.emit(R.multiply(D, residual))
- R.output(E)
- (E,) = df.output_vars
- R.func_ret_value(E)
- relax_mod = ib.get()
- return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense_bias2_mul():
- b, m, n, k = 2, 128, 256, 64
- executable = build(constructBatchGEMM_bias2_mul(b, m, n, k))
- dev = tvm.cuda()
- A = np.random.randn(b, m, k).astype("float16")
- B = np.random.randn(k, n).astype("float16")
- bias = np.random.randn(n).astype("float16")
- residual = np.random.randn(b, m, n).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- bias_tvm = tvm.nd.array(bias, dev)
- residual_tvm = tvm.nd.array(residual, dev)
- result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm, residual_tvm)
- np.testing.assert_allclose(result.numpy(), ((A @ B) + bias) * residual,
rtol=5e-2, atol=5e-2)
-
-
-def constructBatchGEMM2_bias(batch, M, N, K):
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- A = R.arg(
- "A", relax.TensorStructInfo((batch, M, K), A_TYPE)
- ) # pylint: disable=invalid-name
- B = R.arg(
- "B", relax.TensorStructInfo((batch, K, N), B_TYPE)
- ) # pylint: disable=invalid-name
- bias = R.arg(
- "bias", relax.TensorStructInfo((1, N), A_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(R.matmul(A, B, out_dtype=C_TYPE))
- D = R.emit(R.add(C, bias))
- R.output(D)
- (D,) = df.output_vars
- R.func_ret_value(D)
- relax_mod = ib.get()
- return relax_mod
-
-
[email protected]_cutlass
-def test_cutlass_batch_dense2_bias():
- b, m, n, k = 2, 128, 256, 64
- executable = build(constructBatchGEMM2_bias(b, m, n, k))
- dev = tvm.cuda()
- A = np.random.randn(b, m, k).astype("float16")
- B = np.random.randn(b, k, n).astype("float16")
- bias = np.random.randn(1, n).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- bias_tvm = tvm.nd.array(bias, dev)
- result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
- np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2,
atol=5e-2)
-
-
-def constructConv2D(N, C, H, W, KH, KW, O, strides, padding, dilation):
- from tvm.script.ir_builder import IRBuilder
- from tvm.script.ir_builder import ir as I
- from tvm.script.ir_builder import relax as R
- from tvm.script.ir_builder import tir as T
-
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- x = R.arg(
- "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
- ) # pylint: disable=invalid-name
- w = R.arg(
- "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(
- R.nn.conv2d(
- x,
- w,
- strides=strides,
- padding=padding,
- dilation=dilation,
- groups=1,
- data_layout="NHWC",
- kernel_layout="OHWI",
- out_layout="NHWC",
- out_dtype=C_TYPE,
- )
- )
- R.output(C)
- (C,) = df.output_vars
- R.func_ret_value(C)
- mod = ib.get()
- return mod
-
-
[email protected]_cutlass
-def test_cutlass_conv2d():
- n, c, h, w = 1, 3, 224, 224
- kh, kw, o = 3, 3, 64
- for strides in [(1, 1), (2, 2)]:
- for padding in [(0, 0), (3, 3)]:
- for dilation in [(1, 1), (4, 4)]:
- mod = constructConv2D(n, c, h, w, kh, kw, o, strides, padding,
dilation)
- executable = build(mod)
- dev = tvm.cuda()
- np.random.seed(0)
- A = np.random.randn(n, h, w, c).astype("float16")
- B = np.random.randn(o, kh, kw, c).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- result = f_run(executable, dev, A_tvm, B_tvm)
- result_ref = build_and_run_reference(mod, [A, B])
- np.testing.assert_allclose(
- result.numpy(),
- result_ref,
- rtol=5e-2,
- atol=5e-2,
- )
-
-
-def constructConv2D_bias(N, C, H, W, KH, KW, O, strides, padding, dilation):
- from tvm.script.ir_builder import IRBuilder
- from tvm.script.ir_builder import ir as I
- from tvm.script.ir_builder import relax as R
- from tvm.script.ir_builder import tir as T
-
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- x = R.arg(
- "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
- ) # pylint: disable=invalid-name
- w = R.arg(
- "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
- ) # pylint: disable=invalid-name
- bias = R.arg(
- "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(
- R.nn.conv2d(
- x,
- w,
- strides=strides,
- padding=padding,
- dilation=dilation,
- groups=1,
- data_layout="NHWC",
- kernel_layout="OHWI",
- out_layout="NHWC",
- out_dtype=C_TYPE,
- )
- )
- D = R.emit(R.add(C, bias))
- R.output(D)
- (D,) = df.output_vars
- R.func_ret_value(D)
- mod = ib.get()
- return mod
-
-
[email protected]_cutlass
-def test_cutlass_conv2d_bias():
- c, h, w = 3, 224, 224
- kh, kw, o = 3, 3, 64
- for n in [1, 2]:
- for strides in [(1, 1), (2, 2)]:
- for padding in [(0, 0), (3, 3)]:
- for dilation in [(1, 1), (4, 4)]:
- mod = constructConv2D_bias(n, c, h, w, kh, kw, o, strides,
padding, dilation)
- executable = build(mod)
- dev = tvm.cuda()
- np.random.seed(0)
- A = np.random.randn(n, h, w, c).astype("float16")
- B = np.random.randn(o, kh, kw, c).astype("float16")
- bias = np.random.randn(1, 1, 1, o).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- bias_tvm = tvm.nd.array(bias, dev)
- result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm)
- result_ref = build_and_run_reference(mod, [A, B, bias])
- np.testing.assert_allclose(
- result.numpy(),
- result_ref,
- rtol=5e-2,
- atol=5e-2,
- )
-
-
-def constructConv2D_bias_add(N, C, H, W, KH, KW, O, OH, OW, strides, padding,
dilation):
- from tvm.script.ir_builder import IRBuilder
- from tvm.script.ir_builder import ir as I
- from tvm.script.ir_builder import relax as R
- from tvm.script.ir_builder import tir as T
-
- with IRBuilder() as ib: # pylint: disable=invalid-name
- with I.ir_module() as frame:
- with R.function():
- R.func_name("main")
- x = R.arg(
- "x", relax.TensorStructInfo((N, H, W, C), A_TYPE)
- ) # pylint: disable=invalid-name
- w = R.arg(
- "w", relax.TensorStructInfo((O, KH, KW, C), B_TYPE)
- ) # pylint: disable=invalid-name
- bias = R.arg(
- "bias", relax.TensorStructInfo((1, 1, 1, O), A_TYPE)
- ) # pylint: disable=invalid-name
- res = R.arg(
- "res", relax.TensorStructInfo((N, OH, OW, O), A_TYPE)
- ) # pylint: disable=invalid-name
- with R.dataflow() as df:
- C = R.emit(
- R.nn.conv2d(
- x,
- w,
- strides=strides,
- padding=padding,
- dilation=dilation,
- groups=1,
- data_layout="NHWC",
- kernel_layout="OHWI",
- out_layout="NHWC",
- out_dtype=C_TYPE,
- )
- )
- D = R.emit(R.add(C, bias))
- E = R.emit(R.add(D, res))
- R.output(E)
- (E,) = df.output_vars
- R.func_ret_value(E)
- mod = ib.get()
- return mod
-
-
[email protected]_cutlass
-def test_cutlass_conv2d_bias_add():
- n, c, h, w = 2, 3, 224, 224
- kh, kw, o = 3, 3, 64
- for strides in [(1, 1), (2, 2)]:
- for padding in [(0, 0), (3, 3)]:
- for dilation in [(1, 1), (4, 4)]:
- oh = (h + 2 * padding[0] - dilation[0] * (kh - 1) - 1) //
strides[0] + 1
- ow = (w + 2 * padding[1] - dilation[1] * (kw - 1) - 1) //
strides[1] + 1
- mod = constructConv2D_bias_add(
- n, c, h, w, kh, kw, o, oh, ow, strides, padding, dilation
- )
- executable = build(mod)
- dev = tvm.cuda()
- np.random.seed(0)
- A = np.random.randn(n, h, w, c).astype("float16")
- B = np.random.randn(o, kh, kw, c).astype("float16")
- bias = np.random.randn(1, 1, 1, o).astype("float16")
- res = np.random.randn(n, oh, ow, o).astype("float16")
- A_tvm = tvm.nd.array(A, dev)
- B_tvm = tvm.nd.array(B, dev)
- bias_tvm = tvm.nd.array(bias, dev)
- res_tvm = tvm.nd.array(res, dev)
- result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm,
res_tvm)
- result_ref = build_and_run_reference(mod, [A, B, bias, res])
- np.testing.assert_allclose(
- result.numpy(),
- result_ref,
- rtol=5e-2,
- atol=5e-2,
- )
-
-
-if __name__ == "__main__":
- tvm.testing.main()