slyubomirsky commented on code in PR #16732:
URL: https://github.com/apache/tvm/pull/16732#discussion_r1538340938


##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -1140,34 +1071,173 @@ class PatternRewriter : ExprMutator {
     return block;
   }
 
-  /*! \brief The pattern for rewriting call nodes */
-  Optional<DFPattern> pattern_;
   /*! \brief The pattern constraint contexts for rewriting dataflow blocks */
-  Optional<PatternContext> ctx_;
+  PatternContext ctx_;
   /*!
    * \brief The user-provided rewriter function. Its signature and semantics 
are:
-   * - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the 
matched
-   *    call node and the map of patterns and matched expressions, it should 
return a new call node
-   *    to replace the original one or the original matched call node as is.
-   * - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr> for dataflow 
block rewriting.
-   *    Given the map of patterns and corresponding variables (bound variables 
or parameters),
-   *    it should return a map that specifies new values for matched bound 
variables. It can refer
+   *
+   * - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr>
+   *
+   *    Given the map of patterns and corresponding variables (bound
+   *    variables or parameters), it should return a map that
+   *    specifies new values for matched bound variables. It can refer
    *    to the passed bindings to create the replacement expressions.
    */
-  PackedFunc rewriter_func_;
-  std::unordered_set<const VarNode*> params_;
+  TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> 
rewriter_func_;
+};
+
+/*!
+ * \brief Apply pattern matching to each expression, replacing
+ * matches with the output of a user-provided rewriter function.
+ */
+class ExprPatternRewriter : ExprMutator {
+ public:
+  using ExprMutator::VisitBindingBlock_;
+  using ExprMutator::VisitExpr_;
+
+  ExprPatternRewriter(DFPattern pat,
+                      TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> 
rewriter_func)
+      : pattern_(pat), rewriter_func_(rewriter_func) {}
+
+  template <typename PatternType>
+  static Function Run(PatternType pat,
+                      TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> 
rewriter_func,
+                      Function func) {
+    ExprPatternRewriter rewriter(pat, rewriter_func);
+    func = Downcast<Function>(rewriter(func));
+    func = Downcast<Function>(RemoveAllUnused(func));
+    return func;
+  }
+
+  Expr VisitExpr_(const SeqExprNode* seq) override {
+    auto cache = bindings_;
+    SeqExpr prev = GetRef<SeqExpr>(seq);
+
+    StructuralEqual struct_equal;
+
+    while (true) {
+      SeqExpr next = 
Downcast<SeqExpr>(builder_->Normalize(ExprMutator::VisitExpr_(prev.get())));
+      if (struct_equal(prev, next)) {
+        return std::move(next);
+      }
+
+      // Canonicalization may result in two previously-different
+      // expressions being recognized as identical.  Elimination of
+      // common subexpressions may result in trival var-to-var
+      // bindings that can be canonicalized.  Therefore, iterate the
+      // simplification steps until converged.
+      while (true) {
+        auto start_of_loop = next;
+        next = Downcast<SeqExpr>(CanonicalizeBindings(next));
+        next = Downcast<SeqExpr>(EliminateCommonSubexpr(next));
+        next = Downcast<SeqExpr>(RemoveAllUnused(next));
+        if (struct_equal(start_of_loop, next)) {
+          break;
+        }
+      }
+
+      if (struct_equal(prev, next)) {
+        return std::move(next);
+      }
+
+      // Reset all knowledge of bindings that were collected from
+      // this SeqExpr.  The collected bindings are only after
+      // the point where they were collected, and we are repeating
+      // the mutation of this SeqExpr.
+      bindings_ = cache;
+      prev = next;
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) override {
+    auto expr = VisitExpr(binding->value);
+    bindings_.Set(binding->var, expr);
+    ReEmitBinding(binding, expr);
+  }
+
+  Expr VisitExpr(const Expr& expr) override {
+    auto node = ExprMutator::VisitExpr(expr);
+
+    std::vector<DFPattern> matches_top_level;
+    if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) {
+      return builder_->Normalize(rewritten.value());
+    }
+
+    return node;
+  }
+
+ private:
+  Optional<Expr> TryRewrite(const Expr& expr, const DFPattern& pattern,
+                            std::vector<DFPattern>* matches_top_level) {
+    ICHECK(matches_top_level);
+
+    // Special handling if the user-supplied pattern is a `OrPattern`.
+    // While the `ExtractMatchedExpr` can handle match the

Review Comment:
   Looks like a typo. I assume it's supposed to be "handle matching," correct?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to