This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b975db9b28 [Relax] Add FRelaxInferLayout for scatter_nd operator
(#18643)
b975db9b28 is described below
commit b975db9b28959503e471da9c78b41df9a16d738e
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Jan 7 17:53:27 2026 +0800
[Relax] Add FRelaxInferLayout for scatter_nd operator (#18643)
## Why
The scatter_nd operator was missing FRelaxInferLayout attribute, which
is needed for proper layout transformation during model optimization.
### How
- Added InferLayoutScatterND function that uses data tensor's layout for
output since scatter_nd maintains input shape
- Registered FRelaxInferLayout attribute
---
src/relax/op/tensor/manipulate.cc | 40 +++++++++++++++++
.../python/relax/test_transform_convert_layout.py | 52 ++++++++++++++++++++++
2 files changed, 92 insertions(+)
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 7c5682d462..3170b28eeb 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2780,6 +2780,45 @@ StructInfo InferStructInfoScatterND(const Call& call,
const BlockBuilder& ctx) {
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
data_sinfo->vdevice);
}
+InferLayoutOutput InferLayoutScatterND(
+ const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ LayoutDecision data_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ LayoutDecision indices_layout = GetLayoutDecision(var_layout_map,
call->args[1]);
+ LayoutDecision updates_layout = GetLayoutDecision(var_layout_map,
call->args[2]);
+
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* updates_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+ ICHECK(data_sinfo != nullptr) << "Invalid Call";
+ ICHECK(updates_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!data_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+ ICHECK(!updates_sinfo->IsUnknownNdim()) << "Only support static ndim for
now";
+
+ LayoutDecision layout = data_layout;
+ LayoutDecision out_updates_layout = updates_layout;
+
+ // Check if data has a sub-indexed layout
+ bool has_sub_indexed_layout = layout->layout.ndim() !=
layout->layout.ndim_primal();
+
+ if (has_sub_indexed_layout) {
+ // Fall back to initial layouts for both data and updates
+ layout = LayoutDecision(InitialLayout(data_sinfo->ndim));
+ out_updates_layout = LayoutDecision(InitialLayout(updates_sinfo->ndim));
+ } else if (data_sinfo->ndim == updates_sinfo->ndim) {
+ // When data and updates have the same rank, apply the same layout to both
+ out_updates_layout = layout;
+ } else {
+ // Different ranks - fall back to initial layouts for both
+ layout = LayoutDecision(InitialLayout(data_sinfo->ndim));
+ out_updates_layout = LayoutDecision(InitialLayout(updates_sinfo->ndim));
+ }
+
+ return InferLayoutOutput({layout, indices_layout, out_updates_layout},
{layout},
+ Attrs(call->attrs));
+}
+
TVM_REGISTER_OP("relax.scatter_nd")
.set_attrs_type<ScatterNDAttrs>()
.set_num_inputs(3)
@@ -2787,6 +2826,7 @@ TVM_REGISTER_OP("relax.scatter_nd")
.add_argument("indices", "Tensor", "The indices tensor.")
.add_argument("updates", "Tensor", "The input tensor of updates.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterND)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutScatterND)
.set_attr<Bool>("FPurity", Bool(true));
/* relax.scatter_nd */
diff --git a/tests/python/relax/test_transform_convert_layout.py
b/tests/python/relax/test_transform_convert_layout.py
index 26990bc44d..221d680ebc 100644
--- a/tests/python/relax/test_transform_convert_layout.py
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -5382,5 +5382,57 @@ def test_conv2d_scatter_elements():
verify(Input, Expected)
+def test_conv2d_scatter_nd():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"),
+ w: R.Tensor((4, 3, 3, 3), "float32"),
+ indices: R.Tensor((2, 1), "int64"),
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ data: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ updates: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(data)
+ gv = R.scatter_nd(data, indices, updates)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ indices: R.Tensor((2, 1), dtype="int64"),
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ data: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ updates: R.Tensor((2, 26, 26, 4), dtype="float32") =
R.nn.relu(data)
+ lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.scatter_nd(
+ data, indices, updates, reduction="update"
+ )
+ gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv)
+ return gv
+
+ verify(Input, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()