This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0754ad82d6 [FRONTEND][ONNX] Fix operator Transpose: TVMError:
PermuteDims expects the number of input axes to equal the ndim of the input
tensor (#18435)
0754ad82d6 is described below
commit 0754ad82d6669af048effcf019cb549ed342605c
Author: Neo Chien <[email protected]>
AuthorDate: Fri Nov 14 13:23:57 2025 +0800
[FRONTEND][ONNX] Fix operator Transpose: TVMError: PermuteDims expects the
number of input axes to equal the ndim of the input tensor (#18435)
* [#17737] Fix operator Transpose: TVMError: PermuteDims expects the number
of input axes to equal the ndim of the input tensor
* [#17737] Add test case: test_transpose_scalar
* [#17737] Add test case: test_transpose_axes_validation
---------
Co-authored-by: cchung100m <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 29 +++++++++--
tests/python/relax/test_frontend_onnx.py | 68 +++++++++++++++++++++++++
2 files changed, 94 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 2e4e7a3125..24a4014f84 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -645,11 +645,34 @@ class Transpose(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
+ data = inputs[0]
axes = attr.get("perm", None)
- if isinstance(inputs[0], relax.Constant):
- output = _np.transpose(inputs[0].data.numpy(), axes)
+
+ if hasattr(data.struct_info, "ndim"):
+ input_ndim = data.struct_info.ndim
+ elif hasattr(data.struct_info, "shape") and data.struct_info.shape:
+ input_ndim = len(data.struct_info.shape)
+ else:
+ if isinstance(data, relax.Constant):
+ input_ndim = data.data.numpy().ndim
+ else:
+ input_ndim = None
+
+ if input_ndim == 0:
+ return data
+
+ if input_ndim is not None and axes is not None:
+ if len(axes) != input_ndim:
+ raise ValueError(
+ f"Transpose: number of axes in perm attribute
({len(axes)}) "
+ f"must equal the number of input tensor dimensions
({input_ndim})"
+ )
+
+ if isinstance(data, relax.Constant):
+ output = _np.transpose(data.data.numpy(), axes)
return relax.const(output, output.dtype)
- return relax.op.permute_dims(inputs[0], axes)
+
+ return relax.op.permute_dims(data, axes)
class Unsqueeze(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index a8d434e894..23348cf847 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -789,6 +789,74 @@ def test_transpose():
verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]})
+def test_transpose_scalar():
+ """Test Transpose with scalar inputs - should return scalar unchanged."""
+ # Test scalar with no perm attribute (default behavior)
+ scalar_node = helper.make_node("Transpose", ["x"], ["y"])
+ graph = helper.make_graph(
+ [scalar_node],
+ "transpose_scalar_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [])],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])],
+ )
+ model = helper.make_model(graph, producer_name="transpose_scalar_test")
+ check_correctness(model)
+
+ # Test with scalar constant and transpose without perm
+ scalar_constant = helper.make_node(
+ "Constant",
+ [],
+ ["scalar"],
+ value=helper.make_tensor("value", TensorProto.FLOAT, [], [5.0]),
+ )
+
+ transpose_node = helper.make_node("Transpose", ["scalar"], ["y"])
+ graph = helper.make_graph(
+ [scalar_constant, transpose_node],
+ "transpose_scalar_constant_test",
+ inputs=[],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])],
+ )
+ model = helper.make_model(graph,
producer_name="transpose_scalar_constant_test")
+ check_correctness(model)
+
+
+def test_transpose_axes_validation():
+ """Test Transpose validation - perm axes count must match tensor
dimensions"""
+ # Test 1D tensor with correct perm
+ transpose_1d_valid = helper.make_node("Transpose", ["x"], ["y"], perm=[0])
+ graph_1d_valid = helper.make_graph(
+ [transpose_1d_valid],
+ "transpose_1d_valid_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [10])],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [10])],
+ )
+ model_1d_valid = helper.make_model(graph_1d_valid,
producer_name="transpose_1d_valid_test")
+ check_correctness(model_1d_valid)
+
+ # Test 2D tensor with correct perm
+ transpose_2d_valid = helper.make_node("Transpose", ["x"], ["y"], perm=[1,
0])
+ graph_2d_valid = helper.make_graph(
+ [transpose_2d_valid],
+ "transpose_2d_valid_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 4])],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [4,
3])],
+ )
+ model_2d_valid = helper.make_model(graph_2d_valid,
producer_name="transpose_2d_valid_test")
+ check_correctness(model_2d_valid)
+
+ # Test 3D tensor with correct perm
+ transpose_3d_valid = helper.make_node("Transpose", ["x"], ["y"], perm=[2,
0, 1])
+ graph_3d_valid = helper.make_graph(
+ [transpose_3d_valid],
+ "transpose_3d_valid_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3,
4])],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [4, 2,
3])],
+ )
+ model_3d_valid = helper.make_model(graph_3d_valid,
producer_name="transpose_3d_valid_test")
+ check_correctness(model_3d_valid)
+
+
def test_unsqueeze():
unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])