Lunderberg commented on code in PR #16599:
URL: https://github.com/apache/tvm/pull/16599#discussion_r1499598652


##########
src/relax/transform/eliminate_common_subexpr.cc:
##########
@@ -20,223 +20,162 @@
 
 /*!
  * \file tvm/relax/transform/eliminate_common_subexpr.cc
- * \brief Eliminrate common subexpression pass.
+ * \brief Eliminate common subexpression pass.
  *
  * Currently it removes common subexpressions within a Function.
  */
+#include <tvm/relax/analysis.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/transform.h>
 #include <tvm/relax/utils.h>
 
-#include "utils.h"
+#include "../../support/utils.h"
 
 namespace tvm {
 namespace relax {
-
-// Checks if a given expression contains an impure subexpression
-// Caches the results of checks to avoid revisiting subexpressions
-class ImpurityDetector : public ExprVisitor {
- public:
-  bool Detect(const Expr& expr) {
-    impure_found_ = false;
-    VisitExpr(expr);
-    return impure_found_;
+namespace {
+/* \brief Lookup key for subexpression replacements
+ *
+ * The lookup key must contain the expression being bound, along with
+ * the struct info used for a match cast, if applicable.  Using
+ * `MatchCast` with StructuralEqual and StructuralHash would be almost
+ * correct, but acts as a point of definition for symbolic variables
+ * within the output struct info.  As a result, it would erroneously
+ * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and
+ * `R.match_cast(A, R.Tensor([p,q]))`, even though they define
+ * different symbolic variables.
+ */
+struct ReplacementKey {
+  tvm::relax::Expr bound_value;
+  tvm::Optional<tvm::relax::StructInfo> match_cast = tvm::NullOpt;
+
+  explicit ReplacementKey(const tvm::relax::Binding& binding)
+      : bound_value(GetBoundValue(binding)) {
+    if (const auto* ptr = binding.as<tvm::relax::MatchCastNode>()) {
+      match_cast = ptr->struct_info;
+    }
   }
 
-  void VisitExpr(const Expr& expr) {
-    // already checked: do not revisit
-    if (purity_map_.count(expr)) {
-      impure_found_ = impure_found_ || !purity_map_.at(expr);
-      return;
-    }
+  friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) {
+    tvm::StructuralEqual eq;
+    return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast);
+  }
+};
 
-    // in principle, we could stop checking once we find an impurity,
-    // but not doing so lets us fully populate the cache
+}  // namespace
+}  // namespace relax
+}  // namespace tvm
 
-    // store the previous state so we could assess the purity of this 
subexpression alone
-    bool prev_state = impure_found_;
-    impure_found_ = false;
-    ExprVisitor::VisitExpr(expr);
-    // if impure_found_ remains false, then the expression is pure
-    purity_map_[expr] = !impure_found_;
-    impure_found_ = prev_state || impure_found_;
+/* \brief Definition of std::hash<ReplacementKey>
+ *
+ * Specialization of std::hash must occur outside of tvm::relax
+ * namespace, and before its usage in the constructor of
+ * `CommonSubexprEliminator`.
+ */
+template <>
+struct std::hash<tvm::relax::ReplacementKey> {
+  std::size_t operator()(const tvm::relax::ReplacementKey& key) const {
+    tvm::StructuralHash hasher;
+    return tvm::support::HashCombine(hasher(key.bound_value), 
hasher(key.match_cast));
   }
+};
 
-  void VisitExpr_(const CallNode* call) {
-    // the only possible impurities can come from call nodes
-    bool is_impure = IsImpureCall(GetRef<Call>(call));
-    impure_found_ = impure_found_ || is_impure;
-    ExprVisitor::VisitExpr_(call);
-  }
+namespace tvm {
+namespace relax {
 
- private:
-  bool impure_found_ = false;
-  std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_;
-};
+namespace {
 
-class SubexprCounter : public ExprVisitor {
+class CommonSubexprEliminator : public ExprMutator {
  public:
-  static std::unordered_map<Expr, int, StructuralHash, StructuralEqual> 
Count(const Expr& expr) {
-    SubexprCounter visitor;
-    visitor(expr);
-    return visitor.count_map_;
+  explicit CommonSubexprEliminator(bool call_only = false) : 
call_only_(call_only) {}
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override {
+    auto cache_exprs = expr_replacements_;
+    auto cache_vars = var_remap_;
+    auto output = ExprMutator::VisitBindingBlock_(block);
+    expr_replacements_ = cache_exprs;
+    var_remap_ = cache_vars;
+    return output;
   }
 
-  // overriding VisitExpr ensures we do this for every subexpression
-  void VisitExpr(const Expr& e) override {
-    // Cases we ignore because we will not substitute them:
-    // 1. Vars of all kinds
-    // 2. Op nodes (nothing we can do)
-    // 3. PrimValue nodes (not much benefit from binding to a var)
-    // 4. StringImm nodes (not much benefit from binding to a var)
-    // 5. Scalar constants (not much benefit from binding to a var)
-    // 6. Shape expressions (exist to hold several PrimValue objects)
-    // 7. DataType nodes (no need to modify dtype nodes)
-    if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
-          e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
-          e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
-          e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
-          e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) {
-      // also if e has an impure subexpression, we will not deduplicate it
-      if (!impurity_detector_.Detect(e)) {
-        int count = 0;
-        if (count_map_.count(e)) {
-          count = count_map_.at(e);
-        }
-        count_map_[e] = count + 1;
+  void VisitBinding(const Binding& binding) override {
+    Expr bound_value = VisitExpr(GetBoundValue(binding));
+
+    Binding output_binding = [&]() -> Binding {
+      if (binding.as<VarBindingNode>()) {
+        return VarBinding(binding->var, bound_value);
+      } else if (auto match_cast = binding.as<MatchCastNode>()) {
+        return MatchCast(binding->var, bound_value, match_cast->struct_info);
+      } else {
+        LOG(FATAL) << "Binding must be either VarBinding or MatchCast, "
+                   << "but was " << binding->GetTypeKey();
       }
-    }
+    }();
 
-    // Only visit the interior of objects that we might still keep
-    // around.  Otherwise, double-counting these would lead to extra
-    // variable bindings.
-    //
-    // Before:
-    //     y = f(a+b)
-    //     z = f(a+b)
-    //
-    // Expected:
-    //     y = f(a+b)  // De-duped from (y==z)
-    //     z = y
-    //
-    // Erroneous output:
-    //     c = a+b    // Incorrect, a+b only has a single usage.
-    //     y = f(c)   // De-duped from
-    //     z = y
-    //
-    if (auto it = count_map_.find(e); it == count_map_.end() || it->second < 
2) {
-      ExprVisitor::VisitExpr(e);
-    }
-  }
+    ReplacementKey lookup_key(output_binding);
 
-  // do not visit inner functions: we will do CSE within those
-  void VisitExpr_(const FunctionNode* func) override {}
+    if (call_only_ && !bound_value->IsInstance<relax::CallNode>()) {
+      VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate " 
<< bound_value;
 
-  // we are not going to do replacements inside struct info to avoid binding 
lots of reused shapes
-  void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
+    } else if (ContainsImpureCall(bound_value)) {
+      VLOG(1) << "Since the expression is impure, cannot de-duplicate " << 
bound_value;
 
- private:
-  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
-  ImpurityDetector impurity_detector_;
-};
+    } else if (auto it = expr_replacements_.find(lookup_key); it != 
expr_replacements_.end()) {
+      VLOG(1) << "Value " << bound_value << " has previously been bound as " 
<< it->second
+              << ".  The duplicate binding of this value to " << binding->var
+              << " will be replaced with a trivial binding, "
+              << "and occurrences of " << binding->var << " will be replaced 
with " << it->second;
+      output_binding = VarBinding(binding->var, it->second);
+      var_remap_.insert({binding->var->vid, it->second});
 
-class CommonSubexprEliminator : public ExprMutator {
- public:
-  explicit CommonSubexprEliminator(
-      std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map,
-      bool call_only = false)
-      : count_map_(std::move(count_map)), call_only_(call_only) {}
+    } else {
+      VLOG(1) << "Value " << bound_value << " is bound to " << binding->var
+              << " and may be de-duplicated if it occurs again.";
 
-  // overriding here ensures we visit every subexpression
-  Expr VisitExpr(const Expr& e) override {
-    if (call_only_ && !e->IsInstance<CallNode>()) {
-      return ExprMutator::VisitExpr(e);
-    }
-    if (count_map_.count(e) && count_map_.at(e) > 1) {
-      // if we already have a mapping for it, get it
-      if (replacements_.count(e)) {
-        return replacements_.at(e);
-      }
-      // Otherwise, insert a new binding for the current expression.
-      // Visit before emitting to do inner replacements
-      Expr new_e = ExprMutator::VisitExpr(e);
-      Var v = builder_->Emit(new_e);
-      replacements_[e] = v;
-      return v;
+      expr_replacements_.insert({lookup_key, binding->var});
     }
-    return ExprMutator::VisitExpr(e);
-  }
 
-  // we are not going to do replacements inside struct info to avoid binding 
lots of reused shapes
-  StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) 
override {
-    return struct_info;
+    builder_->EmitNormalized(output_binding);
   }
 
   Expr VisitExpr_(const FunctionNode* op) override {
-    Function func = GetRef<Function>(op);
-
-    auto cache = SubexprCounter::Count(op->body);
-    std::swap(cache, count_map_);
-    Expr output = ExprMutator::VisitExpr_(op);
-    std::swap(cache, count_map_);
-
-    return output;
-  }
-
-  void VisitBinding_(const VarBindingNode* binding) override {
-    // no need to visit var def because the struct info isn't going to change
-    Expr new_value = RegisterBoundValue(binding->var, binding->value);
-
-    if (new_value.same_as(binding->value)) {
-      builder_->EmitNormalized(GetRef<VarBinding>(binding));
+    // If we have accumulated any state, visit the function in a fresh
+    // copy of the mutator, to avoid replacing a child-scope
+    // expression with a parent-scope binding, or vice versa.
+    if (expr_replacements_.size() || var_remap_.size()) {
+      return VisitWithCleanScope(GetRef<Expr>(op));
     } else {
-      // no need to renormalize new_value because all replacements are with 
vars
-      builder_->EmitNormalized(VarBinding(binding->var, new_value, 
binding->span));
+      return ExprMutator::VisitExpr_(op);
     }
   }
 
-  void VisitBinding_(const MatchCastNode* binding) override {
-    // no need to visit var def because the struct info isn't going to change
-    Expr new_value = RegisterBoundValue(binding->var, binding->value);
-
-    // re-emit old binding if nothing changes
-    if (new_value.same_as(binding->value)) {
-      builder_->EmitNormalized(GetRef<MatchCast>(binding));
+  Expr VisitExpr_(const IfNode* op) override {
+    Expr cond = VisitExpr(op->cond);
+    Expr true_branch = VisitWithCleanScope(op->true_branch);

Review Comment:
   Good call.  I've updated to have propagate bindings from before a `if/else` 
branch into the body of either branch.  This functionality has three new unit 
tests, to validate (1) bindings before a branch are de-duped inside the branch, 
(2) bindings within a branch are not de-duped after the branch, and (3) 
bindings within a branch are not de-duped into its sibling branch.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to