This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d4ca123afc [BugFix] Support rewrite_once when the number of callbacks 
> 1 (#14344)
d4ca123afc is described below

commit d4ca123afc54ebabe3c9b0666a5456aaf25eeaa2
Author: sisleyli <43139237+sisle...@users.noreply.github.com>
AuthorDate: Wed Mar 22 02:34:01 2023 +0800

    [BugFix] Support rewrite_once when the number of callbacks > 1 (#14344)
    
    * [BugFix] Support rewrite_once when the number of callbacks > 1
    
    * callbacks_map -> done, swapping false and true
    
    ---------
    
    Co-authored-by: Bin Li <bin...@amd.com>
---
 src/relay/ir/dataflow_matcher.cc            | 37 +++++++++-----
 tests/python/relay/test_dataflow_pattern.py | 79 +++++++++++++++++++++++++----
 2 files changed, 94 insertions(+), 22 deletions(-)

diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index cf186c474e..67c6bae6c5 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -796,24 +796,35 @@ Expr PatternRewriter::Rewrite(const 
Array<DFPatternCallback>& callbacks, const E
   bool equal = true;
   static auto* structural_equal = 
runtime::Registry::Get("node.StructuralEqual");
   ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
+  // Keep track of callbacks that have finished rewriting
+  std::unordered_map<DFPatternCallback, bool, ObjectPtrHash, ObjectPtrEqual> 
done;
   do {
     last = post;
     for (auto callback : callbacks) {
-      callback_ = callback;
-      if (callback_->require_type) {
-        post = InferTypeWithModule(post, mod_);
-      }
-      auto grouper = PatternGrouper();
-      groups_ = grouper.GroupMatches(callback_->pattern, post);
-      gid_assignments_ = grouper.GetGIDAssignments();
-      memo_.clear();
-      VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
-      post = this->VisitExpr(post);
-      VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
-      count++;
+      if (!done[callback]) {
+        auto before = post;
+        callback_ = callback;
+        if (callback_->require_type) {
+          post = InferTypeWithModule(post, mod_);
+        }
+        auto grouper = PatternGrouper();
+        groups_ = grouper.GroupMatches(callback_->pattern, post);
+        gid_assignments_ = grouper.GetGIDAssignments();
+        memo_.clear();
+        VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
+        post = this->VisitExpr(post);
+        VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
+        count++;
+        if (callback_->rewrite_once) {
+          bool current_equal = (*structural_equal)(before, post, false, true);
+          if (!current_equal) {
+            done[callback] = true;
+          }
+        }
+      }
     }
     equal = (*structural_equal)(last, post, false, true);
-  } while (!equal && count < 100 && !callback_->rewrite_once);
+  } while (!equal && count < 100);
   if (count >= 100) {
     LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?";
   }
diff --git a/tests/python/relay/test_dataflow_pattern.py 
b/tests/python/relay/test_dataflow_pattern.py
index 1bd05f5258..bcb665121b 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -1804,22 +1804,83 @@ def test_rewrite_once():
             if new_args:
                 return relay.op.concatenate(relay.expr.Tuple(new_args), axis=0)
             else:
-                return concat_args
+                return concat_args[0]
 
     x = relay.var("x")
     y = relay.var("y")
     z = relay.var("z")
     concat = relay.op.concatenate(relay.expr.Tuple([x, y, z]), axis=0)
 
-    # Let the rewriter run recursively
-    out = rewrite(ConcatRewriter(False), concat)
-    expected = relay.expr.Tuple([x])
-    assert tvm.ir.structural_equal(out, expected)
+    def test_one_callback():
+        # Let the rewriter run recursively
+        out = rewrite(ConcatRewriter(False), concat)
+        expected = x
+        assert tvm.ir.structural_equal(out, expected)
+
+        # Run the rewriter once
+        out = rewrite(ConcatRewriter(True), concat)
+        expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0)
+        assert tvm.ir.structural_equal(out, expected)
+
+    def test_multi_callbacks():
+        # This class recursively add a nn.relu operator after nn.softmax
+        class OneMoreReluRewriter(DFPatternCallback):
+            def __init__(self, rewrite_once):
+                super().__init__(rewrite_once=rewrite_once)
+                self.pattern = is_op("nn.softmax")(None)
+
+            def callback(self, pre, post, node_map):
+                return relay.nn.relu(post)
+
+        def before():
+            # Before:
+            #    x    y    z
+            #    |    |    |
+            #       concat
+            #         |
+            #      softmax
+            return relay.nn.softmax(concat)
+
+        def once_concat():
+            # ConcatRewrite once, OneMoreReluRewrite once
+            # Expected:
+            #   x    y
+            #   |    |
+            #   concat
+            #      |
+            #   softmax
+            #      |
+            #    relu
+            return relay.nn.relu(
+                relay.nn.softmax(relay.op.concatenate(relay.expr.Tuple([x, 
y]), axis=0))
+            )
+
+        def recursive_concat():
+            # ConcatRewrite recursively, OneMoreReluRewrite once
+            # Expected:
+            #      x
+            #      |
+            #   softmax
+            #      |
+            #    relu
+            return relay.nn.relu(relay.nn.softmax(x))
+
+        # Run ConcatRewriter once, OneMoreReluRewriter once
+        out = rewrite(
+            [OneMoreReluRewriter(True), ConcatRewriter(True)],
+            before(),
+        )
+        assert tvm.ir.structural_equal(out, once_concat())
+
+        # Run ConcatRewriter recursively, OneMoreReluRewriter once
+        out = rewrite(
+            [OneMoreReluRewriter(True), ConcatRewriter(False)],
+            before(),
+        )
+        assert tvm.ir.structural_equal(out, recursive_concat())
 
-    # Run the rewriter once
-    out = rewrite(ConcatRewriter(True), concat)
-    expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0)
-    assert tvm.ir.structural_equal(out, expected)
+    test_one_callback()
+    test_multi_callbacks()
 
 
 def test_matched_outside_but_dominated():

Reply via email to