gemini-code-assist[bot] commented on code in PR #18643:
URL: https://github.com/apache/tvm/pull/18643#discussion_r2667071160
##########
src/relax/op/tensor/manipulate.cc:
##########
@@ -2780,13 +2780,36 @@ StructInfo InferStructInfoScatterND(const Call& call,
const BlockBuilder& ctx) {
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
data_sinfo->vdevice);
}
+InferLayoutOutput InferLayoutScatterND(
+ const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ LayoutDecision data_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ LayoutDecision indices_layout = GetLayoutDecision(var_layout_map,
call->args[1]);
+ LayoutDecision updates_layout = GetLayoutDecision(var_layout_map,
call->args[2]);
+
+ LayoutDecision layout = data_layout;
+
+ if (layout->layout.ndim() != layout->layout.ndim_primal()) {
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for
now";
+ int ndim = tensor_sinfo->ndim;
+ layout = LayoutDecision(InitialLayout(ndim));
+ }
+
+ return InferLayoutOutput({layout, indices_layout, updates_layout}, {layout},
Attrs(call->attrs));
+}
Review Comment:

The current implementation of `InferLayoutScatterND` doesn't enforce any
layout on the `updates` tensor. This can lead to shape mismatches after layout
conversion if `updates` has a layout that is incompatible with the new layout
of `data`, as the `updates` tensor's shape is dependent on the `data` tensor's
shape. For instance, if `data` is transformed to NHWC, but `updates` remains in
NCHW, `InferStructInfo` will likely fail.
I suggest a more robust implementation that enforces the same layout for
`data` and `updates` when they have the same rank. For more complex cases, like
different ranks or when `data` has a sub-indexed layout, it safely falls back
to the initial layout for both tensors. This will prevent potential compilation
errors.
```c
InferLayoutOutput InferLayoutScatterND(
const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
const VarLayoutMap& var_layout_map) {
ICHECK(NoDesiredLayout(call, desired_layouts));
LayoutDecision data_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
LayoutDecision indices_layout = GetLayoutDecision(var_layout_map,
call->args[1]);
LayoutDecision updates_layout = GetLayoutDecision(var_layout_map,
call->args[2]);
const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
ICHECK(data_sinfo != nullptr) << "Invalid Call";
const auto* updates_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
ICHECK(updates_sinfo != nullptr) << "Invalid Call";
bool same_rank = !data_sinfo->IsUnknownNdim() &&
!updates_sinfo->IsUnknownNdim() &&
data_sinfo->ndim == updates_sinfo->ndim;
if (data_layout->layout.ndim() != data_layout->layout.ndim_primal() ||
!same_rank) {
// Fallback to initial layout if data has sub-indexed layout or ranks
are different.
// The case of different ranks is complex to handle, so we fallback for
safety.
ICHECK(!data_sinfo->IsUnknownNdim()) << "Only support static ndim for
now";
data_layout = LayoutDecision(InitialLayout(data_sinfo->ndim));
if (!updates_sinfo->IsUnknownNdim()) {
updates_layout = LayoutDecision(InitialLayout(updates_sinfo->ndim));
} else {
updates_layout = LayoutDecision::InitUnknownDim();
}
} else {
// Same rank, not sub-indexed. Enforce same layout for data and updates.
updates_layout = data_layout;
}
return InferLayoutOutput({data_layout, indices_layout, updates_layout},
{data_layout},
Attrs(call->attrs));
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]