This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0c7adc2fee [Relax] Add FRelaxInferLayout for scatter_elements operator
(#18638)
0c7adc2fee is described below
commit 0c7adc2fee5971b5908a2940aac99d7cafb09022
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Jan 6 17:34:24 2026 +0800
[Relax] Add FRelaxInferLayout for scatter_elements operator (#18638)
## Why
The scatter_elements operator was missing FRelaxInferLayout support,
which prevented proper layout transformation when used with operators
like conv2d that require layout conversion.
## How
- Implement InferLayoutScatterElements function that handles layout
inference for scatter_elements
- Transform axis attribute according to the inferred layout using
FindAxis
- Handle sub-indexed layout fallback to initial layout
- Add test case for conv2d + scatter_elements layout conversion
---
src/relax/op/tensor/manipulate.cc | 31 +++++++++++-
tests/python/contrib/test_msc/test_graph_build.py | 14 +++---
.../python/relax/test_transform_convert_layout.py | 55 ++++++++++++++++++++++
3 files changed, 92 insertions(+), 8 deletions(-)
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 22636afb97..7c5682d462 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2613,7 +2613,35 @@ StructInfo InferStructInfoScatterElements(const Call&
call, const BlockBuilder&
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
data_sinfo->vdevice);
}
-// TODO(relax-team): implement FRelaxInferLayout for scatter_elements
+InferLayoutOutput InferLayoutScatterElements(
+ const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ const auto* attrs = call->attrs.as<ScatterElementsAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision data_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ LayoutDecision indices_layout = GetLayoutDecision(var_layout_map,
call->args[1]);
+ LayoutDecision updates_layout = GetLayoutDecision(var_layout_map,
call->args[2]);
+
+ LayoutDecision layout = data_layout;
+ if (NLayoutEqual()(indices_layout, updates_layout)) {
+ layout = indices_layout;
+ }
+
+ if (layout->layout.ndim() != layout->layout.ndim_primal()) {
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for
now";
+ int ndim = tensor_sinfo->ndim;
+ layout = LayoutDecision(InitialLayout(ndim));
+ }
+
+ ObjectPtr<ScatterElementsAttrs> new_attrs =
ffi::make_object<ScatterElementsAttrs>(*attrs);
+ new_attrs->axis = FindAxis(layout->layout, attrs->axis->value);
+ return InferLayoutOutput({layout, layout, layout}, {layout},
Attrs(new_attrs));
+}
+
TVM_REGISTER_OP("relax.scatter_elements")
.set_attrs_type<ScatterElementsAttrs>()
.set_num_inputs(3)
@@ -2621,6 +2649,7 @@ TVM_REGISTER_OP("relax.scatter_elements")
.add_argument("indices", "Tensor", "The indices tensor.")
.add_argument("updates", "Tensor", "The input tensor of updates.")
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoScatterElements)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout",
InferLayoutScatterElements)
.set_attr<Bool>("FPurity", Bool(true));
/* relax.scatter_nd */
diff --git a/tests/python/contrib/test_msc/test_graph_build.py
b/tests/python/contrib/test_msc/test_graph_build.py
index 328fbf456e..3f70dce36e 100644
--- a/tests/python/contrib/test_msc/test_graph_build.py
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -2443,22 +2443,22 @@ def test_scatter(dynamic: bool):
expected1 = {
"inputs": [
- {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout":
""},
- {"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout":
""},
+ {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout":
"AB"},
+ {"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout":
"AB"},
],
"outputs": [
- {"name": "scatter_elements", "shape": [bz, 20], "dtype":
"float32", "layout": ""}
+ {"name": "scatter_elements", "shape": [bz, 20], "dtype":
"float32", "layout": "AB"}
],
"nodes": {"total": 4, "input": 2, "constant": 1, "scatter_elements":
1},
}
expected2 = {
"inputs": [
- {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout":
""},
- {"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": ""},
- {"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout":
""},
+ {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout":
"AB"},
+ {"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout":
"AB"},
+ {"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout":
"AB"},
],
"outputs": [
- {"name": "scatter_elements", "shape": [bz, 20], "dtype":
"float32", "layout": ""}
+ {"name": "scatter_elements", "shape": [bz, 20], "dtype":
"float32", "layout": "AB"}
],
"nodes": {"total": 4, "input": 3, "scatter_elements": 1},
}
diff --git a/tests/python/relax/test_transform_convert_layout.py
b/tests/python/relax/test_transform_convert_layout.py
index 8ae96e9c07..26990bc44d 100644
--- a/tests/python/relax/test_transform_convert_layout.py
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -5327,5 +5327,60 @@ def test_conv2d_flip():
verify(Input, Expected)
+def test_conv2d_scatter_elements():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"),
+ w: R.Tensor((4, 3, 3, 3), "float32"),
+ indices: R.Tensor((2, 4, 26, 26), "int64"),
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ data: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ updates: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(data)
+ gv = R.scatter_elements(data, indices, updates, axis=1)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+ w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+ indices: R.Tensor((2, 4, 26, 26), dtype="int64"),
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ data: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ updates: R.Tensor((2, 26, 26, 4), dtype="float32") =
R.nn.relu(data)
+ lv2: R.Tensor((2, 26, 26, 4), dtype="int64") = R.permute_dims(
+ indices, axes=[0, 2, 3, 1]
+ )
+ lv3: R.Tensor((2, 26, 26, 4), dtype="float32") =
R.scatter_elements(
+ data, lv2, updates, axis=3, reduction="update"
+ )
+ gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
+ lv3, axes=[0, 3, 1, 2]
+ )
+ R.output(gv)
+ return gv
+
+ verify(Input, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()