This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9914a77b0f [Unity] Process all Relax functions in 
CompositeFunctionAnnotator (#14736)
9914a77b0f is described below

commit 9914a77b0f9c202b5145d8e292cb6c5b6e5d9545
Author: Lite Ye <yelite...@gmail.com>
AuthorDate: Thu Apr 27 20:27:34 2023 -0400

    [Unity] Process all Relax functions in CompositeFunctionAnnotator (#14736)
    
    Process all Relax functions when annotating codegen
---
 src/relax/transform/fuse_ops.cc                    | 19 +++--
 .../relax/test_transform_fuse_ops_by_pattern.py    | 86 ++++++++++++++++------
 2 files changed, 78 insertions(+), 27 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index f50a578954..ad1dc3eb98 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1079,11 +1079,20 @@ class CompositeFunctionAnnotator : public ExprMutator {
 
   IRModule Run() {
     auto mod = builder_->GetContextIRModule();
-    auto gvar = mod->GetGlobalVar("main");
-    auto func = Downcast<Function>(mod->Lookup(gvar));
-    auto new_func =
-        Function(func->params, VisitExpr(func->body), func->ret_struct_info, 
func->attrs);
-    builder_->UpdateFunction(gvar, new_func);
+    auto all_functions = mod->functions;
+    for (const auto& entry : all_functions) {
+      if (const auto* func = entry.second.as<FunctionNode>()) {
+        if (func->GetAttr<String>(attr::kComposite).defined()) {
+          continue;
+        }
+        auto new_body = VisitExpr(func->body);
+        if (!new_body.same_as(func->body)) {
+          auto new_func = Function(func->params, VisitExpr(func->body), 
func->ret_struct_info,
+                                   func->attrs, func->span);
+          builder_->UpdateFunction(entry.first, new_func);
+        }
+      }
+    }
     return builder_->GetContextIRModule();
   }
 
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 146c8e1ebc..5fb2b3332c 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -21,9 +21,9 @@ import tvm
 from tvm import relax
 from tvm.relax.dpl.pattern import (
     is_op,
+    is_tuple_get_item,
     make_fused_bias_activation_pattern,
     wildcard,
-    is_tuple_get_item,
 )
 from tvm.relax.transform import PatternCheckContext
 from tvm.script import ir as I
@@ -339,7 +339,7 @@ class Branch:
 @tvm.script.ir_module
 class Conv2dx2:
     @R.function
-    def main(
+    def main2(
         data: R.Tensor((16, 32, 32, 16), "float16"),
         weight1: R.Tensor((16, 3, 3, 16), "float16"),
         weight2: R.Tensor((16, 3, 3, 16), "float16"),
@@ -355,26 +355,28 @@ class Conv2dx2:
 
         return conv2
 
-
-@tvm.script.ir_module
-class Conv2dx2_partitioned:
     @R.function
     def main(
-        data: R.Tensor((16, 32, 32, 16), dtype="float16"),
-        weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
-        weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
-    ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
-        cls = Conv2dx2_partitioned
+        data: R.Tensor((16, 32, 32, 16), "float16"),
+        weight1: R.Tensor((16, 3, 3, 16), "float16"),
+        weight2: R.Tensor((16, 3, 3, 16), "float16"),
+    ):
         with R.dataflow():
-            lv: R.Tensor((16, 32, 32, 16), dtype="float16") = 
cls.fused_relax_nn_conv2d_cutlass(
-                data, weight1
+            conv1 = relax.op.nn.conv2d(
+                data, weight1, padding=(1, 1), data_layout="NHWC", 
kernel_layout="OHWI"
             )
-            gv: R.Tensor((16, 32, 32, 16), dtype="float16") = 
cls.fused_relax_nn_conv2d_cutlass(
-                lv, weight2
+            conv2 = relax.op.nn.conv2d(
+                conv1, weight2, padding=(1, 1), data_layout="NHWC", 
kernel_layout="OHWI"
             )
-            R.output(gv)
-        return gv
+            conv3 = Conv2dx2.main2(data, weight1, weight2)
+            result = conv2 + conv3
+            R.output(result)
+
+        return result
 
+
+@tvm.script.ir_module
+class Conv2dx2_partitioned:
     @R.function
     def fused_relax_nn_conv2d_cutlass(
         data: R.Tensor((16, 32, 32, 16), dtype="float16"),
@@ -383,26 +385,66 @@ class Conv2dx2_partitioned:
         R.func_attr({"Codegen": "cutlass", "global_symbol": 
"fused_relax_nn_conv2d_cutlass"})
 
         @R.function
-        def gv(
+        def gv_1(
             data_1: R.Tensor((16, 32, 32, 16), dtype="float16"),
             weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"),
         ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
             R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1})
             with R.dataflow():
-                gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = 
R.nn.conv2d(
+                gv_2: R.Tensor((16, 32, 32, 16), dtype="float16") = 
R.nn.conv2d(
                     data_1,
                     weight1_1,
+                    strides=[1, 1],
                     padding=[1, 1, 1, 1],
+                    dilation=[1, 1],
+                    groups=1,
                     data_layout="NHWC",
                     kernel_layout="OHWI",
                     out_layout="NHWC",
+                    out_dtype="void",
                 )
-                R.output(gv_1)
-            return gv_1
+                R.output(gv_2)
+            return gv_2
 
-        gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1)
+        gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv_1(data, weight1)
         return gv1
 
+    @R.function
+    def main2(
+        data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+        weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+        weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
+    ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+        cls = Conv2dx2_partitioned
+        with R.dataflow():
+            lv: R.Tensor((16, 32, 32, 16), dtype="float16") = 
cls.fused_relax_nn_conv2d_cutlass(
+                data, weight1
+            )
+            gv: R.Tensor((16, 32, 32, 16), dtype="float16") = 
cls.fused_relax_nn_conv2d_cutlass(
+                lv, weight2
+            )
+            R.output(gv)
+        return gv
+
+    @R.function
+    def main(
+        data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+        weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+        weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
+    ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+        cls = Conv2dx2_partitioned
+        with R.dataflow():
+            lv1: R.Tensor((16, 32, 32, 16), dtype="float16") = 
cls.fused_relax_nn_conv2d_cutlass(
+                data, weight1
+            )
+            lv2: R.Tensor((16, 32, 32, 16), dtype="float16") = 
cls.fused_relax_nn_conv2d_cutlass(
+                lv1, weight2
+            )
+            conv3: R.Tensor((16, 32, 32, 16), dtype="float16") = 
cls.main2(data, weight1, weight2)
+            result: R.Tensor((16, 32, 32, 16), dtype="float16") = R.add(lv2, 
conv3)
+            R.output(result)
+        return result
+
 
 conv2d_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
activation=None)
 conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
activation="relax.nn.relu")
@@ -478,7 +520,7 @@ def test_annotate_codegen():
     )
 
 
-def test_multiple_calls_same_extern():
+def test_multiple_entries_multiple_calls_same_extern():
     pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
with_bias=False, activation=None)
     check(Conv2dx2, [("cutlass.conv2d", pat)], Conv2dx2_partitioned, 
annotate_codegen=True)
 

Reply via email to