This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 87b37b4d5b Fix onnx expand op (#17900)
87b37b4d5b is described below
commit 87b37b4d5bd19ee4553e384cffa3aa9e8a1d72bd
Author: Taylor <[email protected]>
AuthorDate: Mon Apr 28 00:48:55 2025 +0800
Fix onnx expand op (#17900)
* [ONNX] Fix Expand operator to properly handle target shapes
This fixes issue #17746 where the ONNX Expand operator was not correctly
expanding tensors to higher dimensions. The issue manifested when a
downstream ArgMin operation received a tensor with fewer dimensions than
expected, causing an 'axis out of bounds' error.
Specifically:
1. The Expand op was incorrectly skipping the broadcast when input and
target shapes had the same values but different ranks
2. This caused a tensor with shape [5,60] to remain [5,60] when it
should have been expanded to [1,1,5,60]
3. The subsequent ArgMin op with axis=2 then failed as the tensor only
had 2 dimensions instead of the expected 4
The fix ensures that Expand always broadcasts to the target shape,
preserving the rank specified in the ONNX model. This allows downstream
operations to work with the correct tensor dimensions.
Fixes #17746
* add expand test case
* fix test case
* reformat
---------
Co-authored-by: Anurag Singh
<[email protected]>
Co-authored-by: taylor <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 22 ++++++++++++----------
tests/python/relax/test_frontend_onnx.py | 6 ++++++
2 files changed, 18 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index dd4b8a4254..24217184b5 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1917,18 +1917,20 @@ class Expand(OnnxOpConverter):
# If possible, directly expand to constant shape.
if isinstance(shape, relax.Constant):
new_shape = shape.data.numpy().tolist()
- # For some reason, onnx allows target shapes to be smaller than
input shapes.
- # We need to go correct it.
+ # ONNX Expand operator requires preserving target rank and
broadcasting
+ # according to standard rules. Dimensions are right-aligned.
data_shape = [dim.value for dim in data.struct_info.shape]
- # Dimensions are right alignment.
- data_shape = [1] * (len(new_shape) - len(data_shape)) + data_shape
- # Fix small target shapes.
- for i, s in enumerate(new_shape):
- if i < len(data_shape) and s < data_shape[i]:
+
+ # Right-align the shapes
+ if len(new_shape) > len(data_shape):
+ data_shape = [1] * (len(new_shape) - len(data_shape)) +
data_shape
+ else:
+ new_shape = [1] * (len(data_shape) - len(new_shape)) +
new_shape
+ # Fix small target shapes - if target dim is smaller than input dim
+ # use the input dim (ONNX-specific behavior).
+ for i in range(len(new_shape)):
+ if new_shape[i] < data_shape[i]:
new_shape[i] = data_shape[i]
- # If the new shape matches the input shape, no transformation is
needed.
- if new_shape == data_shape:
- return data
return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape))
# Otherwise handle dynamic shapes.
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 10c185ae09..ebc1454c23 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1692,6 +1692,12 @@ def test_expand(dynamic):
data = np.random.uniform(size=in_shape).astype(np.float32)
ref_data = np.tile(data, (1, 1, 4))
_test_expand("expand_with_diff_dim", data, shape, ref_data)
+
+ in_shape = (3, 1)
+ shape = (1, 1, 3, 1)
+ data = np.random.uniform(size=in_shape).astype(np.float32)
+ ref_data = np.tile(data, (1, 1, 1, 1))
+ _test_expand("expand_with_the_same_suffix_dims", data, shape, ref_data)
else:
in_shape = (1, 32, 32)
shape = ("batch", 32, 32)