This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 378c4f3043 [BugFix][Relax]: handle ONNX ScatterElements reduction 
(#19527)
378c4f3043 is described below

commit 378c4f3043a81f0a500c9d4685df0e82758df67d
Author: Sun <[email protected]>
AuthorDate: Tue May 12 12:17:28 2026 +0800

    [BugFix][Relax]: handle ONNX ScatterElements reduction (#19527)
    
    ### Summary
    
    - Respect the ONNX `reduction` attribute in the Relax ONNX frontend
    `ScatterElements` converter.
    - Preserve existing default behavior by mapping missing reduction and
    ONNX `none` to Relax `update`.
    - Add focused regression coverage for opset 11 default behavior, opset
    16 `add`/`mul`, and opset 18 `none`/`min`/`max`.
    
    ### Changes
    
    - Added a shared helper to normalize and validate ONNX reduction
    attributes.
    - Implemented `ScatterElements` opset 16 and opset 18 converters.
    - Reused the existing `relax.op.scatter_elements(..., reduction=...)`
    API.
    - Reused the same reduction helper in `ScatterND` to keep behavior
    consistent.
    
    ### Test Plan
    
    - `python -m py_compile python/tvm/relax/frontend/onnx/onnx_frontend.py
    tests/python/relax/test_frontend_onnx.py`
    - `python -m pytest
    tests/python/relax/test_frontend_onnx.py::test_gather_elements
    tests/python/relax/test_frontend_onnx.py::test_scatter
    tests/python/relax/test_frontend_onnx.py::test_scatter_elements_reduction
    tests/python/relax/test_frontend_onnx.py::test_scatter_nd -q`
    
    ### Issue
    
    Fixes #19435
    
    ## Local Verification Notes
    
    - WSL conda environment: `/home/thinker/.cache/tvm-conda-onnx`
    - TVM build directory: `/home/thinker/.cache/tvm-build-onnx`
    - LLVM runtime check: `tvm.runtime.enabled("llvm") == True`
    - Relevant ONNX frontend subset: `15 passed, 4 skipped, 2 warnings`
    - Full `tests/python/relax/test_frontend_onnx.py` was also attempted. It
    currently has 14 failures in unrelated `Reduce* axes input` and `TopK`
    tests; running the same selected failures against `origin/main`
    reproduces them, so they are not introduced by this PR.
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py |  40 +++++++---
 tests/python/relax/test_frontend_onnx.py        | 100 ++++++++++++++++++++++++
 2 files changed, 131 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 7d85906cff..622e262cc4 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1159,6 +1159,20 @@ class Scatter(OnnxOpConverter):
         raise ValueError("Scatter is deprecated in ONNX 11")
 
 
+def _get_onnx_reduction(attr, valid_reductions: list[str]):
+    reduction = attr.get("reduction", None)
+    reduction = reduction or b"update"
+    if isinstance(reduction, bytes):
+        reduction = reduction.decode("utf-8")
+    reduction = "update" if reduction == "none" else reduction
+    if reduction not in valid_reductions:
+        raise ValueError(
+            f"Only {valid_reductions} reductions are supported, but got 
{reduction}"
+        )
+
+    return reduction
+
+
 class ScatterElements(OnnxOpConverter):
     """Convert an onnx ScatterElements node into an equivalent Relax 
expression."""
 
@@ -1167,21 +1181,29 @@ class ScatterElements(OnnxOpConverter):
         axis = attr.get("axis", 0)
         return relax.op.scatter_elements(inputs[0], inputs[1], inputs[2], 
axis=axis)
 
+    @classmethod
+    def _impl_v16(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", 0)
+        reduction = _get_onnx_reduction(attr, ["update", "add", "mul"])
+        return relax.op.scatter_elements(
+            inputs[0], inputs[1], inputs[2], axis=axis, reduction=reduction
+        )
+
+    @classmethod
+    def _impl_v18(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", 0)
+        reduction = _get_onnx_reduction(attr, ["update", "add", "mul", "min", 
"max"])
+        return relax.op.scatter_elements(
+            inputs[0], inputs[1], inputs[2], axis=axis, reduction=reduction
+        )
+
 
 class ScatterND(OnnxOpConverter):
     """Convert an onnx ScatterND node into an equivalent Relax expression."""
 
     @staticmethod
     def _reduction_check(attr, valid_reductions: list[str]):
-        reduction = attr.get("reduction", None)
-        reduction = reduction or b"update"
-        reduction = reduction.decode("utf-8")
-        reduction = "update" if reduction == "none" else reduction
-        assert reduction in valid_reductions, (
-            f"Only {valid_reductions} reductions are supported, but 
{reduction} is gotten"
-        )
-
-        return reduction
+        return _get_onnx_reduction(attr, valid_reductions)
 
     @classmethod
     def _impl_v11(cls, bb, inputs, attr, params):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 52a4064cc8..94b85ab95a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1023,6 +1023,106 @@ def test_scatter(axis: int, name: str, opset: int):
     check_correctness(model, inputs={"indices": indices}, opset=opset)
 
 
[email protected](
+    "reduction, opset, data, indices, updates",
+    [
+        (
+            None,
+            11,
+            np.array([[1, 2, 3], [4, 5, 6]], dtype="float32"),
+            np.array([[2, 0, 1], [1, 2, 0]], dtype="int64"),
+            np.array([[30, 10, 20], [50, 60, 40]], dtype="float32"),
+        ),
+        (
+            "none",
+            18,
+            np.array([[1, 2, 3], [4, 5, 6]], dtype="float32"),
+            np.array([[2, 0, 1], [1, 2, 0]], dtype="int64"),
+            np.array([[30, 10, 20], [50, 60, 40]], dtype="float32"),
+        ),
+        (
+            "add",
+            16,
+            np.full((2, 3), 10, dtype="float32"),
+            np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
+            np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
+        ),
+        (
+            "mul",
+            16,
+            np.full((2, 3), 10, dtype="float32"),
+            np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
+            np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
+        ),
+        (
+            "min",
+            18,
+            np.full((2, 3), 10, dtype="float32"),
+            np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
+            np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
+        ),
+        (
+            "max",
+            18,
+            np.full((2, 3), 10, dtype="float32"),
+            np.array([[0, 0, 2], [1, 1, 2]], dtype="int64"),
+            np.array([[2, 5, 7], [20, 3, 4]], dtype="float32"),
+        ),
+    ],
+)
+def test_scatter_elements_reduction(reduction, opset, data, indices, updates):
+    attrs = {"axis": 1}
+    if reduction is not None:
+        attrs["reduction"] = reduction
+    scatter_elements_node = helper.make_node(
+        "ScatterElements", ["data", "indices", "updates"], ["output"], **attrs
+    )
+
+    graph = helper.make_graph(
+        [scatter_elements_node],
+        "scatter_elements_reduction_test",
+        inputs=[
+            helper.make_tensor_value_info("data", TensorProto.FLOAT, 
list(data.shape)),
+            helper.make_tensor_value_info("indices", TensorProto.INT64, 
list(indices.shape)),
+            helper.make_tensor_value_info("updates", TensorProto.FLOAT, 
list(updates.shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
list(data.shape))],
+    )
+    model = helper.make_model(graph, 
producer_name="scatter_elements_reduction_test")
+
+    check_correctness(
+        model,
+        inputs={"data": data, "indices": indices, "updates": updates},
+        opset=opset,
+    )
+
+
+def test_scatter_elements_invalid_reduction():
+    data_shape = [2, 3]
+    scatter_elements_node = helper.make_node(
+        "ScatterElements",
+        ["data", "indices", "updates"],
+        ["output"],
+        axis=1,
+        reduction="unsupported",
+    )
+
+    graph = helper.make_graph(
+        [scatter_elements_node],
+        "scatter_elements_invalid_reduction_test",
+        inputs=[
+            helper.make_tensor_value_info("data", TensorProto.FLOAT, 
data_shape),
+            helper.make_tensor_value_info("indices", TensorProto.INT64, 
data_shape),
+            helper.make_tensor_value_info("updates", TensorProto.FLOAT, 
data_shape),
+        ],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
data_shape)],
+    )
+    model = helper.make_model(graph, 
producer_name="scatter_elements_invalid_reduction_test")
+
+    with pytest.raises(ValueError, match="Only .* reductions are supported, 
but got unsupported"):
+        from_onnx(model, opset=18, keep_params_in_input=True)
+
+
 @pytest.mark.parametrize("reduction", ["none", "add", "mul"])
 def test_scatter_nd(reduction):
     def verify_scatter_nd(data_shape, indices_shape, updates_shape):

Reply via email to