slyubomirsky commented on code in PR #16599: URL: https://github.com/apache/tvm/pull/16599#discussion_r1498553103
########## src/relax/transform/eliminate_common_subexpr.cc: ########## @@ -20,223 +20,162 @@ /*! * \file tvm/relax/transform/eliminate_common_subexpr.cc - * \brief Eliminrate common subexpression pass. + * \brief Eliminate common subexpression pass. * * Currently it removes common subexpressions within a Function. */ +#include <tvm/relax/analysis.h> #include <tvm/relax/expr_functor.h> #include <tvm/relax/transform.h> #include <tvm/relax/utils.h> -#include "utils.h" +#include "../../support/utils.h" namespace tvm { namespace relax { - -// Checks if a given expression contains an impure subexpression -// Caches the results of checks to avoid revisiting subexpressions -class ImpurityDetector : public ExprVisitor { - public: - bool Detect(const Expr& expr) { - impure_found_ = false; - VisitExpr(expr); - return impure_found_; +namespace { +/* \brief Lookup key for subexpression replacements + * + * The lookup key must contain the expression being bound, along with + * the struct info used for a match cast, if applicable. Using + * `MatchCast` with StructuralEqual and StructuralHash would be almost + * correct, but acts as a point of definition for symbolic variables + * within the output struct info. As a result, it would erroneously + * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and + * `R.match_cast(A, R.Tensor([p,q]))`, even though they define + * different symbolic variables. + */ +struct ReplacementKey { + tvm::relax::Expr bound_value; + tvm::Optional<tvm::relax::StructInfo> match_cast = tvm::NullOpt; + + explicit ReplacementKey(const tvm::relax::Binding& binding) + : bound_value(GetBoundValue(binding)) { + if (const auto* ptr = binding.as<tvm::relax::MatchCastNode>()) { + match_cast = ptr->struct_info; + } } - void VisitExpr(const Expr& expr) { - // already checked: do not revisit - if (purity_map_.count(expr)) { - impure_found_ = impure_found_ || !purity_map_.at(expr); - return; - } + friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) { + tvm::StructuralEqual eq; + return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast); + } +}; - // in principle, we could stop checking once we find an impurity, - // but not doing so lets us fully populate the cache +} // namespace +} // namespace relax +} // namespace tvm - // store the previous state so we could assess the purity of this subexpression alone - bool prev_state = impure_found_; - impure_found_ = false; - ExprVisitor::VisitExpr(expr); - // if impure_found_ remains false, then the expression is pure - purity_map_[expr] = !impure_found_; - impure_found_ = prev_state || impure_found_; +/* \brief Definition of std::hash<ReplacementKey> + * + * Specialization of std::hash must occur outside of tvm::relax + * namespace, and before its usage in the constructor of + * `CommonSubexprEliminator`. + */ +template <> +struct std::hash<tvm::relax::ReplacementKey> { + std::size_t operator()(const tvm::relax::ReplacementKey& key) const { + tvm::StructuralHash hasher; + return tvm::support::HashCombine(hasher(key.bound_value), hasher(key.match_cast)); } +}; - void VisitExpr_(const CallNode* call) { - // the only possible impurities can come from call nodes - bool is_impure = IsImpureCall(GetRef<Call>(call)); - impure_found_ = impure_found_ || is_impure; - ExprVisitor::VisitExpr_(call); - } +namespace tvm { +namespace relax { - private: - bool impure_found_ = false; - std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_; -}; +namespace { -class SubexprCounter : public ExprVisitor { +class CommonSubexprEliminator : public ExprMutator { public: - static std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const Expr& expr) { - SubexprCounter visitor; - visitor(expr); - return visitor.count_map_; + explicit CommonSubexprEliminator(bool call_only = false) : call_only_(call_only) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override { + auto cache_exprs = expr_replacements_; + auto cache_vars = var_remap_; + auto output = ExprMutator::VisitBindingBlock_(block); + expr_replacements_ = cache_exprs; + var_remap_ = cache_vars; + return output; } - // overriding VisitExpr ensures we do this for every subexpression - void VisitExpr(const Expr& e) override { - // Cases we ignore because we will not substitute them: - // 1. Vars of all kinds - // 2. Op nodes (nothing we can do) - // 3. PrimValue nodes (not much benefit from binding to a var) - // 4. StringImm nodes (not much benefit from binding to a var) - // 5. Scalar constants (not much benefit from binding to a var) - // 6. Shape expressions (exist to hold several PrimValue objects) - // 7. DataType nodes (no need to modify dtype nodes) - if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() || - e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() || - e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() || - e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() || - e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) { - // also if e has an impure subexpression, we will not deduplicate it - if (!impurity_detector_.Detect(e)) { - int count = 0; - if (count_map_.count(e)) { - count = count_map_.at(e); - } - count_map_[e] = count + 1; + void VisitBinding(const Binding& binding) override { + Expr bound_value = VisitExpr(GetBoundValue(binding)); + + Binding output_binding = [&]() -> Binding { + if (binding.as<VarBindingNode>()) { + return VarBinding(binding->var, bound_value); + } else if (auto match_cast = binding.as<MatchCastNode>()) { + return MatchCast(binding->var, bound_value, match_cast->struct_info); + } else { + LOG(FATAL) << "Binding must be either VarBinding or MatchCast, " + << "but was " << binding->GetTypeKey(); } - } + }(); - // Only visit the interior of objects that we might still keep - // around. Otherwise, double-counting these would lead to extra - // variable bindings. - // - // Before: - // y = f(a+b) - // z = f(a+b) - // - // Expected: - // y = f(a+b) // De-duped from (y==z) - // z = y - // - // Erroneous output: - // c = a+b // Incorrect, a+b only has a single usage. - // y = f(c) // De-duped from - // z = y - // - if (auto it = count_map_.find(e); it == count_map_.end() || it->second < 2) { - ExprVisitor::VisitExpr(e); - } - } + ReplacementKey lookup_key(output_binding); - // do not visit inner functions: we will do CSE within those - void VisitExpr_(const FunctionNode* func) override {} + if (call_only_ && !bound_value->IsInstance<relax::CallNode>()) { + VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate " << bound_value; - // we are not going to do replacements inside struct info to avoid binding lots of reused shapes - void VisitExprDepStructInfoField(const StructInfo& struct_info) override {} + } else if (ContainsImpureCall(bound_value)) { + VLOG(1) << "Since the expression is impure, cannot de-duplicate " << bound_value; - private: - std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_; - ImpurityDetector impurity_detector_; -}; + } else if (auto it = expr_replacements_.find(lookup_key); it != expr_replacements_.end()) { + VLOG(1) << "Value " << bound_value << " has previously been bound as " << it->second + << ". The duplicate binding of this value to " << binding->var + << " will be replaced with a trivial binding, " + << "and occurrences of " << binding->var << " will be replaced with " << it->second; + output_binding = VarBinding(binding->var, it->second); Review Comment: Why do we still need a binding if we will be remapping the uses of the var? Not a big deal since `CanonicalizeBindings` and `DeadCodeElimination` would clean it up. ########## src/relax/transform/eliminate_common_subexpr.cc: ########## @@ -20,223 +20,162 @@ /*! * \file tvm/relax/transform/eliminate_common_subexpr.cc - * \brief Eliminrate common subexpression pass. + * \brief Eliminate common subexpression pass. * * Currently it removes common subexpressions within a Function. */ +#include <tvm/relax/analysis.h> #include <tvm/relax/expr_functor.h> #include <tvm/relax/transform.h> #include <tvm/relax/utils.h> -#include "utils.h" +#include "../../support/utils.h" namespace tvm { namespace relax { - -// Checks if a given expression contains an impure subexpression -// Caches the results of checks to avoid revisiting subexpressions -class ImpurityDetector : public ExprVisitor { - public: - bool Detect(const Expr& expr) { - impure_found_ = false; - VisitExpr(expr); - return impure_found_; +namespace { +/* \brief Lookup key for subexpression replacements + * + * The lookup key must contain the expression being bound, along with + * the struct info used for a match cast, if applicable. Using + * `MatchCast` with StructuralEqual and StructuralHash would be almost + * correct, but acts as a point of definition for symbolic variables + * within the output struct info. As a result, it would erroneously + * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and + * `R.match_cast(A, R.Tensor([p,q]))`, even though they define + * different symbolic variables. + */ +struct ReplacementKey { + tvm::relax::Expr bound_value; + tvm::Optional<tvm::relax::StructInfo> match_cast = tvm::NullOpt; + + explicit ReplacementKey(const tvm::relax::Binding& binding) + : bound_value(GetBoundValue(binding)) { + if (const auto* ptr = binding.as<tvm::relax::MatchCastNode>()) { + match_cast = ptr->struct_info; + } } - void VisitExpr(const Expr& expr) { - // already checked: do not revisit - if (purity_map_.count(expr)) { - impure_found_ = impure_found_ || !purity_map_.at(expr); - return; - } + friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) { + tvm::StructuralEqual eq; + return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast); + } +}; - // in principle, we could stop checking once we find an impurity, - // but not doing so lets us fully populate the cache +} // namespace +} // namespace relax +} // namespace tvm - // store the previous state so we could assess the purity of this subexpression alone - bool prev_state = impure_found_; - impure_found_ = false; - ExprVisitor::VisitExpr(expr); - // if impure_found_ remains false, then the expression is pure - purity_map_[expr] = !impure_found_; - impure_found_ = prev_state || impure_found_; +/* \brief Definition of std::hash<ReplacementKey> + * + * Specialization of std::hash must occur outside of tvm::relax + * namespace, and before its usage in the constructor of + * `CommonSubexprEliminator`. + */ +template <> +struct std::hash<tvm::relax::ReplacementKey> { + std::size_t operator()(const tvm::relax::ReplacementKey& key) const { + tvm::StructuralHash hasher; + return tvm::support::HashCombine(hasher(key.bound_value), hasher(key.match_cast)); } +}; - void VisitExpr_(const CallNode* call) { - // the only possible impurities can come from call nodes - bool is_impure = IsImpureCall(GetRef<Call>(call)); - impure_found_ = impure_found_ || is_impure; - ExprVisitor::VisitExpr_(call); - } +namespace tvm { +namespace relax { - private: - bool impure_found_ = false; - std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_; -}; +namespace { -class SubexprCounter : public ExprVisitor { +class CommonSubexprEliminator : public ExprMutator { public: - static std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const Expr& expr) { - SubexprCounter visitor; - visitor(expr); - return visitor.count_map_; + explicit CommonSubexprEliminator(bool call_only = false) : call_only_(call_only) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override { + auto cache_exprs = expr_replacements_; + auto cache_vars = var_remap_; + auto output = ExprMutator::VisitBindingBlock_(block); + expr_replacements_ = cache_exprs; + var_remap_ = cache_vars; + return output; } - // overriding VisitExpr ensures we do this for every subexpression - void VisitExpr(const Expr& e) override { - // Cases we ignore because we will not substitute them: - // 1. Vars of all kinds - // 2. Op nodes (nothing we can do) - // 3. PrimValue nodes (not much benefit from binding to a var) - // 4. StringImm nodes (not much benefit from binding to a var) - // 5. Scalar constants (not much benefit from binding to a var) - // 6. Shape expressions (exist to hold several PrimValue objects) - // 7. DataType nodes (no need to modify dtype nodes) - if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() || - e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() || - e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() || - e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() || - e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) { - // also if e has an impure subexpression, we will not deduplicate it - if (!impurity_detector_.Detect(e)) { - int count = 0; - if (count_map_.count(e)) { - count = count_map_.at(e); - } - count_map_[e] = count + 1; + void VisitBinding(const Binding& binding) override { + Expr bound_value = VisitExpr(GetBoundValue(binding)); + + Binding output_binding = [&]() -> Binding { + if (binding.as<VarBindingNode>()) { + return VarBinding(binding->var, bound_value); + } else if (auto match_cast = binding.as<MatchCastNode>()) { + return MatchCast(binding->var, bound_value, match_cast->struct_info); + } else { + LOG(FATAL) << "Binding must be either VarBinding or MatchCast, " + << "but was " << binding->GetTypeKey(); } - } + }(); - // Only visit the interior of objects that we might still keep - // around. Otherwise, double-counting these would lead to extra - // variable bindings. - // - // Before: - // y = f(a+b) - // z = f(a+b) - // - // Expected: - // y = f(a+b) // De-duped from (y==z) - // z = y - // - // Erroneous output: - // c = a+b // Incorrect, a+b only has a single usage. - // y = f(c) // De-duped from - // z = y - // - if (auto it = count_map_.find(e); it == count_map_.end() || it->second < 2) { - ExprVisitor::VisitExpr(e); - } - } + ReplacementKey lookup_key(output_binding); - // do not visit inner functions: we will do CSE within those - void VisitExpr_(const FunctionNode* func) override {} + if (call_only_ && !bound_value->IsInstance<relax::CallNode>()) { + VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate " << bound_value; - // we are not going to do replacements inside struct info to avoid binding lots of reused shapes - void VisitExprDepStructInfoField(const StructInfo& struct_info) override {} + } else if (ContainsImpureCall(bound_value)) { + VLOG(1) << "Since the expression is impure, cannot de-duplicate " << bound_value; - private: - std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_; - ImpurityDetector impurity_detector_; -}; + } else if (auto it = expr_replacements_.find(lookup_key); it != expr_replacements_.end()) { + VLOG(1) << "Value " << bound_value << " has previously been bound as " << it->second + << ". The duplicate binding of this value to " << binding->var + << " will be replaced with a trivial binding, " + << "and occurrences of " << binding->var << " will be replaced with " << it->second; + output_binding = VarBinding(binding->var, it->second); + var_remap_.insert({binding->var->vid, it->second}); -class CommonSubexprEliminator : public ExprMutator { - public: - explicit CommonSubexprEliminator( - std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map, - bool call_only = false) - : count_map_(std::move(count_map)), call_only_(call_only) {} + } else { + VLOG(1) << "Value " << bound_value << " is bound to " << binding->var + << " and may be de-duplicated if it occurs again."; - // overriding here ensures we visit every subexpression - Expr VisitExpr(const Expr& e) override { - if (call_only_ && !e->IsInstance<CallNode>()) { - return ExprMutator::VisitExpr(e); - } - if (count_map_.count(e) && count_map_.at(e) > 1) { - // if we already have a mapping for it, get it - if (replacements_.count(e)) { - return replacements_.at(e); - } - // Otherwise, insert a new binding for the current expression. - // Visit before emitting to do inner replacements - Expr new_e = ExprMutator::VisitExpr(e); - Var v = builder_->Emit(new_e); - replacements_[e] = v; - return v; + expr_replacements_.insert({lookup_key, binding->var}); } - return ExprMutator::VisitExpr(e); - } - // we are not going to do replacements inside struct info to avoid binding lots of reused shapes - StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override { - return struct_info; + builder_->EmitNormalized(output_binding); } Expr VisitExpr_(const FunctionNode* op) override { - Function func = GetRef<Function>(op); - - auto cache = SubexprCounter::Count(op->body); - std::swap(cache, count_map_); - Expr output = ExprMutator::VisitExpr_(op); - std::swap(cache, count_map_); - - return output; - } - - void VisitBinding_(const VarBindingNode* binding) override { - // no need to visit var def because the struct info isn't going to change - Expr new_value = RegisterBoundValue(binding->var, binding->value); - - if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef<VarBinding>(binding)); + // If we have accumulated any state, visit the function in a fresh + // copy of the mutator, to avoid replacing a child-scope + // expression with a parent-scope binding, or vice versa. + if (expr_replacements_.size() || var_remap_.size()) { + return VisitWithCleanScope(GetRef<Expr>(op)); } else { - // no need to renormalize new_value because all replacements are with vars - builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span)); + return ExprMutator::VisitExpr_(op); } } - void VisitBinding_(const MatchCastNode* binding) override { - // no need to visit var def because the struct info isn't going to change - Expr new_value = RegisterBoundValue(binding->var, binding->value); - - // re-emit old binding if nothing changes - if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef<MatchCast>(binding)); + Expr VisitExpr_(const IfNode* op) override { + Expr cond = VisitExpr(op->cond); + Expr true_branch = VisitWithCleanScope(op->true_branch); Review Comment: I don't think visiting with a clean scope is the correct approach, since there are potential correct replacements that could come from the outer scope. I think the right approach would be to cache the current state and restore it after visiting. ########## src/relax/transform/eliminate_common_subexpr.cc: ########## @@ -20,223 +20,162 @@ /*! * \file tvm/relax/transform/eliminate_common_subexpr.cc - * \brief Eliminrate common subexpression pass. + * \brief Eliminate common subexpression pass. * * Currently it removes common subexpressions within a Function. */ +#include <tvm/relax/analysis.h> #include <tvm/relax/expr_functor.h> #include <tvm/relax/transform.h> #include <tvm/relax/utils.h> -#include "utils.h" +#include "../../support/utils.h" namespace tvm { namespace relax { - -// Checks if a given expression contains an impure subexpression -// Caches the results of checks to avoid revisiting subexpressions -class ImpurityDetector : public ExprVisitor { - public: - bool Detect(const Expr& expr) { - impure_found_ = false; - VisitExpr(expr); - return impure_found_; +namespace { +/* \brief Lookup key for subexpression replacements + * + * The lookup key must contain the expression being bound, along with + * the struct info used for a match cast, if applicable. Using + * `MatchCast` with StructuralEqual and StructuralHash would be almost + * correct, but acts as a point of definition for symbolic variables + * within the output struct info. As a result, it would erroneously + * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and + * `R.match_cast(A, R.Tensor([p,q]))`, even though they define + * different symbolic variables. + */ +struct ReplacementKey { + tvm::relax::Expr bound_value; + tvm::Optional<tvm::relax::StructInfo> match_cast = tvm::NullOpt; + + explicit ReplacementKey(const tvm::relax::Binding& binding) + : bound_value(GetBoundValue(binding)) { + if (const auto* ptr = binding.as<tvm::relax::MatchCastNode>()) { + match_cast = ptr->struct_info; + } } - void VisitExpr(const Expr& expr) { - // already checked: do not revisit - if (purity_map_.count(expr)) { - impure_found_ = impure_found_ || !purity_map_.at(expr); - return; - } + friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) { + tvm::StructuralEqual eq; + return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast); + } +}; - // in principle, we could stop checking once we find an impurity, - // but not doing so lets us fully populate the cache +} // namespace +} // namespace relax +} // namespace tvm - // store the previous state so we could assess the purity of this subexpression alone - bool prev_state = impure_found_; - impure_found_ = false; - ExprVisitor::VisitExpr(expr); - // if impure_found_ remains false, then the expression is pure - purity_map_[expr] = !impure_found_; - impure_found_ = prev_state || impure_found_; +/* \brief Definition of std::hash<ReplacementKey> + * + * Specialization of std::hash must occur outside of tvm::relax + * namespace, and before its usage in the constructor of + * `CommonSubexprEliminator`. + */ +template <> +struct std::hash<tvm::relax::ReplacementKey> { + std::size_t operator()(const tvm::relax::ReplacementKey& key) const { + tvm::StructuralHash hasher; + return tvm::support::HashCombine(hasher(key.bound_value), hasher(key.match_cast)); } +}; - void VisitExpr_(const CallNode* call) { - // the only possible impurities can come from call nodes - bool is_impure = IsImpureCall(GetRef<Call>(call)); - impure_found_ = impure_found_ || is_impure; - ExprVisitor::VisitExpr_(call); - } +namespace tvm { +namespace relax { - private: - bool impure_found_ = false; - std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_; -}; +namespace { -class SubexprCounter : public ExprVisitor { +class CommonSubexprEliminator : public ExprMutator { public: - static std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const Expr& expr) { - SubexprCounter visitor; - visitor(expr); - return visitor.count_map_; + explicit CommonSubexprEliminator(bool call_only = false) : call_only_(call_only) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override { + auto cache_exprs = expr_replacements_; + auto cache_vars = var_remap_; + auto output = ExprMutator::VisitBindingBlock_(block); + expr_replacements_ = cache_exprs; + var_remap_ = cache_vars; + return output; } - // overriding VisitExpr ensures we do this for every subexpression - void VisitExpr(const Expr& e) override { - // Cases we ignore because we will not substitute them: - // 1. Vars of all kinds - // 2. Op nodes (nothing we can do) - // 3. PrimValue nodes (not much benefit from binding to a var) - // 4. StringImm nodes (not much benefit from binding to a var) - // 5. Scalar constants (not much benefit from binding to a var) - // 6. Shape expressions (exist to hold several PrimValue objects) - // 7. DataType nodes (no need to modify dtype nodes) - if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() || - e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() || - e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() || - e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() || - e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) { - // also if e has an impure subexpression, we will not deduplicate it - if (!impurity_detector_.Detect(e)) { - int count = 0; - if (count_map_.count(e)) { - count = count_map_.at(e); - } - count_map_[e] = count + 1; + void VisitBinding(const Binding& binding) override { + Expr bound_value = VisitExpr(GetBoundValue(binding)); + + Binding output_binding = [&]() -> Binding { + if (binding.as<VarBindingNode>()) { + return VarBinding(binding->var, bound_value); + } else if (auto match_cast = binding.as<MatchCastNode>()) { + return MatchCast(binding->var, bound_value, match_cast->struct_info); + } else { + LOG(FATAL) << "Binding must be either VarBinding or MatchCast, " + << "but was " << binding->GetTypeKey(); } - } + }(); - // Only visit the interior of objects that we might still keep - // around. Otherwise, double-counting these would lead to extra - // variable bindings. - // - // Before: - // y = f(a+b) - // z = f(a+b) - // - // Expected: - // y = f(a+b) // De-duped from (y==z) - // z = y - // - // Erroneous output: - // c = a+b // Incorrect, a+b only has a single usage. - // y = f(c) // De-duped from - // z = y - // - if (auto it = count_map_.find(e); it == count_map_.end() || it->second < 2) { - ExprVisitor::VisitExpr(e); - } - } + ReplacementKey lookup_key(output_binding); - // do not visit inner functions: we will do CSE within those - void VisitExpr_(const FunctionNode* func) override {} + if (call_only_ && !bound_value->IsInstance<relax::CallNode>()) { + VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate " << bound_value; - // we are not going to do replacements inside struct info to avoid binding lots of reused shapes - void VisitExprDepStructInfoField(const StructInfo& struct_info) override {} + } else if (ContainsImpureCall(bound_value)) { + VLOG(1) << "Since the expression is impure, cannot de-duplicate " << bound_value; - private: - std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_; - ImpurityDetector impurity_detector_; -}; + } else if (auto it = expr_replacements_.find(lookup_key); it != expr_replacements_.end()) { + VLOG(1) << "Value " << bound_value << " has previously been bound as " << it->second + << ". The duplicate binding of this value to " << binding->var + << " will be replaced with a trivial binding, " + << "and occurrences of " << binding->var << " will be replaced with " << it->second; + output_binding = VarBinding(binding->var, it->second); + var_remap_.insert({binding->var->vid, it->second}); -class CommonSubexprEliminator : public ExprMutator { - public: - explicit CommonSubexprEliminator( - std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map, - bool call_only = false) - : count_map_(std::move(count_map)), call_only_(call_only) {} + } else { + VLOG(1) << "Value " << bound_value << " is bound to " << binding->var + << " and may be de-duplicated if it occurs again."; - // overriding here ensures we visit every subexpression - Expr VisitExpr(const Expr& e) override { - if (call_only_ && !e->IsInstance<CallNode>()) { - return ExprMutator::VisitExpr(e); - } - if (count_map_.count(e) && count_map_.at(e) > 1) { - // if we already have a mapping for it, get it - if (replacements_.count(e)) { - return replacements_.at(e); - } - // Otherwise, insert a new binding for the current expression. - // Visit before emitting to do inner replacements - Expr new_e = ExprMutator::VisitExpr(e); - Var v = builder_->Emit(new_e); - replacements_[e] = v; - return v; + expr_replacements_.insert({lookup_key, binding->var}); } - return ExprMutator::VisitExpr(e); - } - // we are not going to do replacements inside struct info to avoid binding lots of reused shapes - StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override { - return struct_info; + builder_->EmitNormalized(output_binding); } Expr VisitExpr_(const FunctionNode* op) override { - Function func = GetRef<Function>(op); - - auto cache = SubexprCounter::Count(op->body); - std::swap(cache, count_map_); - Expr output = ExprMutator::VisitExpr_(op); - std::swap(cache, count_map_); - - return output; - } - - void VisitBinding_(const VarBindingNode* binding) override { - // no need to visit var def because the struct info isn't going to change - Expr new_value = RegisterBoundValue(binding->var, binding->value); - - if (new_value.same_as(binding->value)) { - builder_->EmitNormalized(GetRef<VarBinding>(binding)); + // If we have accumulated any state, visit the function in a fresh + // copy of the mutator, to avoid replacing a child-scope + // expression with a parent-scope binding, or vice versa. + if (expr_replacements_.size() || var_remap_.size()) { + return VisitWithCleanScope(GetRef<Expr>(op)); Review Comment: I'm not sure using a clean scope in all cases is necessarily, since an inner function could capture vars from the outer scope and there might be legitimate substitutions that are possible. That said, this may not necessarily be desirable behavior since it could result in bigger closures (capturing more vars than you might want). We would also have to make sure we don't capture any DataflowVars, since that's not permitted. Any thoughts? ########## tests/python/relax/test_transform_cse.py: ########## @@ -90,6 +88,12 @@ def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32" def test_repeated_inner_tuples(): + """CSE is only applied at variable bindings + + To remain consistent with the behavior of the normalizer, tuples + are kept as-is, even if they contain repeated sub-tuples. + """ Review Comment: Good observation, probably a good choice. ########## src/relax/transform/eliminate_common_subexpr.cc: ########## @@ -20,223 +20,162 @@ /*! * \file tvm/relax/transform/eliminate_common_subexpr.cc - * \brief Eliminrate common subexpression pass. + * \brief Eliminate common subexpression pass. * * Currently it removes common subexpressions within a Function. */ +#include <tvm/relax/analysis.h> #include <tvm/relax/expr_functor.h> #include <tvm/relax/transform.h> #include <tvm/relax/utils.h> -#include "utils.h" +#include "../../support/utils.h" namespace tvm { namespace relax { - -// Checks if a given expression contains an impure subexpression -// Caches the results of checks to avoid revisiting subexpressions -class ImpurityDetector : public ExprVisitor { - public: - bool Detect(const Expr& expr) { - impure_found_ = false; - VisitExpr(expr); - return impure_found_; +namespace { +/* \brief Lookup key for subexpression replacements + * + * The lookup key must contain the expression being bound, along with + * the struct info used for a match cast, if applicable. Using + * `MatchCast` with StructuralEqual and StructuralHash would be almost + * correct, but acts as a point of definition for symbolic variables + * within the output struct info. As a result, it would erroneously + * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and + * `R.match_cast(A, R.Tensor([p,q]))`, even though they define + * different symbolic variables. Review Comment: Does it still do the wrong thing even if you set `map_free_vars` to false? ########## src/relax/transform/eliminate_common_subexpr.cc: ########## @@ -20,223 +20,162 @@ /*! * \file tvm/relax/transform/eliminate_common_subexpr.cc - * \brief Eliminrate common subexpression pass. + * \brief Eliminate common subexpression pass. * * Currently it removes common subexpressions within a Function. */ +#include <tvm/relax/analysis.h> #include <tvm/relax/expr_functor.h> #include <tvm/relax/transform.h> #include <tvm/relax/utils.h> -#include "utils.h" +#include "../../support/utils.h" namespace tvm { namespace relax { - -// Checks if a given expression contains an impure subexpression -// Caches the results of checks to avoid revisiting subexpressions -class ImpurityDetector : public ExprVisitor { - public: - bool Detect(const Expr& expr) { - impure_found_ = false; - VisitExpr(expr); - return impure_found_; +namespace { +/* \brief Lookup key for subexpression replacements + * + * The lookup key must contain the expression being bound, along with + * the struct info used for a match cast, if applicable. Using + * `MatchCast` with StructuralEqual and StructuralHash would be almost + * correct, but acts as a point of definition for symbolic variables + * within the output struct info. As a result, it would erroneously + * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and + * `R.match_cast(A, R.Tensor([p,q]))`, even though they define + * different symbolic variables. + */ +struct ReplacementKey { + tvm::relax::Expr bound_value; + tvm::Optional<tvm::relax::StructInfo> match_cast = tvm::NullOpt; + + explicit ReplacementKey(const tvm::relax::Binding& binding) + : bound_value(GetBoundValue(binding)) { + if (const auto* ptr = binding.as<tvm::relax::MatchCastNode>()) { + match_cast = ptr->struct_info; + } } - void VisitExpr(const Expr& expr) { - // already checked: do not revisit - if (purity_map_.count(expr)) { - impure_found_ = impure_found_ || !purity_map_.at(expr); - return; - } + friend bool operator==(const ReplacementKey& a, const ReplacementKey& b) { + tvm::StructuralEqual eq; + return eq(a.bound_value, b.bound_value) && eq(a.match_cast, b.match_cast); + } +}; - // in principle, we could stop checking once we find an impurity, - // but not doing so lets us fully populate the cache +} // namespace +} // namespace relax +} // namespace tvm - // store the previous state so we could assess the purity of this subexpression alone - bool prev_state = impure_found_; - impure_found_ = false; - ExprVisitor::VisitExpr(expr); - // if impure_found_ remains false, then the expression is pure - purity_map_[expr] = !impure_found_; - impure_found_ = prev_state || impure_found_; +/* \brief Definition of std::hash<ReplacementKey> + * + * Specialization of std::hash must occur outside of tvm::relax + * namespace, and before its usage in the constructor of + * `CommonSubexprEliminator`. + */ +template <> +struct std::hash<tvm::relax::ReplacementKey> { + std::size_t operator()(const tvm::relax::ReplacementKey& key) const { + tvm::StructuralHash hasher; + return tvm::support::HashCombine(hasher(key.bound_value), hasher(key.match_cast)); } +}; - void VisitExpr_(const CallNode* call) { - // the only possible impurities can come from call nodes - bool is_impure = IsImpureCall(GetRef<Call>(call)); - impure_found_ = impure_found_ || is_impure; - ExprVisitor::VisitExpr_(call); - } +namespace tvm { +namespace relax { - private: - bool impure_found_ = false; - std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_; -}; +namespace { -class SubexprCounter : public ExprVisitor { +class CommonSubexprEliminator : public ExprMutator { public: - static std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const Expr& expr) { - SubexprCounter visitor; - visitor(expr); - return visitor.count_map_; + explicit CommonSubexprEliminator(bool call_only = false) : call_only_(call_only) {} + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override { + auto cache_exprs = expr_replacements_; + auto cache_vars = var_remap_; + auto output = ExprMutator::VisitBindingBlock_(block); + expr_replacements_ = cache_exprs; + var_remap_ = cache_vars; + return output; } - // overriding VisitExpr ensures we do this for every subexpression - void VisitExpr(const Expr& e) override { - // Cases we ignore because we will not substitute them: - // 1. Vars of all kinds - // 2. Op nodes (nothing we can do) - // 3. PrimValue nodes (not much benefit from binding to a var) - // 4. StringImm nodes (not much benefit from binding to a var) - // 5. Scalar constants (not much benefit from binding to a var) - // 6. Shape expressions (exist to hold several PrimValue objects) - // 7. DataType nodes (no need to modify dtype nodes) - if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() || - e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() || - e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() || - e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() || - e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) { - // also if e has an impure subexpression, we will not deduplicate it - if (!impurity_detector_.Detect(e)) { - int count = 0; - if (count_map_.count(e)) { - count = count_map_.at(e); - } - count_map_[e] = count + 1; + void VisitBinding(const Binding& binding) override { + Expr bound_value = VisitExpr(GetBoundValue(binding)); + + Binding output_binding = [&]() -> Binding { + if (binding.as<VarBindingNode>()) { + return VarBinding(binding->var, bound_value); + } else if (auto match_cast = binding.as<MatchCastNode>()) { + return MatchCast(binding->var, bound_value, match_cast->struct_info); + } else { + LOG(FATAL) << "Binding must be either VarBinding or MatchCast, " + << "but was " << binding->GetTypeKey(); } - } + }(); - // Only visit the interior of objects that we might still keep - // around. Otherwise, double-counting these would lead to extra - // variable bindings. - // - // Before: - // y = f(a+b) - // z = f(a+b) - // - // Expected: - // y = f(a+b) // De-duped from (y==z) - // z = y - // - // Erroneous output: - // c = a+b // Incorrect, a+b only has a single usage. - // y = f(c) // De-duped from - // z = y - // - if (auto it = count_map_.find(e); it == count_map_.end() || it->second < 2) { - ExprVisitor::VisitExpr(e); - } - } + ReplacementKey lookup_key(output_binding); - // do not visit inner functions: we will do CSE within those - void VisitExpr_(const FunctionNode* func) override {} + if (call_only_ && !bound_value->IsInstance<relax::CallNode>()) { + VLOG(1) << "Since call_only_ is true, it is forbidden to de-duplicate " << bound_value; - // we are not going to do replacements inside struct info to avoid binding lots of reused shapes - void VisitExprDepStructInfoField(const StructInfo& struct_info) override {} + } else if (ContainsImpureCall(bound_value)) { + VLOG(1) << "Since the expression is impure, cannot de-duplicate " << bound_value; - private: - std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_; - ImpurityDetector impurity_detector_; -}; + } else if (auto it = expr_replacements_.find(lookup_key); it != expr_replacements_.end()) { + VLOG(1) << "Value " << bound_value << " has previously been bound as " << it->second + << ". The duplicate binding of this value to " << binding->var + << " will be replaced with a trivial binding, " + << "and occurrences of " << binding->var << " will be replaced with " << it->second; + output_binding = VarBinding(binding->var, it->second); + var_remap_.insert({binding->var->vid, it->second}); -class CommonSubexprEliminator : public ExprMutator { - public: - explicit CommonSubexprEliminator( - std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map, - bool call_only = false) - : count_map_(std::move(count_map)), call_only_(call_only) {} + } else { + VLOG(1) << "Value " << bound_value << " is bound to " << binding->var + << " and may be de-duplicated if it occurs again."; Review Comment: Should we keep in all the logging statements for non-error cases? As long as the log is off by default, it's probably fine to keep them for debugging. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org