gemini-code-assist[bot] commented on code in PR #18637:
URL: https://github.com/apache/tvm/pull/18637#discussion_r2662248588


##########
src/relax/op/tensor/manipulate.cc:
##########
@@ -2047,11 +2047,61 @@ StructInfo InferStructInfoFlip(const Call& call, const 
BlockBuilder& ctx) {
   return data_sinfo;
 }
 
+InferLayoutOutput InferLayoutFlip(
+    const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& 
desired_layouts,
+    const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  const auto* attrs = call->attrs.as<FlipAttrs>();
+  ICHECK(attrs != nullptr) << "Invalid Call";
+  const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+
+  LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  int ndim = tensor_sinfo->ndim;
+
+  if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) 
{
+    existing_layout = LayoutDecision(InitialLayout(ndim));
+  }
+
+  int axis = attrs->axis.IntValue();
+  if (axis < 0) {
+    axis += ndim;
+  }
+
+  std::string axis_str(ndim, '0');
+  axis_str[axis] = '1';
+  for (int i = 0, j = 0; i < ndim; ++i) {
+    if (axis_str[i] != '1') {
+      axis_str[i] = 'A' + j++;
+    }
+  }
+
+  ffi::String new_axis_str =
+      TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout);
+
+  int64_t new_axis = -1;
+  for (size_t i = 0; i < new_axis_str.size(); ++i) {
+    if (new_axis_str.at(i) == '1') {
+      new_axis = i;
+      break;
+    }
+  }
+  ICHECK_GE(new_axis, 0) << "Failed to find transformed axis";

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The logic to find the transformed axis is quite complex and can be 
simplified by using the existing `FindAxis` utility function. This will make 
the code more readable and consistent with other layout inference functions 
like `InferLayoutConcat` and `InferLayoutSplit`.
   
   ```suggestion
     const int64_t new_axis = FindAxis(existing_layout->layout, axis);
     ICHECK_GE(new_axis, 0) << "Failed to find transformed axis";
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to