This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 4e03d85514 [Unity][BYOC]Add relax backend pattern registry (#14106) 4e03d85514 is described below commit 4e03d85514191996bb5c1fabc1c8e3463efffa2c Author: Lite Ye <yelite...@gmail.com> AuthorDate: Fri Feb 24 04:16:47 2023 -0500 [Unity][BYOC]Add relax backend pattern registry (#14106) * Add relax backend pattern registry * Add doc --- CMakeLists.txt | 1 + python/tvm/relax/backend/__init__.py | 20 +++++ python/tvm/relax/backend/_ffi_api.py | 21 +++++ python/tvm/relax/backend/contrib/__init__.py | 20 +++++ python/tvm/relax/backend/contrib/cutlass.py | 90 +++++++++++++++++++ python/tvm/relax/backend/pattern_registry.py | 125 +++++++++++++++++++++++++++ python/tvm/relax/backend/patterns.py | 115 ++++++++++++++++++++++++ python/tvm/relax/dpl/pattern.py | 27 ++---- src/relax/backend/pattern_registry.cc | 82 ++++++++++++++++++ src/relax/backend/pattern_registry.h | 106 +++++++++++++++++++++++ tests/python/relax/test_codegen_cutlass.py | 67 +++----------- 11 files changed, 598 insertions(+), 76 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 18be118832..22e82e2fb7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -295,6 +295,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/relax/transform/*.cc src/relax/backend/vm/*.cc src/relax/backend/task_extraction.cc + src/relax/backend/pattern_registry.cc src/relax/utils.cc ) diff --git a/python/tvm/relax/backend/__init__.py b/python/tvm/relax/backend/__init__.py new file mode 100644 index 0000000000..c3786591e3 --- /dev/null +++ b/python/tvm/relax/backend/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax backends""" + +from . import contrib +from .pattern_registry import get_pattern, get_patterns_with_prefix diff --git a/python/tvm/relax/backend/_ffi_api.py b/python/tvm/relax/backend/_ffi_api.py new file mode 100644 index 0000000000..d1378b2eac --- /dev/null +++ b/python/tvm/relax/backend/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI API for Relax backend.""" + +import tvm._ffi + +tvm._ffi._init_api("relax.backend", __name__) diff --git a/python/tvm/relax/backend/contrib/__init__.py b/python/tvm/relax/backend/contrib/__init__.py new file mode 100644 index 0000000000..a094c97d24 --- /dev/null +++ b/python/tvm/relax/backend/contrib/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""External backend codegen modules for Relax.""" + +from .cutlass import partition_for_cutlass diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py new file mode 100644 index 0000000000..20cf57a40a --- /dev/null +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern table for CUTLASS backend""" + +from tvm.relax import transform + +from ..pattern_registry import get_patterns_with_prefix, register_patterns +from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern + +register_patterns( + [ + ( + "cutlass.conv2d", + make_fused_bias_activation_pattern( + "relax.nn.conv2d", + with_bias=False, + activation=None, + ), + ), + ( + "cutlass.conv2d_bias_relu", + make_fused_bias_activation_pattern( + "relax.nn.conv2d", + with_bias=True, + activation="relax.nn.relu", + ), + ), + ( + "cutlass.matmul", + make_matmul_pattern( + with_bias=False, + ), + ), + ( + "cutlass.matmul_bias", + make_matmul_pattern( + with_bias=True, + ), + ), + ( + "cutlass.matmul_bias_relu", + make_matmul_pattern( + with_bias=True, + activation="relax.nn.relu", + ), + ), + ( + "cutlass.matmul_bias_gelu", + make_matmul_pattern( + with_bias=True, + activation="relax.nn.gelu", + ), + ), + ] +) + + +def partition_for_cutlass(mod): + """ + Partition the input module into CUTLASS-supported subgraphs. + + Parameters + ---------- + mod: tvm.IRModule + The IRModule to be partitioned. + + Returns + ------- + mod: tvm.IRModule + The resulting IRModule, containing partitioned subgraphs to be + compiled by the CUTLASS backend. + """ + + cutlass_patterns = get_patterns_with_prefix("cutlass") + return transform.FuseOpsByPattern(cutlass_patterns, annotate_codegen=True)(mod) diff --git a/python/tvm/relax/backend/pattern_registry.py b/python/tvm/relax/backend/pattern_registry.py new file mode 100644 index 0000000000..0016de0a50 --- /dev/null +++ b/python/tvm/relax/backend/pattern_registry.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pattern registry for BYOC backends""" + +from typing import List, Mapping, Optional, Tuple, Union + +import tvm +from tvm.relax.dpl import DFPattern +from tvm.runtime import Object + +from . import _ffi_api + + +@tvm._ffi.register_object("relax.backend.PatternRegistryEntry") +class PatternRegistryEntry(Object): + """ + An entry in the pattern registry. This represents a single pattern that + can be used to identify expressions that can be handled by external + backends, like CUTLASS and TensorRT. + + Parameters + ---------- + name: str + The name of pattern. Usually it starts with the name of backend, like 'cutlass.matmul'. + + pattern: DFPattern + The dataflow pattern that will be used to match expressions that can be handled + by external backends. + + arg_patterns: Mapping[str, DFPattern] + The mapping from arg name to its pattern. It can be used to extract arg expression + from match result. All DFPattern in this map should be part of the `pattern`. + """ + + name: str + pattern: DFPattern + arg_patterns: Mapping[str, DFPattern] + + def __init__(self, name: str, pattern: DFPattern, arg_patterns: Mapping[str, DFPattern]): + self.__init_handle_by_constructor__( + _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns # type: ignore + ) + + +Pattern = Union[ + PatternRegistryEntry, + Tuple[str, DFPattern], + Tuple[str, Tuple[DFPattern, Mapping[str, DFPattern]]], +] + + +def register_patterns(patterns: List[Pattern]): + """ + Register patterns which will be used to partition the DataflowBlock into + subgraphs that are supported by external backends. + + Parameters + ---------- + patterns: List[Pattern] + Patterns to be registered. Patterns that appear later in the list have + higher priority when partitioning DataflowBlock. + """ + entries = [] + for item in patterns: + if isinstance(item, PatternRegistryEntry): + entries.append(item) + elif isinstance(item, tuple): + name, pattern_or_tuple = item + if isinstance(pattern_or_tuple, tuple): + pattern, arg_patterns = pattern_or_tuple + else: + pattern, arg_patterns = pattern_or_tuple, {} + entries.append(PatternRegistryEntry(name, pattern, arg_patterns)) + else: + raise TypeError(f"Cannot register type {type(pattern)} as pattern") + _ffi_api.RegisterPatterns(entries) + + +def get_patterns_with_prefix(prefix: str) -> List[PatternRegistryEntry]: + """ + Get a list of patterns whose names startwith `prefix`. + + Parameters + ---------- + prefix: str + The prefix of pattern name. + + Returns + ------- + patterns: PatternRegistryEntry + Matched patterns, ordered by priority from high to low. + """ + return _ffi_api.GetPatternsWithPrefix(prefix) + + +def get_pattern(name: str) -> Optional[PatternRegistryEntry]: + """ + Find the pattern with a particular name. + + Parameters + ---------- + name: str + The pattern name. + + Returns + ------- + pattern: Optional[PatternRegistryEntry] + The matched pattern. Returns None if such pattern is not found. + """ + return _ffi_api.GetPattern(name) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py new file mode 100644 index 0000000000..2f744af660 --- /dev/null +++ b/python/tvm/relax/backend/patterns.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Common patterns used in BYOC""" + +from typing import Dict, Mapping, Tuple + +from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard + + +def _with_bias_activation_pattern( + out: DFPattern, + args: Dict[str, DFPattern], + with_bias: bool = False, + activation: str = None, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + if with_bias: + args["bias"] = bias = wildcard() + out = is_op("relax.add")(out, bias) + + if activation: + out = is_op(activation)(out) + + return out, args + + +def make_fused_bias_activation_pattern( + op_name: str, + with_bias: bool = False, + activation: str = None, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + """ + A simple utility to create patterns for an operation fused with bias addition and activation. + + Parameters + ---------- + op_name: str + The name of a Relax op, such as "relax.nn.conv2d" + + with_bias: bool + Whether or not to include bias addition + + activation: str + The name of an activation Relax op, such as "relax.nn.relu" + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a fused operation + + args: Mapping[str, DFPattern] + The mapping from arg name to its pattern. It can be used to extract + arg expression from match result. + """ + lhs = wildcard() + rhs = wildcard() + args = {"lhs": lhs, "rhs": rhs} + out = is_op(op_name)(lhs, rhs) + + return _with_bias_activation_pattern(out, args, with_bias, activation) + + +def make_matmul_pattern( + with_bias: bool = False, + activation: str = None, + transposed_rhs: bool = False, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + """ + Create pattern for matrix multiplication. + + Parameters + ---------- + with_bias: bool + Whether or not to include bias addition + + activation: str + The name of an activation Relax op, such as "relax.nn.relu" + + transposed_rhs: bool + Whether the right hand side of multiplication is transposed. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a matrix multiplication. + + args: Mapping[str, DFPattern] + The mapping from arg name to its pattern. It can be used to extract + arg expression from match result. + """ + + lhs = wildcard() + rhs = wildcard() + args = {"lhs": lhs, "rhs": rhs} + + if transposed_rhs: + rhs = is_op("relax.permute_dims")(rhs) + + out = is_op("relax.matmul")(lhs, rhs) + + return _with_bias_activation_pattern(out, args, with_bias, activation) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 44faa0c93a..9e1963f7ed 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -1046,17 +1046,6 @@ def _only_used_by( return ffi.only_used_by(lhs, rhs, index) # type: ignore -def _add_bias_activation_pattern(out, with_bias=False, activation=None): - if with_bias: - bias = wildcard() - out = is_op("relax.add")(out, bias) - - if activation: - return is_op(activation)(out) - - return out - - def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None): """ A simple utility to create patterns for an operation fused with bias addition and activation. @@ -1081,15 +1070,11 @@ def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None rhs = wildcard() out = is_op(op_name)(lhs, rhs) - return _add_bias_activation_pattern(out, with_bias, activation) - + if with_bias: + bias = wildcard() + out = is_op("relax.add")(out, bias) -def make_matmul_pattern(with_bias=False, activation=None, transposed_b=False): - lhs = wildcard() - if transposed_b: - rhs = is_op("relax.permute_dims")(wildcard()) - else: - rhs = wildcard() - out = is_op("relax.matmul")(lhs, rhs) + if activation: + return is_op(activation)(out) - return _add_bias_activation_pattern(out, with_bias, activation) + return out diff --git a/src/relax/backend/pattern_registry.cc b/src/relax/backend/pattern_registry.cc new file mode 100644 index 0000000000..3ca7973365 --- /dev/null +++ b/src/relax/backend/pattern_registry.cc @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./pattern_registry.h" + +#include "../../support/utils.h" + +namespace tvm { +namespace relax { +namespace backend { + +PatternRegistryEntry::PatternRegistryEntry(String name, DFPattern pattern, + Map<String, DFPattern> arg_patterns) { + ObjectPtr<PatternRegistryEntryNode> n = make_object<PatternRegistryEntryNode>(); + n->name = std::move(name); + n->pattern = std::move(pattern); + n->arg_patterns = std::move(arg_patterns); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PatternRegistryEntryNode); + +static std::vector<PatternRegistryEntry>* GetRegistryTable() { + static std::vector<PatternRegistryEntry> table; + return &table; +} + +void RegisterPatterns(Array<PatternRegistryEntry> entries) { + auto* table = GetRegistryTable(); + for (const auto& entry : entries) { + table->push_back(entry); + } +} + +Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix) { + auto* table = GetRegistryTable(); + Array<PatternRegistryEntry> result; + for (auto it = table->rbegin(); it != table->rend(); ++it) { + if (support::StartsWith((*it)->name, prefix.data())) { + result.push_back(*it); + } + } + return result; +} + +Optional<PatternRegistryEntry> GetPattern(const String& pattern_name) { + auto* table = GetRegistryTable(); + for (auto it = table->rbegin(); it != table->rend(); ++it) { + if ((*it)->name == pattern_name) { + return *it; + } + } + return NullOpt; +} + +TVM_REGISTER_GLOBAL("relax.backend.PatternRegistryEntry") + .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern> arg_patterns) { + return PatternRegistryEntry(name, pattern, arg_patterns); + }); +TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns); +TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix); +TVM_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern); + +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/pattern_registry.h b/src/relax/backend/pattern_registry.h new file mode 100644 index 0000000000..2e199a2bb1 --- /dev/null +++ b/src/relax/backend/pattern_registry.h @@ -0,0 +1,106 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/backend/contrib/pattern_registry.h + * \brief Functions related to registering and retrieving patterns for + * functions handled by backends. + */ +#ifndef TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_ +#define TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_ + +#include <tvm/relax/dataflow_pattern.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/container/optional.h> +#include <tvm/runtime/object.h> + +namespace tvm { +namespace relax { +namespace backend { + +/*! + * \brief An entry in the pattern registry. This represents a single pattern that + * can be used to identify expressions that can be handled by external + * backends, like CUTLASS and TensorRT. + */ +class PatternRegistryEntryNode : public Object { + public: + /*! + * \brief The name of pattern. Usually it starts with the name of backend, like + * 'cutlass.matmul'. + */ + String name; + /*! + * \brief The dataflow pattern that will be used to match expressions that can + * be handled by external backends. + */ + DFPattern pattern; + /*! + * \brief The mapping from arg name to its pattern. It can be used to extract + * arg expression from match result. All DFPattern in this map should be part of + * the `pattern`. + */ + Map<String, DFPattern> arg_patterns; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("pattern", &pattern); + v->Visit("arg_patterns", &arg_patterns); + } + + static constexpr const char* _type_key = "relax.backend.PatternRegistryEntry"; + TVM_DECLARE_FINAL_OBJECT_INFO(PatternRegistryEntryNode, Object); +}; + +class PatternRegistryEntry : public ObjectRef { + public: + PatternRegistryEntry(String name, DFPattern pattern, Map<String, DFPattern> arg_patterns); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternRegistryEntry, ObjectRef, + PatternRegistryEntryNode); +}; + +/*! + * \brief Register patterns which will be used to partition the DataflowBlock + * into subgraphs that are supported by external backends. + * \param patterns Patterns to be registered. Patterns that appear later in the list have + * higher priority when partitioning DataflowBlock. + */ +void RegisterPatterns(Array<PatternRegistryEntry> entries); + +/*! + * \brief Find patterns whose name starts with a particular prefix. + * \param prefx The pattern name prefix. + * \return Matched patterns, ordered by priority from high to low. + */ +Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix); + +/*! + * \brief Find the pattern with a particular name. + * \param name The pattern name. + * \return The matched pattern. NullOpt if not found. + */ +Optional<PatternRegistryEntry> GetPattern(const String& name); + +} // namespace backend +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_ diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 5556d1e5d9..673155342c 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -23,7 +23,7 @@ import pytest import tvm import tvm.testing from tvm import relax, relay -from tvm.relax.dpl import make_fused_bias_activation_pattern, make_matmul_pattern +from tvm.relax.backend import get_patterns_with_prefix from tvm.script import relax as R @@ -219,7 +219,11 @@ cutlass_enabled = pytest.mark.skipif( pytestmark = [cutlass_enabled] -def get_result_with_relax_cutlass_offload(mod, patterns: List[Tuple], *args): +def get_result_with_relax_cutlass_offload(mod, *args): + patterns = [(entry.name, entry.pattern) for entry in get_patterns_with_prefix("cutlass")] + + assert len(patterns) != 0, "Cannot find cutlass patterns" + seq = tvm.transform.Sequential( [ relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True), @@ -243,15 +247,7 @@ def test_conv2d_offload(): weight = np.random.randn(32, 3, 3, 16).astype("float16") bias = np.random.randn(1, 1, 1, 32).astype("float16") - patterns = [ - ( - "cutlass.conv2d_bias_relu", - make_fused_bias_activation_pattern( - "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu" - ), - ) - ] - out = get_result_with_relax_cutlass_offload(Conv2dBiasReLU, patterns, data, weight, bias) + out = get_result_with_relax_cutlass_offload(Conv2dBiasReLU, data, weight, bias) ref_relay_expr = get_relay_conv2d_bias_relu(data.shape, weight.shape) ref = get_relay_ref(ref_relay_expr, data, weight, bias) @@ -327,17 +323,8 @@ def matmul_bias(matmul_size, target_dtype): def test_matmul_offload(matmul_x, matmul_y): x, y = matmul_x, matmul_y - patterns = [ - ( - "cutlass.matmul", - make_matmul_pattern( - with_bias=False, - ), - ), - ] - mod = get_relax_matmul_module(x, y) - out = get_result_with_relax_cutlass_offload(mod, patterns, x, y) + out = get_result_with_relax_cutlass_offload(mod, x, y) ref_relay_expr = get_relay_matmul(x.shape, y.shape[::-1]) ref = get_relay_ref(ref_relay_expr, x, y.transpose()) @@ -347,16 +334,8 @@ def test_matmul_offload(matmul_x, matmul_y): def test_matmul_bias_offload(matmul_x, matmul_y, matmul_bias): x, y, bias = matmul_x, matmul_y, matmul_bias - patterns = [ - ( - "cutlass.matmul_bias", - make_matmul_pattern( - with_bias=True, - ), - ), - ] mod = get_relax_matmul_module(x, y, with_bias=True) - out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias) + out = get_result_with_relax_cutlass_offload(mod, x, y, bias) ref_relay_expr = get_relay_matmul_bias(x.shape, y.shape[::-1]) ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) @@ -367,17 +346,8 @@ def test_matmul_bias_offload(matmul_x, matmul_y, matmul_bias): def test_matmul_bias_relu_offload(matmul_x, matmul_y, matmul_bias): x, y, bias = matmul_x, matmul_y, matmul_bias - patterns = [ - ( - "cutlass.matmul_bias_relu", - make_matmul_pattern( - with_bias=True, - activation="relax.nn.relu", - ), - ), - ] mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.relu) - out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias) + out = get_result_with_relax_cutlass_offload(mod, x, y, bias) ref_relay_expr = get_relay_matmul_bias_relu(x.shape, y.shape[::-1]) ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) @@ -388,17 +358,8 @@ def test_matmul_bias_relu_offload(matmul_x, matmul_y, matmul_bias): def test_matmul_bias_gelu_offload(matmul_x, matmul_y, matmul_bias): x, y, bias = matmul_x, matmul_y, matmul_bias - patterns = [ - ( - "cutlass.matmul_bias_gelu", - make_matmul_pattern( - with_bias=True, - activation="relax.nn.gelu", - ), - ), - ] mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.gelu) - out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias) + out = get_result_with_relax_cutlass_offload(mod, x, y, bias) ref_relay_expr = get_relay_matmul_bias_gelu(x.shape, y.shape[::-1]) ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias) @@ -411,11 +372,7 @@ def test_kernel_sharing(): weight1_np = np.random.randn(16, 3, 3, 16).astype("float16") weight2_np = np.random.randn(16, 3, 3, 16).astype("float16") - pat = make_fused_bias_activation_pattern("relax.nn.conv2d", with_bias=False, activation=None) - - out = get_result_with_relax_cutlass_offload( - Conv2dx2, [("cutlass.conv2d", pat)], data_np, weight1_np, weight2_np - ) + out = get_result_with_relax_cutlass_offload(Conv2dx2, data_np, weight1_np, weight2_np) relay_expr = get_relay_conv2d_relu_x2(data_np.shape, weight1_np.shape) ref = get_relay_ref(relay_expr, data_np, weight1_np, weight2_np)