This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 050b23ff06 [Relay]Disable InferType if it was done and no changes 
after previous pass (#17585)
050b23ff06 is described below

commit 050b23ff069ed810b424106df74c06562b07f34f
Author: Andrey Malyshev <[email protected]>
AuthorDate: Mon Jan 27 17:28:00 2025 +0200

    [Relay]Disable InferType if it was done and no changes after previous pass 
(#17585)
    
    Disable InferType if it was done and no changes after previous pass
    
    This optimizatin allows to speedup PatternRewriter transformations by
    reusing of preious type inferred expression instead of perform
    InferType multiple times
---
 src/relay/ir/dataflow_matcher.cc | 22 +++++++++++++++-------
 1 file changed, 15 insertions(+), 7 deletions(-)

diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 3e86e1c8ea..9d117adbbc 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -851,24 +851,32 @@ Expr PatternRewriter::Rewrite(const 
Array<DFPatternCallback>& callbacks, const E
   std::unordered_map<DFPatternCallback, bool, ObjectPtrHash, ObjectPtrEqual> 
done;
   do {
     last = post;
+    // We don't have to call InferType if previous pass has not modified 
anything
+    // We can just take previous typed state of the expression
+    bool types_invalidated = true;
     for (auto callback : callbacks) {
       if (!done[callback]) {
         auto before = post;
+        auto post_typed = post;
         callback_ = callback;
-        if (callback_->require_type) {
-          post = InferTypeWithModule(post, mod_);
+        if (callback_->require_type && types_invalidated) {
+          post_typed = InferTypeWithModule(post, mod_);
         }
         auto grouper = PatternGrouper();
-        groups_ = grouper.GroupMatches(callback_->pattern, post);
+        groups_ = grouper.GroupMatches(callback_->pattern, post_typed);
         gid_assignments_ = grouper.GetGIDAssignments();
         memo_.clear();
         VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
-        post = this->VisitExpr(post);
+        post = this->VisitExpr(post_typed);
         VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
         count++;
-        if (callback_->rewrite_once) {
-          bool current_equal = (*structural_equal)(before, post, false, true);
-          if (!current_equal) {
+        bool current_equal = (*structural_equal)(before, post, false, true);
+        if (callback_->require_type && current_equal) {
+          types_invalidated = false;
+          post = post_typed;
+        } else {
+          types_invalidated = true;
+          if (callback_->rewrite_once) {
             done[callback] = true;
           }
         }

Reply via email to