Lunderberg commented on code in PR #16599:
URL: https://github.com/apache/tvm/pull/16599#discussion_r1499585431


##########
src/relax/transform/eliminate_common_subexpr.cc:
##########
@@ -20,223 +20,162 @@
 
 /*!
  * \file tvm/relax/transform/eliminate_common_subexpr.cc
- * \brief Eliminrate common subexpression pass.
+ * \brief Eliminate common subexpression pass.
  *
  * Currently it removes common subexpressions within a Function.
  */
+#include <tvm/relax/analysis.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/transform.h>
 #include <tvm/relax/utils.h>
 
-#include "utils.h"
+#include "../../support/utils.h"
 
 namespace tvm {
 namespace relax {
-
-// Checks if a given expression contains an impure subexpression
-// Caches the results of checks to avoid revisiting subexpressions
-class ImpurityDetector : public ExprVisitor {
- public:
-  bool Detect(const Expr& expr) {
-    impure_found_ = false;
-    VisitExpr(expr);
-    return impure_found_;
+namespace {
+/* \brief Lookup key for subexpression replacements
+ *
+ * The lookup key must contain the expression being bound, along with
+ * the struct info used for a match cast, if applicable.  Using
+ * `MatchCast` with StructuralEqual and StructuralHash would be almost
+ * correct, but acts as a point of definition for symbolic variables
+ * within the output struct info.  As a result, it would erroneously
+ * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and
+ * `R.match_cast(A, R.Tensor([p,q]))`, even though they define
+ * different symbolic variables.
+ */
+struct ReplacementKey {
+  tvm::relax::Expr bound_value;
+  tvm::Optional<tvm::relax::StructInfo> match_cast = tvm::NullOpt;
+
+  explicit ReplacementKey(const tvm::relax::Binding& binding)
+      : bound_value(GetBoundValue(binding)) {
+    if (const auto* ptr = binding.as<tvm::relax::MatchCastNode>()) {
+      match_cast = ptr->struct_info;
+    }
   }
 
-  void VisitExpr(const Expr& expr) {
-    // already checked: do not revisit
-    if (purity_map_.count(expr)) {
-      impure_found_ = impure_found_ || !purity_map_.at(expr);
-      return;
-    }
+  friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) {
+    tvm::StructuralEqual eq;
+    return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast);
+  }
+};
 
-    // in principle, we could stop checking once we find an impurity,
-    // but not doing so lets us fully populate the cache
+}  // namespace
+}  // namespace relax
+}  // namespace tvm
 
-    // store the previous state so we could assess the purity of this 
subexpression alone
-    bool prev_state = impure_found_;
-    impure_found_ = false;
-    ExprVisitor::VisitExpr(expr);
-    // if impure_found_ remains false, then the expression is pure
-    purity_map_[expr] = !impure_found_;
-    impure_found_ = prev_state || impure_found_;
+/* \brief Definition of std::hash<ReplacementKey>
+ *
+ * Specialization of std::hash must occur outside of tvm::relax
+ * namespace, and before its usage in the constructor of
+ * `CommonSubexprEliminator`.
+ */
+template <>
+struct std::hash<tvm::relax::ReplacementKey> {
+  std::size_t operator()(const tvm::relax::ReplacementKey& key) const {
+    tvm::StructuralHash hasher;
+    return tvm::support::HashCombine(hasher(key.bound_value), 
hasher(key.match_cast));
   }
+};
 
-  void VisitExpr_(const CallNode* call) {
-    // the only possible impurities can come from call nodes
-    bool is_impure = IsImpureCall(GetRef<Call>(call));
-    impure_found_ = impure_found_ || is_impure;
-    ExprVisitor::VisitExpr_(call);
-  }
+namespace tvm {
+namespace relax {
 
- private:
-  bool impure_found_ = false;
-  std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_;
-};
+namespace {
 
-class SubexprCounter : public ExprVisitor {
+class CommonSubexprEliminator : public ExprMutator {
  public:
-  static std::unordered_map<Expr, int, StructuralHash, StructuralEqual> 
Count(const Expr& expr) {
-    SubexprCounter visitor;
-    visitor(expr);
-    return visitor.count_map_;
+  explicit CommonSubexprEliminator(bool call_only = false) : 
call_only_(call_only) {}
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override {
+    auto cache_exprs = expr_replacements_;
+    auto cache_vars = var_remap_;
+    auto output = ExprMutator::VisitBindingBlock_(block);
+    expr_replacements_ = cache_exprs;
+    var_remap_ = cache_vars;
+    return output;
   }
 
-  // overriding VisitExpr ensures we do this for every subexpression
-  void VisitExpr(const Expr& e) override {
-    // Cases we ignore because we will not substitute them:
-    // 1. Vars of all kinds
-    // 2. Op nodes (nothing we can do)
-    // 3. PrimValue nodes (not much benefit from binding to a var)
-    // 4. StringImm nodes (not much benefit from binding to a var)
-    // 5. Scalar constants (not much benefit from binding to a var)
-    // 6. Shape expressions (exist to hold several PrimValue objects)
-    // 7. DataType nodes (no need to modify dtype nodes)
-    if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
-          e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
-          e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
-          e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
-          e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) {
-      // also if e has an impure subexpression, we will not deduplicate it
-      if (!impurity_detector_.Detect(e)) {
-        int count = 0;
-        if (count_map_.count(e)) {
-          count = count_map_.at(e);
-        }
-        count_map_[e] = count + 1;
+  void VisitBinding(const Binding& binding) override {
+    Expr bound_value = VisitExpr(GetBoundValue(binding));
+
+    Binding output_binding = [&]() -> Binding {
+      if (binding.as<VarBindingNode>()) {
+        return VarBinding(binding->var, bound_value);
+      } else if (auto match_cast = binding.as<MatchCastNode>()) {
+        return MatchCast(binding->var, bound_value, match_cast->struct_info);
+      } else {
+        LOG(FATAL) << "Binding must be either VarBinding or MatchCast, "
+                   << "but was " << binding->GetTypeKey();
       }
-    }
+    }();
 
-    // Only visit the interior of objects that we might still keep
-    // around.  Otherwise, double-counting these would lead to extra
-    // variable bindings.
-    //
-    // Before:
-    //     y = f(a+b)
-    //     z = f(a+b)
-    //
-    // Expected:
-    //     y = f(a+b)  // De-duped from (y==z)
-    //     z = y
-    //
-    // Erroneous output:
-    //     c = a+b    // Incorrect, a+b only has a single usage.
-    //     y = f(c)   // De-duped from
-    //     z = y
-    //
-    if (auto it = count_map_.find(e); it == count_map_.end() || it->second < 
2) {
-      ExprVisitor::VisitExpr(e);
-    }
-  }
+    ReplacementKey lookup_key(output_binding);
 
-  // do not visit inner functions: we will do CSE within those
-  void VisitExpr_(const FunctionNode* func) override {}
+    if (call_only_ && !bound_value->IsInstance<relax::CallNode>()) {
+      VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate " 
<< bound_value;
 
-  // we are not going to do replacements inside struct info to avoid binding 
lots of reused shapes
-  void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
+    } else if (ContainsImpureCall(bound_value)) {
+      VLOG(1) << "Since the expression is impure, cannot de-duplicate " << 
bound_value;
 
- private:
-  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
-  ImpurityDetector impurity_detector_;
-};
+    } else if (auto it = expr_replacements_.find(lookup_key); it != 
expr_replacements_.end()) {
+      VLOG(1) << "Value " << bound_value << " has previously been bound as " 
<< it->second
+              << ".  The duplicate binding of this value to " << binding->var
+              << " will be replaced with a trivial binding, "
+              << "and occurrences of " << binding->var << " will be replaced 
with " << it->second;
+      output_binding = VarBinding(binding->var, it->second);
+      var_remap_.insert({binding->var->vid, it->second});
 
-class CommonSubexprEliminator : public ExprMutator {
- public:
-  explicit CommonSubexprEliminator(
-      std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map,
-      bool call_only = false)
-      : count_map_(std::move(count_map)), call_only_(call_only) {}
+    } else {
+      VLOG(1) << "Value " << bound_value << " is bound to " << binding->var
+              << " and may be de-duplicated if it occurs again.";
 
-  // overriding here ensures we visit every subexpression
-  Expr VisitExpr(const Expr& e) override {
-    if (call_only_ && !e->IsInstance<CallNode>()) {
-      return ExprMutator::VisitExpr(e);
-    }
-    if (count_map_.count(e) && count_map_.at(e) > 1) {
-      // if we already have a mapping for it, get it
-      if (replacements_.count(e)) {
-        return replacements_.at(e);
-      }
-      // Otherwise, insert a new binding for the current expression.
-      // Visit before emitting to do inner replacements
-      Expr new_e = ExprMutator::VisitExpr(e);
-      Var v = builder_->Emit(new_e);
-      replacements_[e] = v;
-      return v;
+      expr_replacements_.insert({lookup_key, binding->var});
     }
-    return ExprMutator::VisitExpr(e);
-  }
 
-  // we are not going to do replacements inside struct info to avoid binding 
lots of reused shapes
-  StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) 
override {
-    return struct_info;
+    builder_->EmitNormalized(output_binding);
   }
 
   Expr VisitExpr_(const FunctionNode* op) override {
-    Function func = GetRef<Function>(op);
-
-    auto cache = SubexprCounter::Count(op->body);
-    std::swap(cache, count_map_);
-    Expr output = ExprMutator::VisitExpr_(op);
-    std::swap(cache, count_map_);
-
-    return output;
-  }
-
-  void VisitBinding_(const VarBindingNode* binding) override {
-    // no need to visit var def because the struct info isn't going to change
-    Expr new_value = RegisterBoundValue(binding->var, binding->value);
-
-    if (new_value.same_as(binding->value)) {
-      builder_->EmitNormalized(GetRef<VarBinding>(binding));
+    // If we have accumulated any state, visit the function in a fresh
+    // copy of the mutator, to avoid replacing a child-scope
+    // expression with a parent-scope binding, or vice versa.
+    if (expr_replacements_.size() || var_remap_.size()) {
+      return VisitWithCleanScope(GetRef<Expr>(op));

Review Comment:
   Hmm.  From a lowering standpoint, my main concern would be to avoid 
introducing a captured variable.  If a user has an inner function without 
closure variables, it could be surprising if CSE prevents it from being hoisted 
out of the local function.
   
   (Selfishly, I have a few plans to simplify `LambdaLift` when I have the 
time.  Currently, `LambdaLift` handles lifting of inner functions regardless of 
whether there are closure variables, which is a bit tricky to track.  I'd like 
to split that out into a `HoistClosureVariablesToParams` and 
`HoistInnerFunctions` passes.  The first would provide explicit arguments for 
any closure variables, and the second would hoist the inner function out to the 
`IRModule`.  If CSE is applied in-between those two passes, I'd like to avoid 
having it re-introduce closure variables.)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to