This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0c7adc2fee [Relax] Add FRelaxInferLayout for scatter_elements operator 
(#18638)
0c7adc2fee is described below

commit 0c7adc2fee5971b5908a2940aac99d7cafb09022
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Jan 6 17:34:24 2026 +0800

    [Relax] Add FRelaxInferLayout for scatter_elements operator (#18638)
    
    ## Why
    
    The scatter_elements operator was missing FRelaxInferLayout support,
    which prevented proper layout transformation when used with operators
    like conv2d that require layout conversion.
    
    ## How
    
    - Implement InferLayoutScatterElements function that handles layout
    inference for scatter_elements
    - Transform axis attribute according to the inferred layout using
    FindAxis
    - Handle sub-indexed layout fallback to initial layout
    - Add test case for conv2d + scatter_elements layout conversion
---
 src/relax/op/tensor/manipulate.cc                  | 31 +++++++++++-
 tests/python/contrib/test_msc/test_graph_build.py  | 14 +++---
 .../python/relax/test_transform_convert_layout.py  | 55 ++++++++++++++++++++++
 3 files changed, 92 insertions(+), 8 deletions(-)

diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 22636afb97..7c5682d462 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2613,7 +2613,35 @@ StructInfo InferStructInfoScatterElements(const Call& 
call, const BlockBuilder&
   return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, 
data_sinfo->vdevice);
 }
 
-// TODO(relax-team): implement FRelaxInferLayout for scatter_elements
+InferLayoutOutput InferLayoutScatterElements(
+    const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& 
desired_layouts,
+    const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  const auto* attrs = call->attrs.as<ScatterElementsAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision data_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, 
call->args[1]);
+  LayoutDecision updates_layout = GetLayoutDecision(var_layout_map, 
call->args[2]);
+
+  LayoutDecision layout = data_layout;
+  if (NLayoutEqual()(indices_layout, updates_layout)) {
+    layout = indices_layout;
+  }
+
+  if (layout->layout.ndim() != layout->layout.ndim_primal()) {
+    const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+    ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+    ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for 
now";
+    int ndim = tensor_sinfo->ndim;
+    layout = LayoutDecision(InitialLayout(ndim));
+  }
+
+  ObjectPtr<ScatterElementsAttrs> new_attrs = 
ffi::make_object<ScatterElementsAttrs>(*attrs);
+  new_attrs->axis = FindAxis(layout->layout, attrs->axis->value);
+  return InferLayoutOutput({layout, layout, layout}, {layout}, 
Attrs(new_attrs));
+}
+
 TVM_REGISTER_OP("relax.scatter_elements")
     .set_attrs_type<ScatterElementsAttrs>()
     .set_num_inputs(3)
@@ -2621,6 +2649,7 @@ TVM_REGISTER_OP("relax.scatter_elements")
     .add_argument("indices", "Tensor", "The indices tensor.")
     .add_argument("updates", "Tensor", "The input tensor of updates.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoScatterElements)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", 
InferLayoutScatterElements)
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.scatter_nd */
diff --git a/tests/python/contrib/test_msc/test_graph_build.py 
b/tests/python/contrib/test_msc/test_graph_build.py
index 328fbf456e..3f70dce36e 100644
--- a/tests/python/contrib/test_msc/test_graph_build.py
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -2443,22 +2443,22 @@ def test_scatter(dynamic: bool):
 
     expected1 = {
         "inputs": [
-            {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": 
""},
-            {"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout": 
""},
+            {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": 
"AB"},
+            {"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout": 
"AB"},
         ],
         "outputs": [
-            {"name": "scatter_elements", "shape": [bz, 20], "dtype": 
"float32", "layout": ""}
+            {"name": "scatter_elements", "shape": [bz, 20], "dtype": 
"float32", "layout": "AB"}
         ],
         "nodes": {"total": 4, "input": 2, "constant": 1, "scatter_elements": 
1},
     }
     expected2 = {
         "inputs": [
-            {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": 
""},
-            {"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": ""},
-            {"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout": 
""},
+            {"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": 
"AB"},
+            {"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": 
"AB"},
+            {"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout": 
"AB"},
         ],
         "outputs": [
-            {"name": "scatter_elements", "shape": [bz, 20], "dtype": 
"float32", "layout": ""}
+            {"name": "scatter_elements", "shape": [bz, 20], "dtype": 
"float32", "layout": "AB"}
         ],
         "nodes": {"total": 4, "input": 3, "scatter_elements": 1},
     }
diff --git a/tests/python/relax/test_transform_convert_layout.py 
b/tests/python/relax/test_transform_convert_layout.py
index 8ae96e9c07..26990bc44d 100644
--- a/tests/python/relax/test_transform_convert_layout.py
+++ b/tests/python/relax/test_transform_convert_layout.py
@@ -5327,5 +5327,60 @@ def test_conv2d_flip():
     verify(Input, Expected)
 
 
+def test_conv2d_scatter_elements():
+    @I.ir_module
+    class Input:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), "float32"),
+            w: R.Tensor((4, 3, 3, 3), "float32"),
+            indices: R.Tensor((2, 4, 26, 26), "int64"),
+        ) -> R.Tensor(None, "float32", ndim=4):
+            with R.dataflow():
+                data: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, 
out_dtype="float32")
+                updates: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(data)
+                gv = R.scatter_elements(data, indices, updates, axis=1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 28, 28), dtype="float32"),
+            w: R.Tensor((4, 3, 3, 3), dtype="float32"),
+            indices: R.Tensor((2, 4, 26, 26), dtype="int64"),
+        ) -> R.Tensor(None, dtype="float32", ndim=4):
+            with R.dataflow():
+                lv: R.Tensor((2, 28, 28, 3), dtype="float32") = 
R.permute_dims(x, axes=[0, 2, 3, 1])
+                lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = 
R.permute_dims(w, axes=[0, 2, 3, 1])
+                data: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+                    lv,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="OHWI",
+                    out_layout="NHWC",
+                    out_dtype="float32",
+                )
+                updates: R.Tensor((2, 26, 26, 4), dtype="float32") = 
R.nn.relu(data)
+                lv2: R.Tensor((2, 26, 26, 4), dtype="int64") = R.permute_dims(
+                    indices, axes=[0, 2, 3, 1]
+                )
+                lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = 
R.scatter_elements(
+                    data, lv2, updates, axis=3, reduction="update"
+                )
+                gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
+                    lv3, axes=[0, 3, 1, 2]
+                )
+                R.output(gv)
+            return gv
+
+    verify(Input, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to