gemini-code-assist[bot] commented on code in PR #18641:
URL: https://github.com/apache/tvm/pull/18641#discussion_r2664438267


##########
src/relax/op/nn/pooling.cc:
##########
@@ -111,7 +111,13 @@ StructInfo InferStructInfoPool1D(const Call& call, const 
BlockBuilder& ctx) {
   if (attrs->ceil_mode) {
     numerator_w += attrs->strides[0] - 1;
   }
-  out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[0]) + 1);
+  PrimExpr raw_out_w = floordiv(numerator_w, attrs->strides[0]) + 1;
+  if (attrs->ceil_mode) {
+    PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[0] >= input_w + 
attrs->padding[0];
+    out_NCW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_w, 
raw_out_w - 1, raw_out_w));
+  } else {
+    out_NCW_shape[2] = analyzer->Simplify(raw_out_w);
+  }

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The logic to handle `ceil_mode` is duplicated across 
`InferStructInfoPool1D`, `InferStructInfoPool2D`, and `InferStructInfoPool3D`. 
To improve code maintainability and reduce redundancy, consider extracting this 
logic into a common helper function. For example:
   
   ```cpp
   // Placed within the relax namespace
   inline PrimExpr InferPoolOutputSize(arith::Analyzer* analyzer, PrimExpr 
input_size,
                                       PrimExpr raw_out_size, PrimExpr stride,
                                       PrimExpr padding_before, bool ceil_mode) 
{
     if (ceil_mode) {
       PrimExpr invalid_last = (raw_out_size - 1) * stride >= input_size + 
padding_before;
       return analyzer->Simplify(if_then_else(invalid_last, raw_out_size - 1, 
raw_out_size));
     } else {
       return analyzer->Simplify(raw_out_size);
     }
   }
   ```
   
   Then, this block and similar blocks in `InferStructInfoPool2D` and 
`InferStructInfoPool3D` can be simplified to a single call, for instance:
   
   ```cpp
   out_NCW_shape[2] = InferPoolOutputSize(analyzer, input_w, raw_out_w, 
attrs->strides[0], attrs->padding[0], attrs->ceil_mode);
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to