gemini-code-assist[bot] commented on code in PR #18642:
URL: https://github.com/apache/tvm/pull/18642#discussion_r2667065587


##########
tests/python/relax/test_transform_convert_layout.py:
##########
@@ -5382,5 +5382,58 @@ def main(
     verify(Input, Expected)
 
 
+def test_conv2d_gather_elements():

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This is a good test case. However, it only covers the scenario where the 
`data` tensor has a transformed layout and `indices` has an initial layout.
   
   To ensure the layout inference is robust, please consider adding another 
test case that covers the opposite scenario: where `indices` has a transformed 
layout and `data` has an initial layout. This would help verify that the layout 
propagation logic correctly prioritizes transformed layouts.
   
   For example, you could construct a test where `indices` is derived from an 
operator that has a desired layout (like `conv2d`, after casting to `int64`), 
while `data` is a direct function input.



##########
src/relax/op/tensor/manipulate.cc:
##########
@@ -2150,12 +2150,41 @@ StructInfo InferStructInfoGatherElements(const Call& 
call, const BlockBuilder& c
   return TensorStructInfo(data_sinfo->dtype, indices_sinfo->ndim, 
data_sinfo->vdevice);
 }
 
+InferLayoutOutput InferLayoutGatherElements(
+    const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& 
desired_layouts,
+    const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  const auto* attrs = call->attrs.as<GatherElementsAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision data_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, 
call->args[1]);
+
+  LayoutDecision layout = data_layout;
+  if (NLayoutEqual()(data_layout, indices_layout)) {
+    layout = indices_layout;
+  }

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The current logic for selecting the layout for `gather_elements` can be 
improved. It effectively defaults to `data_layout`, which is suboptimal if 
`data` has an initial layout (e.g., `NCHW`) while `indices` has a transformed 
layout (e.g., `NHWC`). This would force an unnecessary layout transformation on 
`indices` back to the initial layout, negating the benefits of layout 
optimization.
   
   A more robust approach, seen in other binary operators, is to prioritize 
non-initial (transformed) layouts. This ensures layout transformations 
propagate correctly.
   
   Please consider changing the layout selection logic to prioritize a 
transformed layout over an initial one.
   
   ```suggestion
     LayoutDecision layout = data_layout;
     // If data_layout is initial and indices_layout is not, prefer 
indices_layout.
     if (data_layout->IsInitial() && !indices_layout->IsInitial()) {
       layout = indices_layout;
     }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to