This is an automated email from the ASF dual-hosted git repository. liuyizhi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 00257f3 [Autodiff] Deterministic gradient compute (#7321) 00257f3 is described below commit 00257f347faad0b3ec2e9624413015bef34d451f Author: Haozheng Fan <hz...@apache.org> AuthorDate: Thu Jan 28 08:32:04 2021 +0800 [Autodiff] Deterministic gradient compute (#7321) * fix unstable compute * fix * fix * lint * sort linear equation * sort inequalities * fix * fix find * lint * fix find * lint --- src/arith/solve_linear_equation.cc | 9 +++--- src/arith/solve_linear_inequality.cc | 54 ++++++++++++++++++------------------ src/te/autodiff/ad_simplify.cc | 26 +++++++++-------- 3 files changed, 46 insertions(+), 43 deletions(-) diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 22bf736..d66e75d 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -427,11 +427,10 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // We have to transform ranges of the old variables into relations over new variables because // new ranges are not enough usually. - for (const auto& p : system_to_solve->ranges) { - const Var& old_var = p.first; - const Range& old_range = p.second; - if (old_to_new_map.count(old_var)) { - PrimExpr express_by_new_vars = old_to_new_map[old_var]; + for (const auto& old_var : system_to_solve->variables) { + if (system_to_solve->ranges.find(old_var) != system_to_solve->ranges.end()) { + const Range& old_range = system_to_solve->ranges.at(old_var); + PrimExpr express_by_new_vars = old_to_new_map.at(old_var); PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars); PrimExpr upper_cond = analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent); diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index f4de9ff..dd90448 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -94,11 +94,10 @@ struct ExprLess { } }; -void DebugPrint( - const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set, - const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& next_ineq_set, - const std::vector<PrimExpr>& rest, const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos, - const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) { +void DebugPrint(const std::vector<PrimExpr>& current_ineq_set, + const std::vector<PrimExpr>& next_ineq_set, const std::vector<PrimExpr>& rest, + const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos, + const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) { std::cout << "Current ineq set:\n["; for (auto& ineq : current_ineq_set) { std::cout << ineq << ", "; @@ -148,9 +147,12 @@ class NormalizeComparisons : public ExprMutator { arith::Analyzer analyzer_; }; -void AddInequality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* inequality_set, - const PrimExpr& new_ineq, Analyzer* analyzer) { - if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != inequality_set->end()) { +void AddInequality(std::vector<PrimExpr>* inequality_set, const PrimExpr& new_ineq, + Analyzer* analyzer) { + if (analyzer->CanProve(new_ineq) || + std::find_if(inequality_set->begin(), inequality_set->end(), [&](const PrimExpr& e) { + return StructuralEqual()(e, new_ineq); + }) != inequality_set->end()) { // redundant: follows from the vranges // or has already been added return; @@ -168,15 +170,13 @@ void AddInequality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> } } - inequality_set->insert(new_ineq); + inequality_set->push_back(new_ineq); } -void ClassifyByPolarity( - const Var& var, - const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set, - std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* next_ineq_set, - std::vector<PrimExpr>* rest, std::vector<std::pair<int64_t, PrimExpr>>* coef_pos, - std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) { +void ClassifyByPolarity(const Var& var, const std::vector<PrimExpr>& current_ineq_set, + std::vector<PrimExpr>* next_ineq_set, std::vector<PrimExpr>* rest, + std::vector<std::pair<int64_t, PrimExpr>>* coef_pos, + std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) { // Take formulas from current_ineq_set and classify them according to polarity wrt var // and store to coef_pos and coef_neg respectively. for (const PrimExpr& ineq : current_ineq_set) { @@ -218,14 +218,14 @@ void ClassifyByPolarity( } } -void MoveEquality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* upper_bounds, - std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* lower_bounds, - std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* equalities) { +void MoveEquality(std::vector<PrimExpr>* upper_bounds, std::vector<PrimExpr>* lower_bounds, + std::vector<PrimExpr>* equalities) { // those exist in both upper & lower bounds will be moved to equalities for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) { - auto lb = lower_bounds->find(*ub); + auto lb = std::find_if(lower_bounds->begin(), lower_bounds->end(), + [&](const PrimExpr& e) { return StructuralEqual()(e, *ub); }); if (lb != lower_bounds->end()) { - equalities->insert(*lb); + equalities->push_back(*lb); lower_bounds->erase(lb); ub = upper_bounds->erase(ub); } else { @@ -249,8 +249,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // and move to the next variable. // normalized inequality - std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> current_ineq_set_to_solve; - std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> next_ineq_set_to_solve; + std::vector<PrimExpr> current_ineq_set_to_solve; + std::vector<PrimExpr> next_ineq_set_to_solve; // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0 std::vector<std::pair<int64_t, PrimExpr>> coef_pos; // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0 @@ -321,8 +321,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t } // The resulting lower and upper bounds - std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> upper_bounds; - std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> lower_bounds; + std::vector<PrimExpr> upper_bounds; + std::vector<PrimExpr> lower_bounds; upper_bounds.reserve(coef_pos.size()); lower_bounds.reserve(coef_neg.size()); @@ -345,7 +345,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t } } // Add the upper bound - upper_bounds.insert(bound); + upper_bounds.push_back(bound); } for (const auto& neg : coef_neg) { PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second; @@ -366,10 +366,10 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t } } // Add the lower bound - lower_bounds.insert(bound); + lower_bounds.push_back(bound); } - std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> equal; + std::vector<PrimExpr> equal; equal.reserve(std::min(upper_bounds.size(), lower_bounds.size())); MoveEquality(&upper_bounds, &lower_bounds, &equal); std::vector<PrimExpr> equal_list(equal.begin(), equal.end()); diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc index cc0e820..96f278e 100644 --- a/src/te/autodiff/ad_simplify.cc +++ b/src/te/autodiff/ad_simplify.cc @@ -413,15 +413,17 @@ class FactorOutAtomicFormulasFunctor auto res_b = VisitExpr(op->b); // For the And case we return the union of the sets of atomic formulas - std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set; - res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set; + res_a_set.reserve(res_a.atomic_formulas.size()); std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), - std::inserter(res_set, res_set.end())); - std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), - std::inserter(res_set, res_set.end())); - - std::vector<PrimExpr> res{res_set.begin(), res_set.end()}; + std::inserter(res_a_set, res_a_set.end())); + std::vector<PrimExpr> res = res_a.atomic_formulas; + for (const auto& e : res_b.atomic_formulas) { + if (res_a_set.find(e) == res_a_set.end()) { + res.emplace_back(e); + } + } // And the residuals are combined with && return {res, res_a.rest && res_b.rest}; } @@ -443,10 +445,13 @@ class FactorOutAtomicFormulasFunctor // For the Or case we intersect the sets of atomic formulas std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set; + std::vector<PrimExpr> res; res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); - for (const auto& res_b_formula : res_b_set) { + res.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); + for (const auto& res_b_formula : res_b.atomic_formulas) { if (res_a_set.count(res_b_formula)) { res_set.insert(res_b_formula); + res.push_back(res_b_formula); } } @@ -454,13 +459,13 @@ class FactorOutAtomicFormulasFunctor // which are left behind, and then combine them with the residuals into the new residual. std::vector<PrimExpr> new_cond_a; new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size()); - for (const auto& formula : res_a_set) { + for (const auto& formula : res_a.atomic_formulas) { if (!res_set.count(formula)) new_cond_a.emplace_back(formula); } std::vector<PrimExpr> new_cond_b; new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size()); - for (const auto& formula : res_b_set) { + for (const auto& formula : res_b.atomic_formulas) { if (!res_set.count(formula)) new_cond_b.emplace_back(formula); } @@ -468,7 +473,6 @@ class FactorOutAtomicFormulasFunctor res_b.atomic_formulas = std::move(new_cond_b); PrimExpr new_rest = res_a.to_expr() || res_b.to_expr(); - std::vector<PrimExpr> res{res_set.begin(), res_set.end()}; return {res, new_rest}; }