This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e370fc7374 [Relax][ONNX] Normalize negative indices before the take
call for `Gather` operator (#19525)
e370fc7374 is described below
commit e370fc7374853a7dd28dab0bba4b1b2252292e29
Author: Neo Chien <[email protected]>
AuthorDate: Mon May 11 20:52:03 2026 +0800
[Relax][ONNX] Normalize negative indices before the take call for `Gather`
operator (#19525)
Hi Committers,
This PR is trying to fix issues
https://github.com/apache/tvm/issues/19436. Any suggestions would be
appreciated if you are available.
### Root Cause
1. ONNX `Gather` allows negative indices (counting from the end of the
target axis).
2. In the Relax ONNX importer, `Gather` was lowered directly to
`relax.op.take` without normalizing negative indices first.
3. This created semantic mismatch / incorrect behavior in downstream
lowering paths that assume non-negative indices.
4. Test failures were also caused by pytest parametrization issues:
- using ONNX `TensorProto` enum values directly as NumPy dtypes,
- and tuple-style parametrization triggering fixture interpretation
errors.
### Solutions
1. Added conditional negative-index normalization in `Gather._impl_v13`:
- apply only for signed index dtypes,
- use: `idx < 0 ? idx + axis_extent : idx`,
- derive `axis_extent` from shape/runtime expression to support dynamic
shapes.
2. Skipped normalization for unsigned index dtypes to avoid redundant
graph ops/checks.
---------
Co-authored-by: cchung100m <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 19 ++++++++
tests/python/relax/test_frontend_onnx.py | 62 +++++++++++++++++++++++++
2 files changed, 81 insertions(+)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 268d91b750..7d85906cff 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1106,6 +1106,25 @@ class Gather(OnnxOpConverter):
shape_val = data[np_index]
return relax.PrimValue(shape_val)
+ indices_dtype = indices.struct_info.dtype
+ if not indices_dtype.startswith("uint"):
+ data_shape = bb.normalize(relax.op.shape_of(data))
+ data_shape_tensor =
bb.normalize(relax.op.shape_to_tensor(data_shape))
+ axis_extent = bb.normalize(
+ relax.op.take(data_shape_tensor, relax.const(axis, "int64"),
axis=0, mode="wrap")
+ )
+
+ if indices_dtype !="int64":
+ axis_extent = bb.normalize(relax.op.astype(axis_extent,
indices_dtype))
+
+ indices = bb.normalize(
+ relax.op.where(
+ relax.op.less(indices, relax.const(0, indices_dtype)),
+ relax.op.add(indices, axis_extent),
+ indices,
+ )
+ )
+
return relax.op.take(data, indices, axis)
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 5a8d84b090..52a4064cc8 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -874,6 +874,68 @@ def test_gather():
_verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1)
[email protected](
+ "axis, indices, out_shape",
+ [
+ (0, [-1, 0], [2, 4]),
+ (1, [-1, 0], [3, 2]),
+ (
+ 1,
+ [[-1, 0], [1, -2]],
+ [3, 2, 2],
+ ),
+ ],
+)
[email protected]("indices_type", [TensorProto.INT64,
TensorProto.INT32])
+def test_gather_negative_indices(axis, indices, out_shape, indices_type):
+ gather_node = helper.make_node("Gather", ["data", "indices"], ["y"],
axis=axis)
+ indices_shape = np.asarray(indices).shape
+
+ graph = helper.make_graph(
+ [gather_node],
+ "gather_negative_indices_test",
+ inputs=[
+ helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4]),
+ helper.make_tensor_value_info("indices", indices_type,
indices_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
out_shape)],
+ )
+
+ model = helper.make_model(graph,
producer_name="gather_negative_indices_test")
+ indices_np_dtype = {
+ TensorProto.INT64: np.int64,
+ TensorProto.INT32: np.int32,
+ }[indices_type]
+ input_values = {
+ "data": np.random.randn(3, 4).astype("float32"),
+ "indices": np.array(indices).astype(indices_np_dtype),
+ }
+ check_correctness(model, inputs=input_values)
+
+
[email protected]("indices_type", [TensorProto.INT64,
TensorProto.INT32])
+def test_gather_negative_indices_ir_normalization(indices_type):
+ gather_node = helper.make_node("Gather", ["data", "indices"], ["y"],
axis=1)
+ graph = helper.make_graph(
+ [gather_node],
+ "gather_negative_indices_ir_test",
+ inputs=[
+ helper.make_tensor_value_info("data", TensorProto.FLOAT, [3, 4]),
+ helper.make_tensor_value_info("indices", indices_type, [2]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3,
2])],
+ )
+
+ model = helper.make_model(graph,
producer_name="gather_negative_indices_ir_test")
+ tvm_model = from_onnx(model, opset=13, keep_params_in_input=True)
+ call_ops = collect_relax_call_ops(tvm_model["main"])
+
+ assert "relax.where" in call_ops
+ assert "relax.less" in call_ops
+ assert "relax.add" in call_ops
+ assert "relax.take" in call_ops
+
+
@pytest.mark.parametrize(
"data_shape, indices_shape, axis",
[