This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 4e03d85514 [Unity][BYOC]Add relax backend pattern registry (#14106)
4e03d85514 is described below

commit 4e03d85514191996bb5c1fabc1c8e3463efffa2c
Author: Lite Ye <yelite...@gmail.com>
AuthorDate: Fri Feb 24 04:16:47 2023 -0500

    [Unity][BYOC]Add relax backend pattern registry (#14106)
    
    * Add relax backend pattern registry
    
    * Add doc
---
 CMakeLists.txt                               |   1 +
 python/tvm/relax/backend/__init__.py         |  20 +++++
 python/tvm/relax/backend/_ffi_api.py         |  21 +++++
 python/tvm/relax/backend/contrib/__init__.py |  20 +++++
 python/tvm/relax/backend/contrib/cutlass.py  |  90 +++++++++++++++++++
 python/tvm/relax/backend/pattern_registry.py | 125 +++++++++++++++++++++++++++
 python/tvm/relax/backend/patterns.py         | 115 ++++++++++++++++++++++++
 python/tvm/relax/dpl/pattern.py              |  27 ++----
 src/relax/backend/pattern_registry.cc        |  82 ++++++++++++++++++
 src/relax/backend/pattern_registry.h         | 106 +++++++++++++++++++++++
 tests/python/relax/test_codegen_cutlass.py   |  67 +++-----------
 11 files changed, 598 insertions(+), 76 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 18be118832..22e82e2fb7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -295,6 +295,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
     src/relax/transform/*.cc
     src/relax/backend/vm/*.cc
     src/relax/backend/task_extraction.cc
+    src/relax/backend/pattern_registry.cc
     src/relax/utils.cc
     )
 
diff --git a/python/tvm/relax/backend/__init__.py 
b/python/tvm/relax/backend/__init__.py
new file mode 100644
index 0000000000..c3786591e3
--- /dev/null
+++ b/python/tvm/relax/backend/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relax backends"""
+
+from . import contrib
+from .pattern_registry import get_pattern, get_patterns_with_prefix
diff --git a/python/tvm/relax/backend/_ffi_api.py 
b/python/tvm/relax/backend/_ffi_api.py
new file mode 100644
index 0000000000..d1378b2eac
--- /dev/null
+++ b/python/tvm/relax/backend/_ffi_api.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI API for Relax backend."""
+
+import tvm._ffi
+
+tvm._ffi._init_api("relax.backend", __name__)
diff --git a/python/tvm/relax/backend/contrib/__init__.py 
b/python/tvm/relax/backend/contrib/__init__.py
new file mode 100644
index 0000000000..a094c97d24
--- /dev/null
+++ b/python/tvm/relax/backend/contrib/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""External backend codegen modules for Relax."""
+
+from .cutlass import partition_for_cutlass
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
new file mode 100644
index 0000000000..20cf57a40a
--- /dev/null
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern table for CUTLASS backend"""
+
+from tvm.relax import transform
+
+from ..pattern_registry import get_patterns_with_prefix, register_patterns
+from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern
+
+register_patterns(
+    [
+        (
+            "cutlass.conv2d",
+            make_fused_bias_activation_pattern(
+                "relax.nn.conv2d",
+                with_bias=False,
+                activation=None,
+            ),
+        ),
+        (
+            "cutlass.conv2d_bias_relu",
+            make_fused_bias_activation_pattern(
+                "relax.nn.conv2d",
+                with_bias=True,
+                activation="relax.nn.relu",
+            ),
+        ),
+        (
+            "cutlass.matmul",
+            make_matmul_pattern(
+                with_bias=False,
+            ),
+        ),
+        (
+            "cutlass.matmul_bias",
+            make_matmul_pattern(
+                with_bias=True,
+            ),
+        ),
+        (
+            "cutlass.matmul_bias_relu",
+            make_matmul_pattern(
+                with_bias=True,
+                activation="relax.nn.relu",
+            ),
+        ),
+        (
+            "cutlass.matmul_bias_gelu",
+            make_matmul_pattern(
+                with_bias=True,
+                activation="relax.nn.gelu",
+            ),
+        ),
+    ]
+)
+
+
+def partition_for_cutlass(mod):
+    """
+    Partition the input module into CUTLASS-supported subgraphs.
+
+    Parameters
+    ----------
+    mod: tvm.IRModule
+        The IRModule to be partitioned.
+
+    Returns
+    -------
+    mod: tvm.IRModule
+        The resulting IRModule, containing partitioned subgraphs to be
+        compiled by the CUTLASS backend.
+    """
+
+    cutlass_patterns = get_patterns_with_prefix("cutlass")
+    return transform.FuseOpsByPattern(cutlass_patterns, 
annotate_codegen=True)(mod)
diff --git a/python/tvm/relax/backend/pattern_registry.py 
b/python/tvm/relax/backend/pattern_registry.py
new file mode 100644
index 0000000000..0016de0a50
--- /dev/null
+++ b/python/tvm/relax/backend/pattern_registry.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern registry for BYOC backends"""
+
+from typing import List, Mapping, Optional, Tuple, Union
+
+import tvm
+from tvm.relax.dpl import DFPattern
+from tvm.runtime import Object
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("relax.backend.PatternRegistryEntry")
+class PatternRegistryEntry(Object):
+    """
+    An entry in the pattern registry. This represents a single pattern that
+    can be used to identify expressions that can be handled by external
+    backends, like CUTLASS and TensorRT.
+
+    Parameters
+    ----------
+    name: str
+        The name of pattern. Usually it starts with the name of backend, like 
'cutlass.matmul'.
+
+    pattern: DFPattern
+        The dataflow pattern that will be used to match expressions that can 
be handled
+        by external backends.
+
+    arg_patterns: Mapping[str, DFPattern]
+        The mapping from arg name to its pattern. It can be used to extract 
arg expression
+        from match result. All DFPattern in this map should be part of the 
`pattern`.
+    """
+
+    name: str
+    pattern: DFPattern
+    arg_patterns: Mapping[str, DFPattern]
+
+    def __init__(self, name: str, pattern: DFPattern, arg_patterns: 
Mapping[str, DFPattern]):
+        self.__init_handle_by_constructor__(
+            _ffi_api.PatternRegistryEntry, name, pattern, arg_patterns  # 
type: ignore
+        )
+
+
+Pattern = Union[
+    PatternRegistryEntry,
+    Tuple[str, DFPattern],
+    Tuple[str, Tuple[DFPattern, Mapping[str, DFPattern]]],
+]
+
+
+def register_patterns(patterns: List[Pattern]):
+    """
+    Register patterns which will be used to partition the DataflowBlock into
+    subgraphs that are supported by external backends.
+
+    Parameters
+    ----------
+    patterns: List[Pattern]
+        Patterns to be registered. Patterns that appear later in the list have
+        higher priority when partitioning DataflowBlock.
+    """
+    entries = []
+    for item in patterns:
+        if isinstance(item, PatternRegistryEntry):
+            entries.append(item)
+        elif isinstance(item, tuple):
+            name, pattern_or_tuple = item
+            if isinstance(pattern_or_tuple, tuple):
+                pattern, arg_patterns = pattern_or_tuple
+            else:
+                pattern, arg_patterns = pattern_or_tuple, {}
+            entries.append(PatternRegistryEntry(name, pattern, arg_patterns))
+        else:
+            raise TypeError(f"Cannot register type {type(pattern)} as pattern")
+    _ffi_api.RegisterPatterns(entries)
+
+
+def get_patterns_with_prefix(prefix: str) -> List[PatternRegistryEntry]:
+    """
+    Get a list of patterns whose names startwith `prefix`.
+
+    Parameters
+    ----------
+    prefix: str
+        The prefix of pattern name.
+
+    Returns
+    -------
+    patterns: PatternRegistryEntry
+        Matched patterns, ordered by priority from high to low.
+    """
+    return _ffi_api.GetPatternsWithPrefix(prefix)
+
+
+def get_pattern(name: str) -> Optional[PatternRegistryEntry]:
+    """
+    Find the pattern with a particular name.
+
+    Parameters
+    ----------
+    name: str
+        The pattern name.
+
+    Returns
+    -------
+    pattern: Optional[PatternRegistryEntry]
+        The matched pattern. Returns None if such pattern is not found.
+    """
+    return _ffi_api.GetPattern(name)
diff --git a/python/tvm/relax/backend/patterns.py 
b/python/tvm/relax/backend/patterns.py
new file mode 100644
index 0000000000..2f744af660
--- /dev/null
+++ b/python/tvm/relax/backend/patterns.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Common patterns used in BYOC"""
+
+from typing import Dict, Mapping, Tuple
+
+from tvm.relax.dpl.pattern import DFPattern, is_op, wildcard
+
+
+def _with_bias_activation_pattern(
+    out: DFPattern,
+    args: Dict[str, DFPattern],
+    with_bias: bool = False,
+    activation: str = None,
+) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
+    if with_bias:
+        args["bias"] = bias = wildcard()
+        out = is_op("relax.add")(out, bias)
+
+    if activation:
+        out = is_op(activation)(out)
+
+    return out, args
+
+
+def make_fused_bias_activation_pattern(
+    op_name: str,
+    with_bias: bool = False,
+    activation: str = None,
+) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
+    """
+    A simple utility to create patterns for an operation fused with bias 
addition and activation.
+
+    Parameters
+    ----------
+    op_name: str
+        The name of a Relax op, such as "relax.nn.conv2d"
+
+    with_bias: bool
+        Whether or not to include bias addition
+
+    activation: str
+        The name of an activation Relax op, such as "relax.nn.relu"
+
+    Returns
+    -------
+    pattern: DFPattern
+        The resulting pattern describing a fused operation
+
+    args: Mapping[str, DFPattern]
+        The mapping from arg name to its pattern. It can be used to extract
+        arg expression from match result.
+    """
+    lhs = wildcard()
+    rhs = wildcard()
+    args = {"lhs": lhs, "rhs": rhs}
+    out = is_op(op_name)(lhs, rhs)
+
+    return _with_bias_activation_pattern(out, args, with_bias, activation)
+
+
+def make_matmul_pattern(
+    with_bias: bool = False,
+    activation: str = None,
+    transposed_rhs: bool = False,
+) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
+    """
+    Create pattern for matrix multiplication.
+
+    Parameters
+    ----------
+    with_bias: bool
+        Whether or not to include bias addition
+
+    activation: str
+        The name of an activation Relax op, such as "relax.nn.relu"
+
+    transposed_rhs: bool
+        Whether the right hand side of multiplication is transposed.
+
+    Returns
+    -------
+    pattern: DFPattern
+        The resulting pattern describing a matrix multiplication.
+
+    args: Mapping[str, DFPattern]
+        The mapping from arg name to its pattern. It can be used to extract
+        arg expression from match result.
+    """
+
+    lhs = wildcard()
+    rhs = wildcard()
+    args = {"lhs": lhs, "rhs": rhs}
+
+    if transposed_rhs:
+        rhs = is_op("relax.permute_dims")(rhs)
+
+    out = is_op("relax.matmul")(lhs, rhs)
+
+    return _with_bias_activation_pattern(out, args, with_bias, activation)
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 44faa0c93a..9e1963f7ed 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -1046,17 +1046,6 @@ def _only_used_by(
     return ffi.only_used_by(lhs, rhs, index)  # type: ignore
 
 
-def _add_bias_activation_pattern(out, with_bias=False, activation=None):
-    if with_bias:
-        bias = wildcard()
-        out = is_op("relax.add")(out, bias)
-
-    if activation:
-        return is_op(activation)(out)
-
-    return out
-
-
 def make_fused_bias_activation_pattern(op_name, with_bias=False, 
activation=None):
     """
     A simple utility to create patterns for an operation fused with bias 
addition and activation.
@@ -1081,15 +1070,11 @@ def make_fused_bias_activation_pattern(op_name, 
with_bias=False, activation=None
     rhs = wildcard()
     out = is_op(op_name)(lhs, rhs)
 
-    return _add_bias_activation_pattern(out, with_bias, activation)
-
+    if with_bias:
+        bias = wildcard()
+        out = is_op("relax.add")(out, bias)
 
-def make_matmul_pattern(with_bias=False, activation=None, transposed_b=False):
-    lhs = wildcard()
-    if transposed_b:
-        rhs = is_op("relax.permute_dims")(wildcard())
-    else:
-        rhs = wildcard()
-    out = is_op("relax.matmul")(lhs, rhs)
+    if activation:
+        return is_op(activation)(out)
 
-    return _add_bias_activation_pattern(out, with_bias, activation)
+    return out
diff --git a/src/relax/backend/pattern_registry.cc 
b/src/relax/backend/pattern_registry.cc
new file mode 100644
index 0000000000..3ca7973365
--- /dev/null
+++ b/src/relax/backend/pattern_registry.cc
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./pattern_registry.h"
+
+#include "../../support/utils.h"
+
+namespace tvm {
+namespace relax {
+namespace backend {
+
+PatternRegistryEntry::PatternRegistryEntry(String name, DFPattern pattern,
+                                           Map<String, DFPattern> 
arg_patterns) {
+  ObjectPtr<PatternRegistryEntryNode> n = 
make_object<PatternRegistryEntryNode>();
+  n->name = std::move(name);
+  n->pattern = std::move(pattern);
+  n->arg_patterns = std::move(arg_patterns);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(PatternRegistryEntryNode);
+
+static std::vector<PatternRegistryEntry>* GetRegistryTable() {
+  static std::vector<PatternRegistryEntry> table;
+  return &table;
+}
+
+void RegisterPatterns(Array<PatternRegistryEntry> entries) {
+  auto* table = GetRegistryTable();
+  for (const auto& entry : entries) {
+    table->push_back(entry);
+  }
+}
+
+Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix) {
+  auto* table = GetRegistryTable();
+  Array<PatternRegistryEntry> result;
+  for (auto it = table->rbegin(); it != table->rend(); ++it) {
+    if (support::StartsWith((*it)->name, prefix.data())) {
+      result.push_back(*it);
+    }
+  }
+  return result;
+}
+
+Optional<PatternRegistryEntry> GetPattern(const String& pattern_name) {
+  auto* table = GetRegistryTable();
+  for (auto it = table->rbegin(); it != table->rend(); ++it) {
+    if ((*it)->name == pattern_name) {
+      return *it;
+    }
+  }
+  return NullOpt;
+}
+
+TVM_REGISTER_GLOBAL("relax.backend.PatternRegistryEntry")
+    .set_body_typed([](String name, DFPattern pattern, Map<String, DFPattern> 
arg_patterns) {
+      return PatternRegistryEntry(name, pattern, arg_patterns);
+    });
+TVM_REGISTER_GLOBAL("relax.backend.RegisterPatterns").set_body_typed(RegisterPatterns);
+TVM_REGISTER_GLOBAL("relax.backend.GetPatternsWithPrefix").set_body_typed(GetPatternsWithPrefix);
+TVM_REGISTER_GLOBAL("relax.backend.GetPattern").set_body_typed(GetPattern);
+
+}  // namespace backend
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/backend/pattern_registry.h 
b/src/relax/backend/pattern_registry.h
new file mode 100644
index 0000000000..2e199a2bb1
--- /dev/null
+++ b/src/relax/backend/pattern_registry.h
@@ -0,0 +1,106 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/backend/contrib/pattern_registry.h
+ * \brief Functions related to registering and retrieving patterns for
+ * functions handled by backends.
+ */
+#ifndef TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_
+#define TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_
+
+#include <tvm/relax/dataflow_pattern.h>
+#include <tvm/relax/expr.h>
+#include <tvm/runtime/container/optional.h>
+#include <tvm/runtime/object.h>
+
+namespace tvm {
+namespace relax {
+namespace backend {
+
+/*!
+ * \brief An entry in the pattern registry. This represents a single pattern 
that
+ * can be used to identify expressions that can be handled by external
+ * backends, like CUTLASS and TensorRT.
+ */
+class PatternRegistryEntryNode : public Object {
+ public:
+  /*!
+   * \brief The name of pattern. Usually it starts with the name of backend, 
like
+   * 'cutlass.matmul'.
+   */
+  String name;
+  /*!
+   * \brief The dataflow pattern that will be used to match expressions that 
can
+   * be handled by external backends.
+   */
+  DFPattern pattern;
+  /*!
+   * \brief The mapping from arg name to its pattern. It can be used to extract
+   * arg expression from match result. All DFPattern in this map should be 
part of
+   * the `pattern`.
+   */
+  Map<String, DFPattern> arg_patterns;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("name", &name);
+    v->Visit("pattern", &pattern);
+    v->Visit("arg_patterns", &arg_patterns);
+  }
+
+  static constexpr const char* _type_key = 
"relax.backend.PatternRegistryEntry";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PatternRegistryEntryNode, Object);
+};
+
+class PatternRegistryEntry : public ObjectRef {
+ public:
+  PatternRegistryEntry(String name, DFPattern pattern, Map<String, DFPattern> 
arg_patterns);
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternRegistryEntry, ObjectRef,
+                                            PatternRegistryEntryNode);
+};
+
+/*!
+ * \brief Register patterns which will be used to partition the DataflowBlock
+ *        into subgraphs that are supported by external backends.
+ * \param patterns Patterns to be registered. Patterns that appear later in 
the list have
+ *        higher priority when partitioning DataflowBlock.
+ */
+void RegisterPatterns(Array<PatternRegistryEntry> entries);
+
+/*!
+ * \brief Find patterns whose name starts with a particular prefix.
+ * \param prefx The pattern name prefix.
+ * \return Matched patterns, ordered by priority from high to low.
+ */
+Array<PatternRegistryEntry> GetPatternsWithPrefix(const String& prefix);
+
+/*!
+ * \brief Find the pattern with a particular name.
+ * \param name The pattern name.
+ * \return The matched pattern. NullOpt if not found.
+ */
+Optional<PatternRegistryEntry> GetPattern(const String& name);
+
+}  // namespace backend
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_BACKEND_PATTERN_REGISTRY_H_
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 5556d1e5d9..673155342c 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -23,7 +23,7 @@ import pytest
 import tvm
 import tvm.testing
 from tvm import relax, relay
-from tvm.relax.dpl import make_fused_bias_activation_pattern, 
make_matmul_pattern
+from tvm.relax.backend import get_patterns_with_prefix
 from tvm.script import relax as R
 
 
@@ -219,7 +219,11 @@ cutlass_enabled = pytest.mark.skipif(
 pytestmark = [cutlass_enabled]
 
 
-def get_result_with_relax_cutlass_offload(mod, patterns: List[Tuple], *args):
+def get_result_with_relax_cutlass_offload(mod, *args):
+    patterns = [(entry.name, entry.pattern) for entry in 
get_patterns_with_prefix("cutlass")]
+
+    assert len(patterns) != 0, "Cannot find cutlass patterns"
+
     seq = tvm.transform.Sequential(
         [
             relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True),
@@ -243,15 +247,7 @@ def test_conv2d_offload():
     weight = np.random.randn(32, 3, 3, 16).astype("float16")
     bias = np.random.randn(1, 1, 1, 32).astype("float16")
 
-    patterns = [
-        (
-            "cutlass.conv2d_bias_relu",
-            make_fused_bias_activation_pattern(
-                "relax.nn.conv2d", with_bias=True, activation="relax.nn.relu"
-            ),
-        )
-    ]
-    out = get_result_with_relax_cutlass_offload(Conv2dBiasReLU, patterns, 
data, weight, bias)
+    out = get_result_with_relax_cutlass_offload(Conv2dBiasReLU, data, weight, 
bias)
 
     ref_relay_expr = get_relay_conv2d_bias_relu(data.shape, weight.shape)
     ref = get_relay_ref(ref_relay_expr, data, weight, bias)
@@ -327,17 +323,8 @@ def matmul_bias(matmul_size, target_dtype):
 def test_matmul_offload(matmul_x, matmul_y):
     x, y = matmul_x, matmul_y
 
-    patterns = [
-        (
-            "cutlass.matmul",
-            make_matmul_pattern(
-                with_bias=False,
-            ),
-        ),
-    ]
-
     mod = get_relax_matmul_module(x, y)
-    out = get_result_with_relax_cutlass_offload(mod, patterns, x, y)
+    out = get_result_with_relax_cutlass_offload(mod, x, y)
     ref_relay_expr = get_relay_matmul(x.shape, y.shape[::-1])
     ref = get_relay_ref(ref_relay_expr, x, y.transpose())
 
@@ -347,16 +334,8 @@ def test_matmul_offload(matmul_x, matmul_y):
 def test_matmul_bias_offload(matmul_x, matmul_y, matmul_bias):
     x, y, bias = matmul_x, matmul_y, matmul_bias
 
-    patterns = [
-        (
-            "cutlass.matmul_bias",
-            make_matmul_pattern(
-                with_bias=True,
-            ),
-        ),
-    ]
     mod = get_relax_matmul_module(x, y, with_bias=True)
-    out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias)
+    out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
 
     ref_relay_expr = get_relay_matmul_bias(x.shape, y.shape[::-1])
     ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
@@ -367,17 +346,8 @@ def test_matmul_bias_offload(matmul_x, matmul_y, 
matmul_bias):
 def test_matmul_bias_relu_offload(matmul_x, matmul_y, matmul_bias):
     x, y, bias = matmul_x, matmul_y, matmul_bias
 
-    patterns = [
-        (
-            "cutlass.matmul_bias_relu",
-            make_matmul_pattern(
-                with_bias=True,
-                activation="relax.nn.relu",
-            ),
-        ),
-    ]
     mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.relu)
-    out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias)
+    out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
 
     ref_relay_expr = get_relay_matmul_bias_relu(x.shape, y.shape[::-1])
     ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
@@ -388,17 +358,8 @@ def test_matmul_bias_relu_offload(matmul_x, matmul_y, 
matmul_bias):
 def test_matmul_bias_gelu_offload(matmul_x, matmul_y, matmul_bias):
     x, y, bias = matmul_x, matmul_y, matmul_bias
 
-    patterns = [
-        (
-            "cutlass.matmul_bias_gelu",
-            make_matmul_pattern(
-                with_bias=True,
-                activation="relax.nn.gelu",
-            ),
-        ),
-    ]
     mod = get_relax_matmul_module(x, y, with_bias=True, activation=R.nn.gelu)
-    out = get_result_with_relax_cutlass_offload(mod, patterns, x, y, bias)
+    out = get_result_with_relax_cutlass_offload(mod, x, y, bias)
 
     ref_relay_expr = get_relay_matmul_bias_gelu(x.shape, y.shape[::-1])
     ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
@@ -411,11 +372,7 @@ def test_kernel_sharing():
     weight1_np = np.random.randn(16, 3, 3, 16).astype("float16")
     weight2_np = np.random.randn(16, 3, 3, 16).astype("float16")
 
-    pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
with_bias=False, activation=None)
-
-    out = get_result_with_relax_cutlass_offload(
-        Conv2dx2, [("cutlass.conv2d", pat)], data_np, weight1_np, weight2_np
-    )
+    out = get_result_with_relax_cutlass_offload(Conv2dx2, data_np, weight1_np, 
weight2_np)
 
     relay_expr = get_relay_conv2d_relu_x2(data_np.shape, weight1_np.shape)
     ref = get_relay_ref(relay_expr, data_np, weight1_np, weight2_np)

Reply via email to