This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 00257f3  [Autodiff] Deterministic gradient compute (#7321)
00257f3 is described below

commit 00257f347faad0b3ec2e9624413015bef34d451f
Author: Haozheng Fan <hz...@apache.org>
AuthorDate: Thu Jan 28 08:32:04 2021 +0800

    [Autodiff] Deterministic gradient compute (#7321)
    
    * fix unstable compute
    
    * fix
    
    * fix
    
    * lint
    
    * sort linear equation
    
    * sort inequalities
    
    * fix
    
    * fix find
    
    * lint
    
    * fix find
    
    * lint
---
 src/arith/solve_linear_equation.cc   |  9 +++---
 src/arith/solve_linear_inequality.cc | 54 ++++++++++++++++++------------------
 src/te/autodiff/ad_simplify.cc       | 26 +++++++++--------
 3 files changed, 46 insertions(+), 43 deletions(-)

diff --git a/src/arith/solve_linear_equation.cc 
b/src/arith/solve_linear_equation.cc
index 22bf736..d66e75d 100644
--- a/src/arith/solve_linear_equation.cc
+++ b/src/arith/solve_linear_equation.cc
@@ -427,11 +427,10 @@ IntConstraintsTransform SolveLinearEquations(const 
IntConstraints& system_to_sol
 
   // We have to transform ranges of the old variables into relations over new 
variables because
   // new ranges are not enough usually.
-  for (const auto& p : system_to_solve->ranges) {
-    const Var& old_var = p.first;
-    const Range& old_range = p.second;
-    if (old_to_new_map.count(old_var)) {
-      PrimExpr express_by_new_vars = old_to_new_map[old_var];
+  for (const auto& old_var : system_to_solve->variables) {
+    if (system_to_solve->ranges.find(old_var) != 
system_to_solve->ranges.end()) {
+      const Range& old_range = system_to_solve->ranges.at(old_var);
+      PrimExpr express_by_new_vars = old_to_new_map.at(old_var);
       PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= 
express_by_new_vars);
       PrimExpr upper_cond =
           analyzer_solution.Simplify(express_by_new_vars < old_range->min + 
old_range->extent);
diff --git a/src/arith/solve_linear_inequality.cc 
b/src/arith/solve_linear_inequality.cc
index f4de9ff..dd90448 100644
--- a/src/arith/solve_linear_inequality.cc
+++ b/src/arith/solve_linear_inequality.cc
@@ -94,11 +94,10 @@ struct ExprLess {
   }
 };
 
-void DebugPrint(
-    const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& 
current_ineq_set,
-    const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& 
next_ineq_set,
-    const std::vector<PrimExpr>& rest, const std::vector<std::pair<int64_t, 
PrimExpr>>& coef_pos,
-    const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) {
+void DebugPrint(const std::vector<PrimExpr>& current_ineq_set,
+                const std::vector<PrimExpr>& next_ineq_set, const 
std::vector<PrimExpr>& rest,
+                const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos,
+                const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) {
   std::cout << "Current ineq set:\n[";
   for (auto& ineq : current_ineq_set) {
     std::cout << ineq << ", ";
@@ -148,9 +147,12 @@ class NormalizeComparisons : public ExprMutator {
   arith::Analyzer analyzer_;
 };
 
-void AddInequality(std::unordered_set<PrimExpr, StructuralHash, 
StructuralEqual>* inequality_set,
-                   const PrimExpr& new_ineq, Analyzer* analyzer) {
-  if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != 
inequality_set->end()) {
+void AddInequality(std::vector<PrimExpr>* inequality_set, const PrimExpr& 
new_ineq,
+                   Analyzer* analyzer) {
+  if (analyzer->CanProve(new_ineq) ||
+      std::find_if(inequality_set->begin(), inequality_set->end(), [&](const 
PrimExpr& e) {
+        return StructuralEqual()(e, new_ineq);
+      }) != inequality_set->end()) {
     // redundant: follows from the vranges
     // or has already been added
     return;
@@ -168,15 +170,13 @@ void AddInequality(std::unordered_set<PrimExpr, 
StructuralHash, StructuralEqual>
     }
   }
 
-  inequality_set->insert(new_ineq);
+  inequality_set->push_back(new_ineq);
 }
 
-void ClassifyByPolarity(
-    const Var& var,
-    const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& 
current_ineq_set,
-    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* 
next_ineq_set,
-    std::vector<PrimExpr>* rest, std::vector<std::pair<int64_t, PrimExpr>>* 
coef_pos,
-    std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) {
+void ClassifyByPolarity(const Var& var, const std::vector<PrimExpr>& 
current_ineq_set,
+                        std::vector<PrimExpr>* next_ineq_set, 
std::vector<PrimExpr>* rest,
+                        std::vector<std::pair<int64_t, PrimExpr>>* coef_pos,
+                        std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, 
Analyzer* analyzer) {
   // Take formulas from current_ineq_set and classify them according to 
polarity wrt var
   // and store to coef_pos and coef_neg respectively.
   for (const PrimExpr& ineq : current_ineq_set) {
@@ -218,14 +218,14 @@ void ClassifyByPolarity(
   }
 }
 
-void MoveEquality(std::unordered_set<PrimExpr, StructuralHash, 
StructuralEqual>* upper_bounds,
-                  std::unordered_set<PrimExpr, StructuralHash, 
StructuralEqual>* lower_bounds,
-                  std::unordered_set<PrimExpr, StructuralHash, 
StructuralEqual>* equalities) {
+void MoveEquality(std::vector<PrimExpr>* upper_bounds, std::vector<PrimExpr>* 
lower_bounds,
+                  std::vector<PrimExpr>* equalities) {
   // those exist in both upper & lower bounds will be moved to equalities
   for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) {
-    auto lb = lower_bounds->find(*ub);
+    auto lb = std::find_if(lower_bounds->begin(), lower_bounds->end(),
+                           [&](const PrimExpr& e) { return 
StructuralEqual()(e, *ub); });
     if (lb != lower_bounds->end()) {
-      equalities->insert(*lb);
+      equalities->push_back(*lb);
       lower_bounds->erase(lb);
       ub = upper_bounds->erase(ub);
     } else {
@@ -249,8 +249,8 @@ PartialSolvedInequalities SolveLinearInequalities(const 
IntConstraints& system_t
   //   and move to the next variable.
 
   // normalized inequality
-  std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> 
current_ineq_set_to_solve;
-  std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> 
next_ineq_set_to_solve;
+  std::vector<PrimExpr> current_ineq_set_to_solve;
+  std::vector<PrimExpr> next_ineq_set_to_solve;
   // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + 
e <= 0
   std::vector<std::pair<int64_t, PrimExpr>> coef_pos;
   // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + 
e <= 0
@@ -321,8 +321,8 @@ PartialSolvedInequalities SolveLinearInequalities(const 
IntConstraints& system_t
     }
 
     // The resulting lower and upper bounds
-    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> upper_bounds;
-    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> lower_bounds;
+    std::vector<PrimExpr> upper_bounds;
+    std::vector<PrimExpr> lower_bounds;
     upper_bounds.reserve(coef_pos.size());
     lower_bounds.reserve(coef_neg.size());
 
@@ -345,7 +345,7 @@ PartialSolvedInequalities SolveLinearInequalities(const 
IntConstraints& system_t
         }
       }
       // Add the upper bound
-      upper_bounds.insert(bound);
+      upper_bounds.push_back(bound);
     }
     for (const auto& neg : coef_neg) {
       PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * 
neg.second;
@@ -366,10 +366,10 @@ PartialSolvedInequalities SolveLinearInequalities(const 
IntConstraints& system_t
         }
       }
       // Add the lower bound
-      lower_bounds.insert(bound);
+      lower_bounds.push_back(bound);
     }
 
-    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> equal;
+    std::vector<PrimExpr> equal;
     equal.reserve(std::min(upper_bounds.size(), lower_bounds.size()));
     MoveEquality(&upper_bounds, &lower_bounds, &equal);
     std::vector<PrimExpr> equal_list(equal.begin(), equal.end());
diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc
index cc0e820..96f278e 100644
--- a/src/te/autodiff/ad_simplify.cc
+++ b/src/te/autodiff/ad_simplify.cc
@@ -413,15 +413,17 @@ class FactorOutAtomicFormulasFunctor
     auto res_b = VisitExpr(op->b);
 
     // For the And case we return the union of the sets of atomic formulas
-    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
-    res_set.reserve(res_a.atomic_formulas.size() + 
res_b.atomic_formulas.size());
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set;
+    res_a_set.reserve(res_a.atomic_formulas.size());
     std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
-              std::inserter(res_set, res_set.end()));
-    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
-              std::inserter(res_set, res_set.end()));
-
-    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+              std::inserter(res_a_set, res_a_set.end()));
 
+    std::vector<PrimExpr> res = res_a.atomic_formulas;
+    for (const auto& e : res_b.atomic_formulas) {
+      if (res_a_set.find(e) == res_a_set.end()) {
+        res.emplace_back(e);
+      }
+    }
     // And the residuals are combined with &&
     return {res, res_a.rest && res_b.rest};
   }
@@ -443,10 +445,13 @@ class FactorOutAtomicFormulasFunctor
 
     // For the Or case we intersect the sets of atomic formulas
     std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    std::vector<PrimExpr> res;
     res_set.reserve(std::min(res_a.atomic_formulas.size(), 
res_b.atomic_formulas.size()));
-    for (const auto& res_b_formula : res_b_set) {
+    res.reserve(std::min(res_a.atomic_formulas.size(), 
res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b.atomic_formulas) {
       if (res_a_set.count(res_b_formula)) {
         res_set.insert(res_b_formula);
+        res.push_back(res_b_formula);
       }
     }
 
@@ -454,13 +459,13 @@ class FactorOutAtomicFormulasFunctor
     // which are left behind, and then combine them with the residuals into 
the new residual.
     std::vector<PrimExpr> new_cond_a;
     new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
-    for (const auto& formula : res_a_set) {
+    for (const auto& formula : res_a.atomic_formulas) {
       if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
     }
 
     std::vector<PrimExpr> new_cond_b;
     new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
-    for (const auto& formula : res_b_set) {
+    for (const auto& formula : res_b.atomic_formulas) {
       if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
     }
 
@@ -468,7 +473,6 @@ class FactorOutAtomicFormulasFunctor
     res_b.atomic_formulas = std::move(new_cond_b);
 
     PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
-    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
 
     return {res, new_rest};
   }

Reply via email to