slyubomirsky commented on code in PR #16732: URL: https://github.com/apache/tvm/pull/16732#discussion_r1538340938
########## src/relax/ir/dataflow_matcher.cc: ########## @@ -1140,34 +1071,173 @@ class PatternRewriter : ExprMutator { return block; } - /*! \brief The pattern for rewriting call nodes */ - Optional<DFPattern> pattern_; /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ - Optional<PatternContext> ctx_; + PatternContext ctx_; /*! * \brief The user-provided rewriter function. Its signature and semantics are: - * - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the matched - * call node and the map of patterns and matched expressions, it should return a new call node - * to replace the original one or the original matched call node as is. - * - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr> for dataflow block rewriting. - * Given the map of patterns and corresponding variables (bound variables or parameters), - * it should return a map that specifies new values for matched bound variables. It can refer + * + * - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr> + * + * Given the map of patterns and corresponding variables (bound + * variables or parameters), it should return a map that + * specifies new values for matched bound variables. It can refer * to the passed bindings to create the replacement expressions. */ - PackedFunc rewriter_func_; - std::unordered_set<const VarNode*> params_; + TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter_func_; +}; + +/*! + * \brief Apply pattern matching to each expression, replacing + * matches with the output of a user-provided rewriter function. + */ +class ExprPatternRewriter : ExprMutator { + public: + using ExprMutator::VisitBindingBlock_; + using ExprMutator::VisitExpr_; + + ExprPatternRewriter(DFPattern pat, + TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func) + : pattern_(pat), rewriter_func_(rewriter_func) {} + + template <typename PatternType> + static Function Run(PatternType pat, + TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func, + Function func) { + ExprPatternRewriter rewriter(pat, rewriter_func); + func = Downcast<Function>(rewriter(func)); + func = Downcast<Function>(RemoveAllUnused(func)); + return func; + } + + Expr VisitExpr_(const SeqExprNode* seq) override { + auto cache = bindings_; + SeqExpr prev = GetRef<SeqExpr>(seq); + + StructuralEqual struct_equal; + + while (true) { + SeqExpr next = Downcast<SeqExpr>(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); + if (struct_equal(prev, next)) { + return std::move(next); + } + + // Canonicalization may result in two previously-different + // expressions being recognized as identical. Elimination of + // common subexpressions may result in trival var-to-var + // bindings that can be canonicalized. Therefore, iterate the + // simplification steps until converged. + while (true) { + auto start_of_loop = next; + next = Downcast<SeqExpr>(CanonicalizeBindings(next)); + next = Downcast<SeqExpr>(EliminateCommonSubexpr(next)); + next = Downcast<SeqExpr>(RemoveAllUnused(next)); + if (struct_equal(start_of_loop, next)) { + break; + } + } + + if (struct_equal(prev, next)) { + return std::move(next); + } + + // Reset all knowledge of bindings that were collected from + // this SeqExpr. The collected bindings are only after + // the point where they were collected, and we are repeating + // the mutation of this SeqExpr. + bindings_ = cache; + prev = next; + } + } + + void VisitBinding_(const VarBindingNode* binding) override { + auto expr = VisitExpr(binding->value); + bindings_.Set(binding->var, expr); + ReEmitBinding(binding, expr); + } + + Expr VisitExpr(const Expr& expr) override { + auto node = ExprMutator::VisitExpr(expr); + + std::vector<DFPattern> matches_top_level; + if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) { + return builder_->Normalize(rewritten.value()); + } + + return node; + } + + private: + Optional<Expr> TryRewrite(const Expr& expr, const DFPattern& pattern, + std::vector<DFPattern>* matches_top_level) { + ICHECK(matches_top_level); + + // Special handling if the user-supplied pattern is a `OrPattern`. + // While the `ExtractMatchedExpr` can handle match the Review Comment: Looks like a typo. I assume it's supposed to be "handle matching," correct? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org