This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 0f6c4674c19f6ecc82175d686559ae59db42dd08
Author: masahi <masahi...@gmail.com>
AuthorDate: Mon Feb 20 17:08:44 2023 +0900

    [Unity][BYOC] Add pattern-based partitioning pass (#14054)
    
    This adds a new pass, FuseOpsByPattern, which applies pattern matching to 
each function in the given module, and groups matched expressions into a new 
function. The end result is similar to FuseOps, but fusion is driven completely 
by
    the provided patterns. The implementation also reuses OperatorFusor used by 
FuseOps to create grouped functions from partitioned groups, further 
illustrating the similarity between the two passes.
    
    The new pass will serve the same role the MergeComposite pass plays in 
Relay BYOC - grouped functions are annotated with the "composite" attribute to 
denote what operations a given function consists of, and offloaded to external 
backends. But it can be also useful in non-BYOC settings, for example to 
support advanced fusion that the op-kind based one doesn't handle (fused MHA, 
conv2d / gemm + reduction fusion, etc).
    
    The original PR: https://github.com/tlc-pack/relax/pull/366
---
 python/tvm/relax/transform/transform.py            |  37 +-
 src/relax/transform/fuse_ops.cc                    | 199 +++++++++
 .../relax/test_transform_fuse_ops_by_pattern.py    | 464 +++++++++++++++++++++
 3 files changed, 699 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 1f14823b5a..bf90ef0b09 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -19,7 +19,7 @@
 import functools
 import inspect
 import types
-from typing import Callable, Dict, Union, Optional, List
+from typing import Callable, Dict, Union, Optional, List, Tuple
 import numpy as np  # type: ignore
 import tvm.ir
 from tvm.runtime import NDArray
@@ -241,6 +241,41 @@ def FuseTIR() -> tvm.ir.transform.Pass:
     return _ffi_api.FuseTIR()  # type: ignore
 
 
+def FuseOpsByPattern(
+    patterns: List[Tuple], annotate_codegen: bool = False
+) -> tvm.ir.transform.Pass:
+    """Apply pattern matching to each function in the given module, and group 
matched expressions
+    into a new function.
+
+    The end result is similar to FuseOps, but fusion is driven completely by 
the provided patterns.
+
+    Parameters
+    ----------
+    patterns : List[Tuple[str, DFPattern]]
+        The patterns to detect. The order of the patterns determines the order 
of priority in which
+        they are matched. Higher-priority patterns should come earlier in the 
list.
+        The string is the name of the corresponding pattern. It becomes the 
value of the kComposite
+        attribute of a fused function after a successful matching.
+
+    annotate_codegen : bool
+        If True, wrap each created composite function with another function, 
whose body consists
+        only of a call to the composite function, and annotate the outer 
function with "Codegen"
+        and "global_symbol" attributes. The "Codegen" attribute is set as the 
prefix of the
+        corresponding pattern name. For example, "dnnl" if the pattern name is 
"dnnl.conv2d_relu".
+
+        This must be True if the created composite functions are intended to 
be offloaded to
+        an external backend without using the MergeCompositeFunctions pass.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass for pattern-based fusion.
+
+    """
+    pattern_names, df_patterns = zip(*patterns)
+    return _ffi_api.FuseOpsByPattern(pattern_names, df_patterns, 
annotate_codegen)  # type: ignore
+
+
 def LegalizeOps(customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = 
None):
     """Legalize high-level operator calls in Relax functions to call_tir
     with corresponding low-level TIR PrimFuncs.
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 0a0209bb87..3b78274cec 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -28,12 +28,15 @@
  */
 
 #include <tvm/relax/analysis.h>
+#include <tvm/relax/dataflow_matcher.h>
+#include <tvm/relax/dataflow_pattern.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/struct_info.h>
 #include <tvm/relax/transform.h>
 #include <tvm/tir/function.h>
 
 #include <optional>
+#include <unordered_map>
 
 #include "../../relay/analysis/graph_partitioner.h"
 #include "../../support/arena.h"
@@ -880,6 +883,188 @@ IRModule FuseOps(IRModule mod, int opt_level, size_t 
max_fuse_depth) {
   return OperatorFusor(mod, graph, groups, /*lift_constants*/ 
true).Transform();
 }
 
+IRModule MakeGroupedFunctions(
+    IRModule mod, const std::unordered_map<const Object*, 
GraphPartitioner::Group*>& partition,
+    bool lift_constants) {
+  return OperatorFusor(mod, partition, lift_constants).Transform();
+}
+
+static Map<Expr, Var> GetBindingInverse(const Map<Var, Expr>& binding) {
+  Map<Expr, Var> value_to_bound_var;
+  for (const auto& [var, val] : binding) {
+    value_to_bound_var.Set(val, var);
+  }
+  return value_to_bound_var;
+}
+
+/*! \brief Create a "partitioning", a map from interior / leaf expr to its 
representative group,
+ * based on the provided pattern. The result can be passed to OperatorFusor 
above to fuse operations
+ * in a group and create a grouped function.
+ */
+class PatternBasedPartitioner : ExprVisitor {
+ public:
+  using Group = GraphPartitioner::Group;
+  using GroupMap = OperatorFusor::GroupMap;
+  using ExprVisitor::VisitExpr_;
+
+  static GroupMap Run(String pattern_name, DFPattern pattern, Expr expr, 
support::Arena* arena) {
+    PatternBasedPartitioner part(pattern_name, pattern, 
AnalyzeVar2Value(expr));
+    // Initialize each expr to have its own group
+    PostOrderVisit(
+        expr, [arena, &part](const Expr& e) { part.group_map_[e.get()] = 
arena->make<Group>(); });
+    part.VisitExpr(expr);
+    return part.group_map_;
+  }
+
+  PatternBasedPartitioner(String pattern_name, DFPattern pattern, const 
Map<Var, Expr>& bindings)
+      : pat_name_(pattern_name),
+        pat_(pattern),
+        bindings_(bindings),
+        value_to_bound_var_(GetBindingInverse(bindings)) {}
+
+  void VisitExpr_(const CallNode* call) override {
+    if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef<Call>(call), 
bindings_)) {
+      // If a match is found, put all matching expressions into the same group.
+      // OperatorFusor also requires that the bound variable be in the same 
group as the RHS value.
+      // Since is_op(...) based pattern only matches against call nodes on the 
right hand side,
+      // we need to take care of groups corresponding to the LHS bound 
variables carefully.
+
+      // In the example below, conv2d + relu pattern would match if the "call" 
variable in this
+      // function points to the relu op. We identify the group corresponding 
to "conv1", and make
+      // it the representative group for relu and conv2d on the RHS and also 
"lv" on the LHS.
+
+      // with R.dataflow():
+      //   lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(...)
+      //   conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
+
+      // parent_group corresponds to the group of "conv1" above.
+      auto parent_group = GetGroupForBoundVar(GetRef<Call>(call));
+      ICHECK(parent_group);
+      parent_group->attrs.Set(attr::kComposite, pat_name_);
+
+      for (const auto& [pat, match] : matches_opt.value()) {
+        ICHECK(group_map_.count(match.get()));
+        // Put all matching call nodes into the parent group.
+        if (pat->IsInstance<CallPatternNode>() && match != GetRef<Call>(call)) 
{
+          AddToGroup(match, parent_group);
+          // Put the bound variable on the LHS into the same parent group.
+          AddToGroup(value_to_bound_var_[match], parent_group);
+        }
+      }
+    }
+  }
+
+ private:
+  void AddToGroup(Expr e, Group* to) {
+    if (group_map_[e.get()] != to) {
+      --group_map_[e.get()]->num_nodes;
+      group_map_[e.get()]->parent = to;
+      ++to->num_nodes;
+    }
+  }
+
+  Group* GetGroupForBoundVar(Expr e) {
+    ICHECK(value_to_bound_var_.count(e));
+    auto bound_var = value_to_bound_var_[e];
+    ICHECK(group_map_.count(bound_var.get()));
+    return group_map_[bound_var.get()]->FindRoot();
+  }
+
+  String pat_name_;
+  DFPattern pat_;
+  Map<Var, Expr> bindings_;
+  Map<Expr, Var> value_to_bound_var_;
+  GroupMap group_map_;
+};
+
+/*!
+ * \brief Wrap each created composite function with another function, whose 
body consists
+ * only of a call to the composite function, and annotate the outer function  
with kCodegen
+ * and kGlobalSymbol attributes.
+ */
+class CompositeFunctionAnnotator : public ExprMutator {
+ public:
+  explicit CompositeFunctionAnnotator(IRModule mod) : ExprMutator(mod) {}
+  using ExprMutator::VisitExpr_;
+
+  IRModule Run() {
+    auto mod = builder_->GetContextIRModule();
+    auto gvar = mod->GetGlobalVar("main");
+    auto func = Downcast<Function>(mod->Lookup(gvar));
+    auto new_func =
+        Function(func->params, VisitExpr(func->body), func->ret_struct_info, 
func->attrs);
+    builder_->UpdateFunction(gvar, new_func);
+    return builder_->GetContextIRModule();
+  }
+
+  Expr VisitExpr_(const CallNode* call_node) final {
+    if (auto const* gvar = call_node->op.as<GlobalVarNode>()) {
+      if (auto it = gvar_map_.find(gvar); it != gvar_map_.end()) {
+        return Call(it->second, call_node->args);
+      }
+      auto func = 
builder_->GetContextIRModule()->Lookup(GetRef<GlobalVar>(gvar));
+      if (auto composite_name = func->GetAttr<String>(attr::kComposite)) {
+        auto new_func = Downcast<Function>(VisitExpr(func));
+        auto codegen_name = GetCodegenName(composite_name.value());
+        auto gsymbol = gvar->name_hint + "_" + codegen_name;
+        new_func = WithAttrs(new_func,
+                             {{attr::kCodegen, codegen_name}, 
{tvm::attr::kGlobalSymbol, gsymbol}});
+        builder_->GetContextIRModule()->Remove(GetRef<GlobalVar>(gvar));
+        auto new_gvar = builder_->AddFunction(new_func, gsymbol);
+        gvar_map_[gvar] = new_gvar;
+        return Call(new_gvar, call_node->args);
+      }
+    }
+    return ExprMutator::VisitExpr_(call_node);
+  }
+
+  Expr VisitExpr_(const FunctionNode* func_node) final {
+    auto f_inner = ExprMutator::VisitExpr_(func_node);
+    auto composite_name = func_node->GetAttr<String>(attr::kComposite);
+    ICHECK(composite_name);
+
+    Array<Var> param_vars;
+    Array<Expr> params;
+
+    for (auto v : func_node->params) {
+      Var new_v(v->name_hint(), GetStructInfo(v));
+      param_vars.push_back(new_v);
+      params.push_back(new_v);
+    }
+
+    return Function(param_vars, Call(f_inner, params), 
func_node->ret_struct_info);
+  }
+
+ private:
+  String GetCodegenName(const std::string& composite_name) {
+    auto delim_pos = composite_name.find(".");
+    ICHECK(delim_pos != std::string::npos) << "The pattern name for a 
composite function should "
+                                              "start with a compiler name 
followed by period.";
+    return composite_name.substr(0, delim_pos);
+  }
+
+  /*! \brief A map from old global vars to their replacements. */
+  std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_;
+};
+
+IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names,
+                          const tvm::Array<DFPattern>& patterns, IRModule mod,
+                          bool annotate_codegen) {
+  support::Arena arena;
+  for (size_t i = 0; i < pattern_names.size(); ++i) {
+    OperatorFusor::GroupMap group_map;
+    for (const auto& entry : mod->functions) {
+      auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i], 
entry.second, &arena);
+      group_map.insert(map.begin(), map.end());
+    }
+    mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ false);
+  }
+  if (annotate_codegen) {
+    return CompositeFunctionAnnotator(mod).Run();
+  }
+  return mod;
+}
+
 namespace transform {
 
 Pass FuseOps(int fuse_opt_level) {
@@ -897,6 +1082,20 @@ Pass FuseOps(int fuse_opt_level) {
 
 TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
 
+Pass FuseOpsByPattern(const tvm::Array<String>& pattern_names,
+                      const tvm::Array<DFPattern>& patterns, bool 
annotate_codegen) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
+      [=](IRModule m, PassContext pc) {
+        return relax::FuseOpsByPattern(pattern_names, patterns, m, 
annotate_codegen);
+      };
+  return CreateModulePass(/*pass_function=*/pass_func,       //
+                          /*opt_level=*/0,                   //
+                          /*pass_name=*/"FuseOpsByPattern",  //
+                          /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern);
+
 }  // namespace transform
 
 }  // namespace relax
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
new file mode 100644
index 0000000000..da5b92fb64
--- /dev/null
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -0,0 +1,464 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import numpy as np
+
+import tvm
+
+from tvm import relax
+from tvm.script import relax as R
+from tvm.relax.dpl.pattern import make_fused_bias_activation_pattern, is_op, 
wildcard
+
+
+@tvm.script.ir_module
+class Conv2dReLU:
+    @R.function
+    def main(
+        data: R.Tensor((1, 64, 56, 56), "float32"),
+        weight1: R.Tensor((64, 64, 3, 3), "float32"),
+    ):
+        with R.dataflow():
+            conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1)))
+            R.output(conv1)
+
+        return conv1
+
+
+@tvm.script.ir_module
+class Conv2dReLU_composite_annotated:
+    @R.function
+    def main(
+        data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+        with R.dataflow():
+            gv: R.Tensor(
+                (1, 64, 56, 56), dtype="float32"
+            ) = fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1)
+            R.output(gv)
+        return gv
+
+    @R.function
+    def fused_relax_nn_conv2d_relax_nn_relu_dnnl(
+        data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+        R.func_attr(
+            {"Codegen": "dnnl", "global_symbol": 
"fused_relax_nn_conv2d_relax_nn_relu_dnnl"}
+        )
+
+        @R.function
+        def gv1(
+            data2: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            weight12: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+            R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+            with R.dataflow():
+                lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+                    data2,
+                    weight12,
+                    padding=[1, 1, 1, 1],
+                )
+                gv2: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
+                R.output(gv2)
+            return gv2
+
+        gv11: R.Tensor((1, 64, 56, 56), dtype="float32") = gv1(data1, weight11)
+        return gv11
+
+
+@tvm.script.ir_module
+class Conv2dReLUx2:
+    @R.function
+    def main(
+        data: R.Tensor((1, 64, 56, 56), "float32"),
+        weight1: R.Tensor((64, 64, 3, 3), "float32"),
+        weight2: R.Tensor((64, 64, 3, 3), "float32"),
+    ):
+        with R.dataflow():
+            conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1)))
+            conv2 = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0)))
+            R.output(conv2)
+
+        return conv2
+
+
+@tvm.script.ir_module
+class Conv2dReLUx2Partitioned:
+    @R.function
+    def main(
+        data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+        with R.dataflow():
+            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = 
fused_relax_nn_conv2d_relax_nn_relu(
+                data, weight1
+            )
+            gv: R.Tensor((1, 64, 54, 54), dtype="float32") = 
fused_relax_nn_conv2d_relax_nn_relu1(
+                lv, weight2
+            )
+            R.output(gv)
+        return gv
+
+    @R.function
+    def fused_relax_nn_conv2d_relax_nn_relu(
+        data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+        R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+        with R.dataflow():
+            lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+                data1, weight11, padding=[1, 1, 1, 1]
+            )
+            gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv1)
+            R.output(gv1)
+        return gv1
+
+    @R.function
+    def fused_relax_nn_conv2d_relax_nn_relu1(
+        conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+        R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+        with R.dataflow():
+            lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+                conv1, weight21, padding=[0, 0, 0, 0]
+            )
+            gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv2)
+            R.output(gv2)
+        return gv2
+
+
+@tvm.script.ir_module
+class Conv2dReLUx2Partitioned_only_conv2d:
+    @R.function
+    def main(
+        data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+        with R.dataflow():
+            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = 
fused_relax_nn_conv2d(data, weight1)
+            conv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv)
+            lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = 
fused_relax_nn_conv2d1(conv1, weight2)
+            conv2d: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1)
+            R.output(conv2d)
+        return conv2d
+
+    @R.function
+    def fused_relax_nn_conv2d(
+        data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+        R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"})
+        with R.dataflow():
+            gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+                data1, weight11, padding=[1, 1, 1, 1]
+            )
+            R.output(gv)
+        return gv
+
+    @R.function
+    def fused_relax_nn_conv2d1(
+        conv11: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+        R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"})
+        with R.dataflow():
+            gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+                conv11, weight21, padding=[0, 0, 0, 0]
+            )
+            R.output(gv1)
+        return gv1
+
+
+@tvm.script.ir_module
+class Conv2dConv2dReLU:
+    @R.function
+    def main(
+        data: R.Tensor((1, 64, 56, 56), "float32"),
+        weight1: R.Tensor((64, 64, 3, 3), "float32"),
+        weight2: R.Tensor((64, 64, 3, 3), "float32"),
+    ):
+        with R.dataflow():
+            conv1 = R.nn.conv2d(data, weight1, padding=(1, 1))
+            conv2d = R.nn.relu(R.nn.conv2d(conv1, weight2, padding=(0, 0)))
+            R.output(conv2d)
+
+        return conv2d
+
+
+@tvm.script.ir_module
+class Conv2dConv2dReLUPartitioned:
+    @R.function
+    def main(
+        data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+        with R.dataflow():
+            lv: R.Tensor((1, 64, 56, 56), dtype="float32") = 
fused_relax_nn_conv2d(data, weight1)
+            gv: R.Tensor((1, 64, 54, 54), dtype="float32") = 
fused_relax_nn_conv2d_relax_nn_relu(
+                lv, weight2
+            )
+            R.output(gv)
+        return gv
+
+    @R.function
+    def fused_relax_nn_conv2d_relax_nn_relu(
+        conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight21: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+        R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+        with R.dataflow():
+            lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+                conv1, weight21, padding=[0, 0, 0, 0]
+            )
+            gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(lv1)
+            R.output(gv1)
+        return gv1
+
+    @R.function
+    def fused_relax_nn_conv2d(
+        data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight11: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+        R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d"})
+        with R.dataflow():
+            gv2: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+                data1, weight11, padding=[1, 1, 1, 1]
+            )
+            R.output(gv2)
+        return gv2
+
+
+@tvm.script.ir_module
+class BranchTupleOutput:
+    @R.function
+    def main(
+        data: R.Tensor((1, 64, 56, 56), "float32"),
+        weight: R.Tensor((64, 64, 3, 3), "float32"),
+    ):
+        with R.dataflow():
+            conv1 = R.nn.conv2d(data, weight)
+            relu1 = R.nn.relu(conv1)
+            gelu1 = R.nn.gelu(relu1)
+            gelu2 = R.nn.gelu(conv1)
+            out = relax.op.add(gelu1, gelu2)
+            R.output(out)
+
+        return out
+
+
+@tvm.script.ir_module
+class BranchTupleOutputPartitioned:
+    @R.function
+    def main(
+        data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+        with R.dataflow():
+            lv: R.Tuple(
+                R.Tensor((1, 64, 54, 54), dtype="float32"),
+                R.Tensor((1, 64, 54, 54), dtype="float32"),
+            ) = fused_relax_nn_conv2d_relax_nn_relu(data, weight)
+            lv1: R.Tensor((1, 64, 54, 54), dtype="float32") = lv[1]  # conv1
+            lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv[0]  # 
relu(conv1)
+            gelu1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv2)
+            gelu2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv1)
+            out: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(gelu1, 
gelu2)
+            R.output(out)
+        return out
+
+    @R.function
+    def fused_relax_nn_conv2d_relax_nn_relu(
+        data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+        weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+    ) -> R.Tuple(
+        R.Tensor((1, 64, 54, 54), dtype="float32"), R.Tensor((1, 64, 54, 54), 
dtype="float32")
+    ):
+        R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+        with R.dataflow():
+            gv: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.conv2d(data1, weight1)
+            gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.relu(gv)
+            R.output(gv, gv1)
+        return (gv1, gv)
+
+
+@tvm.script.ir_module
+class Branch:
+    @R.function
+    def main(
+        data: R.Tensor((1, 64, 56, 56), "float32"),
+        weight: R.Tensor((64, 64, 3, 3), "float32"),
+    ):
+        with R.dataflow():
+            conv1 = R.nn.conv2d(data, weight)
+            relu1 = R.nn.relu(conv1)
+            gelu1 = R.nn.gelu(conv1)
+
+            out = relax.op.add(relu1, gelu1)
+            R.output(out)
+
+        return out
+
+
+@tvm.script.ir_module
+class Conv2dx2:
+    @R.function
+    def main(
+        data: R.Tensor((16, 32, 32, 16), "float16"),
+        weight1: R.Tensor((16, 3, 3, 16), "float16"),
+        weight2: R.Tensor((16, 3, 3, 16), "float16"),
+    ):
+        with R.dataflow():
+            conv1 = relax.op.nn.conv2d(
+                data, weight1, padding=(1, 1), data_layout="NHWC", 
kernel_layout="OHWI"
+            )
+            conv2 = relax.op.nn.conv2d(
+                conv1, weight2, padding=(1, 1), data_layout="NHWC", 
kernel_layout="OHWI"
+            )
+            R.output(conv2)
+
+        return conv2
+
+
+@tvm.script.ir_module
+class Conv2dx2_partitioned:
+    @R.function
+    def main(
+        data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+        weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+        weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
+    ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+        with R.dataflow():
+            lv: R.Tensor((16, 32, 32, 16), dtype="float16") = 
fused_relax_nn_conv2d_cutlass(
+                data, weight1
+            )
+            gv: R.Tensor((16, 32, 32, 16), dtype="float16") = 
fused_relax_nn_conv2d_cutlass(
+                lv, weight2
+            )
+            R.output(gv)
+        return gv
+
+    @R.function
+    def fused_relax_nn_conv2d_cutlass(
+        data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+        weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+    ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+        R.func_attr({"Codegen": "cutlass", "global_symbol": 
"fused_relax_nn_conv2d_cutlass"})
+
+        @R.function
+        def gv(
+            data_1: R.Tensor((16, 32, 32, 16), dtype="float16"),
+            weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+        ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+            R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1})
+            with R.dataflow():
+                gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = 
R.nn.conv2d(
+                    data_1,
+                    weight1_1,
+                    padding=[1, 1, 1, 1],
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                )
+                R.output(gv_1)
+            return gv_1
+
+        gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1)
+        return gv1
+
+
+conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
activation=None)
+conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
activation="relax.nn.relu")
+
+
+def check(mod, patterns, expected, annoatate_codegen=False):
+    partitioned = relax.transform.FuseOpsByPattern(patterns, 
annoatate_codegen)(mod)
+    tvm.ir.assert_structural_equal(partitioned, expected)
+
+
+def test_partition_conv2d_relu():
+    check(Conv2dReLUx2, [("dnnl.conv2d_relu", conv2d_relu_pat)], 
Conv2dReLUx2Partitioned)
+
+
+def test_partition_multiple_patterns():
+    check(
+        Conv2dConv2dReLU,
+        [("dnnl.conv2d_relu", conv2d_relu_pat), ("dnnl.conv2d", conv2d_pat)],
+        Conv2dConv2dReLUPartitioned,
+    )
+
+
+def test_partition_order():
+    check(
+        Conv2dReLUx2,
+        [("dnnl.conv2d", conv2d_pat), ("dnnl.conv2d_relu", conv2d_relu_pat)],
+        Conv2dReLUx2Partitioned_only_conv2d,
+    )
+
+
+def test_branch_tuple_output():
+    check(BranchTupleOutput, [("dnnl.conv2d_relu", conv2d_relu_pat)], 
BranchTupleOutputPartitioned)
+
+
+def test_cyclic_dependency():
+    conv_pat = make_fused_bias_activation_pattern("relax.nn.conv2d")
+    relu_pat = is_op("relax.nn.relu")(conv_pat)
+    add_pat = is_op("relax.add")(relu_pat, wildcard())
+
+    with pytest.raises(tvm.error.TVMError) as err:
+        relax.transform.FuseOpsByPattern([("compiler_A.conv2d_relu_add", 
add_pat)])(Branch)
+
+    assert "A cyclic dependency detected" in str(err.value)
+
+
+def test_bind_params():
+    weight_np = np.random.randn(64, 64, 3, 3).astype("float32")
+    mod = tvm.transform.Sequential(
+        [
+            relax.transform.BindParams("main", {"weight1": weight_np}),
+            relax.transform.FuseOpsByPattern([("dnnl.conv2d_relu", 
conv2d_relu_pat)]),
+        ]
+    )(Conv2dReLU)
+
+    assert "fused_relax_nn_conv2d_relax_nn_relu" in [var.name_hint for var in 
mod.functions.keys()]
+
+    for gvar, f in mod.functions.items():
+        if gvar.name_hint == "fused_relax_nn_conv2d_relax_nn_relu":
+            conv2d = f.body.blocks[0].bindings[0].value
+            assert isinstance(conv2d.args[1], relax.Constant)
+
+
+def test_annotate_codegen():
+    check(
+        Conv2dReLU,
+        [("dnnl.conv2d_relu", conv2d_relu_pat)],
+        Conv2dReLU_composite_annotated,
+        annoatate_codegen=True,
+    )
+
+
+def test_multiple_calls_same_extern():
+    pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
with_bias=False, activation=None)
+    check(Conv2dx2, [("cutlass.conv2d", pat)], Conv2dx2_partitioned, 
annoatate_codegen=True)
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])

Reply via email to