Lunderberg commented on code in PR #11287:
URL: https://github.com/apache/tvm/pull/11287#discussion_r880579967


##########
src/arith/iter_affine_map.cc:
##########
@@ -1260,140 +1300,132 @@ IterSumExpr 
IterMapRewriter::PreprocessDividend(IterMapExpr dividend) {
   }
 }
 
+PrimExpr NearLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, 
Analyzer* analyzer) {
+  auto fsplit = [](const PrimExpr& e) -> std::pair<PrimExpr, int64_t> {
+    if (const IntImmNode* imm = e.as<IntImmNode>()) {
+      return {1, imm->value};
+    }
+    PVar<PrimExpr> pv;
+    PVar<IntImm> pc;
+    if ((pv * pc).Match(e) || (pc * pv).Match(e)) {
+      return {pv.Eval(), pc.Eval()->value};
+    } else {
+      return {e, 1};
+    }
+  };
+
+  auto p1 = fsplit(a);
+  auto p2 = fsplit(b);
+  auto const_lcm = Integer(LeastCommonMultiple(p1.second, p2.second));
+  if (analyzer->CanProveEqual(p1.first, p2.first)) {
+    return p1.first * const_lcm;
+  } else {
+    return (p1.first * p2.first) * const_lcm;
+  }
+}
+
 std::pair<IterSplitExpr, PrimExpr> 
IterMapRewriter::PadDividendToDivisor(IterSplitExpr split,
                                                                          
PrimExpr base,
                                                                          
PrimExpr divisor) {
   // If FloorDiv: (((source//lower_factor) % extent) + base) // divisor
   // If FloorMod: (((source//lower_factor) % extent) + base) % divisor
 
-  PrimExpr lookup_key = split;
-
-  auto modified_divisor = [&]() {
-    if (update_iterator_padding_) {
-      return divisor;
-    }
-
-    auto it = padded_iter_map_.find(lookup_key);
-    if (it == padded_iter_map_.end()) {
-      return divisor;
-    }
-
-    const std::vector<PrimExpr>& divisors = it->second.divisors;
-    PrimExpr largest_divisor = divisor;
-    for (const auto& other : divisors) {
-      if (CanProveDivisible(other, largest_divisor)) {
-        // New one is bigger, use it
-        largest_divisor = other;
-      } else if (CanProveDivisible(largest_divisor, other)) {
-        // Current is bigger, keep it
-      } else {
-        ErrorLogger(this) << "Iterator appears in multiple terms with 
incompatible divisors "
-                          << tvm::PrettyPrint(largest_divisor) << " and "
-                          << tvm::PrettyPrint(other);
-      }
-    }
-    return largest_divisor;
-  }();
-
-  divisor = modified_divisor;
-
+  // Update current iteration split's padding.
   // First, adding any padding that is on the lower side of a
   // FloorDiv/FloorMod, such that floormod(iter-left_pad,divisor) == 0
   // when iter==0.
-
-  PrimExpr left_pad;
-
-  if (is_zero(base)) {
-    // Padding on the left is unnecessary if base is known to be zero.
-    left_pad = make_zero(base->dtype);
-  } else {
-    left_pad = analyzer_->Simplify(floormod(base, divisor));
-  }
+  PrimExpr left_pad = analyzer_->Simplify(floormod(base, divisor));
 
   // Next, adding any padding that is on the upper side of a
   // FloorDiv/FloorMod, such that floormod(left_pad + iter + right_pad, 
divisor) == 0
   // when iter==extent.
-
   PrimExpr right_edge = left_pad + split->extent;
   PrimExpr right_pad;
-
   if (CanProveDivisible(right_edge, divisor)) {
-    // Padding on the right is unnecessary if the extent is a multiple of
-    // the divisor.
     right_pad = 0;
   } else {
-    right_pad = analyzer_->Simplify(floormod(-right_edge, divisor));
-  }
-
-  if (is_zero(left_pad) && is_zero(right_pad)) {
-    return {split, left_pad};
+    right_pad = analyzer_->Simplify(floormod(-right_edge, divisor), 9);
   }
 
   if (update_iterator_padding_) {
+    IterMark mark = split->source;
+    auto& info = padded_iter_map_[mark];
+    info.padding_factor =
+        NearLeastCommonMultiple(info.padding_factor, divisor * 
split->lower_factor, analyzer_);
+
+    if (is_zero(left_pad) && is_zero(right_pad)) {
+      return {split, 0};
+    }
+
     // In the first pass, the primary goal is to collect all the divisors
     // that may be used for padding.  These will impact the divisor used
     // to determine padding in the second pass.
-    IterPaddingInfo& info = padded_iter_map_[lookup_key];
-
-    info.divisors.push_back(divisor);
+    PrimExpr padded_extent = analyzer_->Simplify(left_pad + split->extent + 
right_pad);
 
-    PrimExpr padded_extent = left_pad + split->extent + right_pad;
-
-    IterSumExpr as_sum({split}, left_pad);
-    IterMark mark(as_sum, padded_extent);
-    IterSplitExpr new_split(mark);
-
-    return {new_split, left_pad};
+    PrimExpr mark_left_pad = left_pad * split->lower_factor;
+    if (!is_zero(left_pad)) {
+      if (info.left_pad.defined()) {
+        info.left_pad = max(info.left_pad, mark_left_pad);
+      } else {
+        info.left_pad = mark_left_pad;
+      }
+    }
+    split.CopyOnWrite()->extent = padded_extent;
+    return {split, left_pad};
   }
 
-  // Any padding that is required during parsing should have been found
-  // during the first pass that determines the GCD.
-  auto it = padded_iter_map_.find(lookup_key);
+  // In the second pass, update iteration mark's to padded
+  const IterMark& mark = split->source;
+  auto it = padded_iter_map_.find(mark);
   if (it == padded_iter_map_.end()) {
-    ErrorLogger(this) << "Dividend has extent " << 
tvm::PrettyPrint(split->extent) << " and offset "
-                      << tvm::PrettyPrint(base) << ", which requires padding 
for divisor "
-                      << tvm::PrettyPrint(divisor) << ".";
-    return {IterSplitExpr(), left_pad};
+    return {split, left_pad};
   }
-  IterPaddingInfo& info = it->second;
-
-  if (info.padded.defined()) {
-    // A previous visit already applied padding to this iterator.
-    // (e.g. Visiting `(i+1)//4`, then visiting `(i+1)%4`).
-    ICHECK(analyzer_->CanProveEqual(info.left_pad, left_pad));
-    ICHECK(analyzer_->CanProveEqual(info.right_pad, right_pad));
-
-    return {info.padded, left_pad};
+  auto& info = it->second;
+  if (is_zero(info.left_pad.defined() ? info.left_pad : 0) &&
+      CanProveDivisible(mark->extent, info.padding_factor)) {
+    return {split, left_pad};
   }
 
-  // This is the first encounter with the iterator during the second pass.
-  IterSumExpr as_sum({split}, left_pad);
-  IterMark mark(as_sum, left_pad + split->extent + right_pad);
-  info.padded = IterSplitExpr(mark);
-  info.left_pad = left_pad;
-  info.right_pad = right_pad;
-
-  auto left_padding_introduced = (left_pad != 0);
-  // Equivalent to (0 <= split < left_pad), but easier to simplify in
-  // terms of the transformed variables.
-  auto left_padding_predicate =
-      left_padding_introduced && (floordiv(info.padded, divisor) == 
floordiv(base, divisor) &&
-                                  floormod(info.padded, divisor) < left_pad);
-
-  PrimExpr nparts = ceildiv(right_edge, divisor);
-
-  auto right_padding_introduced = (right_pad != 0);
-
-  // Equivalent to (right_edge <= split < right_edge+right_pad), but
-  // easier to simplify in terms of the transformed variables.
-  auto right_padding_predicate = right_padding_introduced &&
-                                 (floordiv(info.padded, divisor) == 
floordiv(right_edge, divisor) &&
-                                  floormod(info.padded, divisor) >= 
floormod(right_edge, divisor));
-
-  requires_padding_ = requires_padding_ || (left_padding_introduced || 
right_padding_introduced);
-  padding_predicate_ = padding_predicate_ || (left_padding_predicate || 
right_padding_predicate);
-
-  return {info.padded, left_pad};
+  if (!info.padded.defined()) {
+    PrimExpr mark_left_pad = info.left_pad.defined() ? info.left_pad : 0;
+    PrimExpr mark_right_pad;
+    if (CanProveDivisible(mark->extent + mark_left_pad, info.padding_factor)) {
+      mark_right_pad = 0;
+    } else {
+      mark_right_pad = floormod(-(mark->extent + mark_left_pad), 
info.padding_factor);
+    }
+    PrimExpr padded_extent = analyzer_->Simplify(mark_left_pad + mark->extent 
+ mark_right_pad);
+    info.right_pad = mark_right_pad;
+    info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad), 
padded_extent);
+
+    auto left_padding_introduced = (mark_left_pad != 0);
+    PrimExpr divisor = info.padding_factor;
+    PrimExpr right_edge = mark_left_pad + mark->extent;
+
+    // Equivalent to (0 <= split < left_pad), but easier to simplify in
+    // terms of the transformed variables.
+    auto left_padding_predicate =
+        left_padding_introduced && (floordiv(info.padded->source, divisor) == 
0 &&
+                                    floormod(info.padded->source, divisor) < 
mark_left_pad);
+
+    auto right_padding_introduced = (mark_right_pad != 0);
+
+    // Equivalent to (right_edge <= split < right_edge+right_pad), but
+    // easier to simplify in terms of the transformed variables.
+    auto right_padding_predicate =
+        right_padding_introduced &&
+        (floordiv(info.padded->source, divisor) == floordiv(right_edge, 
divisor) &&
+         floormod(info.padded->source, divisor) >= floormod(right_edge, 
divisor));
+
+    requires_padding_ = requires_padding_ || (left_padding_introduced || 
right_padding_introduced);
+    padding_predicate_ = padding_predicate_ || (left_padding_predicate || 
right_padding_predicate);
+  }
+  // ICHECK(CanProveDivisible(info.padded->extent, split->lower_factor));

Review Comment:
   Got it.  I noticed that there were also some simplification steps that 
needed to increase the number of iterations performed.  Is the failure to prove 
divisibility related, since `CanProveDivisible` only uses the default of 2 
steps?
   
   (I'm also wondering if the default for `Analyzer::Simplify` should be to 
iterate until it the simplification converges, rather than using a fixed number 
of steps.)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to