quic-sanirudh commented on code in PR #17219:
URL: https://github.com/apache/tvm/pull/17219#discussion_r1698068199


##########
src/tir/ir/buffer.cc:
##########
@@ -334,24 +334,38 @@ inline Array<PrimExpr> BufferOffset(const BufferNode* n, 
Array<PrimExpr> index,
   return offsets;
 }
 
-Buffer Buffer::GetFlattenedBuffer() const {
-  auto self = operator->();
-
+static void ValidateAxisSeparators(const Array<IntImm>& axis_separators, 
size_t buffer_dim) {
   // These checks ensure that all output axes contain at least one
   // input axis.
-  for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) {
-    auto sep = self->axis_separators[i]->value;
-    auto next_sep = self->axis_separators[i + 1]->value;
-    ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly 
increasing order.";
-  }
-  if (self->axis_separators.size()) {
-    auto first_sep = self->axis_separators[0]->value;
-    ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater 
than 0, "
-                            << "so that first output axis contains at least 
one input axis";
-    auto last_sep = self->axis_separators[self->axis_separators.size() - 
1]->value;
-    ICHECK_LT(last_sep, self->shape.size())
-        << "Last output axis must contain at least one input axis.";
+  for (size_t i = 0; (i + 1) < axis_separators.size(); i++) {
+    auto sep = axis_separators[i]->value;
+    auto next_sep = axis_separators[i + 1]->value;
+    CHECK_LT(sep, next_sep) << "ValueError: "
+                            << "Axis separators must be in strictly increasing 
order, "
+                            << "but axis_separators[" << i << "] = " << sep
+                            << " is greater than or equal to axis_separators[" 
<< (i + 1)
+                            << "] = " << next_sep << ".";
+  }
+  if (axis_separators.size()) {
+    auto first_sep = axis_separators[0]->value;
+    CHECK_GT(first_sep, 0) << "ValueError: "
+                           << "First axis separator must be strictly greater 
than 0, "
+                           << "so that first output axis contains at least one 
input axis.  "
+                           << "However, the axis_separators[0] = " << 
first_sep;
+    auto last_sep = axis_separators[axis_separators.size() - 1]->value;
+    CHECK_LT(last_sep, buffer_dim)
+        << "ValueError: "
+        << "Last output axis must contain at least one input axis.  "
+        << "However, the axis_separators[" << (axis_separators.size() - 1) << 
"] = " << last_sep
+        << " does not leave any input axes between it and the buffer's 
dimensionality "
+        << buffer_dim;

Review Comment:
   For a case like `axis_separators=[1, 2]` where the buffer is say 4d 
(NHWC/NCHW), both these checks would pass, but that might also be confusing as 
the user would expect 3 flattened dimensions with `axis_separators.size() == 
2`, but we get only 2 flatteneed dimensions. 
   
   Should we require atleast one valid axis between separators? Or do we allow 
it to be flattened into a single axis separator in this case?
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to