This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new d4ca123afc [BugFix] Support rewrite_once when the number of callbacks > 1 (#14344) d4ca123afc is described below commit d4ca123afc54ebabe3c9b0666a5456aaf25eeaa2 Author: sisleyli <43139237+sisle...@users.noreply.github.com> AuthorDate: Wed Mar 22 02:34:01 2023 +0800 [BugFix] Support rewrite_once when the number of callbacks > 1 (#14344) * [BugFix] Support rewrite_once when the number of callbacks > 1 * callbacks_map -> done, swapping false and true --------- Co-authored-by: Bin Li <bin...@amd.com> --- src/relay/ir/dataflow_matcher.cc | 37 +++++++++----- tests/python/relay/test_dataflow_pattern.py | 79 +++++++++++++++++++++++++---- 2 files changed, 94 insertions(+), 22 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index cf186c474e..67c6bae6c5 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -796,24 +796,35 @@ Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const E bool equal = true; static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); ICHECK(structural_equal) << "node.StructuralEqual is not registered."; + // Keep track of callbacks that have finished rewriting + std::unordered_map<DFPatternCallback, bool, ObjectPtrHash, ObjectPtrEqual> done; do { last = post; for (auto callback : callbacks) { - callback_ = callback; - if (callback_->require_type) { - post = InferTypeWithModule(post, mod_); - } - auto grouper = PatternGrouper(); - groups_ = grouper.GroupMatches(callback_->pattern, post); - gid_assignments_ = grouper.GetGIDAssignments(); - memo_.clear(); - VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre); - post = this->VisitExpr(post); - VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post); - count++; + if (!done[callback]) { + auto before = post; + callback_ = callback; + if (callback_->require_type) { + post = InferTypeWithModule(post, mod_); + } + auto grouper = PatternGrouper(); + groups_ = grouper.GroupMatches(callback_->pattern, post); + gid_assignments_ = grouper.GetGIDAssignments(); + memo_.clear(); + VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre); + post = this->VisitExpr(post); + VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post); + count++; + if (callback_->rewrite_once) { + bool current_equal = (*structural_equal)(before, post, false, true); + if (!current_equal) { + done[callback] = true; + } + } + } } equal = (*structural_equal)(last, post, false, true); - } while (!equal && count < 100 && !callback_->rewrite_once); + } while (!equal && count < 100); if (count >= 100) { LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; } diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 1bd05f5258..bcb665121b 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -1804,22 +1804,83 @@ def test_rewrite_once(): if new_args: return relay.op.concatenate(relay.expr.Tuple(new_args), axis=0) else: - return concat_args + return concat_args[0] x = relay.var("x") y = relay.var("y") z = relay.var("z") concat = relay.op.concatenate(relay.expr.Tuple([x, y, z]), axis=0) - # Let the rewriter run recursively - out = rewrite(ConcatRewriter(False), concat) - expected = relay.expr.Tuple([x]) - assert tvm.ir.structural_equal(out, expected) + def test_one_callback(): + # Let the rewriter run recursively + out = rewrite(ConcatRewriter(False), concat) + expected = x + assert tvm.ir.structural_equal(out, expected) + + # Run the rewriter once + out = rewrite(ConcatRewriter(True), concat) + expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0) + assert tvm.ir.structural_equal(out, expected) + + def test_multi_callbacks(): + # This class recursively add a nn.relu operator after nn.softmax + class OneMoreReluRewriter(DFPatternCallback): + def __init__(self, rewrite_once): + super().__init__(rewrite_once=rewrite_once) + self.pattern = is_op("nn.softmax")(None) + + def callback(self, pre, post, node_map): + return relay.nn.relu(post) + + def before(): + # Before: + # x y z + # | | | + # concat + # | + # softmax + return relay.nn.softmax(concat) + + def once_concat(): + # ConcatRewrite once, OneMoreReluRewrite once + # Expected: + # x y + # | | + # concat + # | + # softmax + # | + # relu + return relay.nn.relu( + relay.nn.softmax(relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0)) + ) + + def recursive_concat(): + # ConcatRewrite recursively, OneMoreReluRewrite once + # Expected: + # x + # | + # softmax + # | + # relu + return relay.nn.relu(relay.nn.softmax(x)) + + # Run ConcatRewriter once, OneMoreReluRewriter once + out = rewrite( + [OneMoreReluRewriter(True), ConcatRewriter(True)], + before(), + ) + assert tvm.ir.structural_equal(out, once_concat()) + + # Run ConcatRewriter recursively, OneMoreReluRewriter once + out = rewrite( + [OneMoreReluRewriter(True), ConcatRewriter(False)], + before(), + ) + assert tvm.ir.structural_equal(out, recursive_concat()) - # Run the rewriter once - out = rewrite(ConcatRewriter(True), concat) - expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0) - assert tvm.ir.structural_equal(out, expected) + test_one_callback() + test_multi_callbacks() def test_matched_outside_but_dominated():