huenwei-arch opened a new issue, #18751:
URL: https://github.com/apache/tvm/issues/18751
### Expected behavior
The TVM ONNX frontend should correctly implement the "uneven split" logic
for the `Split` operator as defined in Opset 18+. When the `num_outputs`
attribute is provided:
1. It should calculate `block_size = ceil(dimension / num_outputs)`.
2. The first $N-1$ outputs should have the size of `block_size`.
3. The last output should contain the remainder.
For an input length of 10 and `num_outputs=3`, the expected output shapes
are `[4, 4, 2]`.
### Actual behavior
TVM correctly handles uniform splits (e.g., 9/3), but fails to convert the
model when an uneven split is required (e.g., 10/3). The frontend throws a
conversion error, indicating it cannot handle dimensions that are not perfectly
divisible by the number of outputs.
**Reproduction Log:**
```text
>>> Testing Split: Input Length 9 / 3 parts
ONNX Runtime shapes: [3, 3, 3]
TVM shapes: [3, 3, 3]
Result: PASS
>>> Testing Split: Input Length 10 / 3 parts
ONNX Runtime shapes: [4, 4, 2]
Error converting operator Split, with inputs: [X]
Result: FAIL (Conversion or Runtime Error)
Error: Traceback (most recent call last): ...
src/relax/ir/block_builder.cc:65: Warning: BlockBuilder destroyed with
remaining blocks!
```
### Environment
* **OS**: Ubuntu 20.04.6 LTS (Focal Fossa)
* **TVM Version**: 0.19.0 (Relax)
* **ONNX Version**: 1.18.0
* **ONNX Runtime Version**: 1.24.1
* **NumPy Version**: 2.4.2
### Steps to reproduce
```python
import onnx
from onnx import helper, TensorProto
import numpy as np
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
import onnxruntime as ort
def run_split_test(input_len, num_outputs):
print(f"\n>>> Testing Split: Input Length {input_len} / {num_outputs}
parts")
# 1. Construct ONNX Model
x_np = np.arange(input_len).astype(np.float32)
node = helper.make_node(
'Split',
inputs=['X'],
outputs=[f'Y{i}' for i in range(num_outputs)],
axis=0,
num_outputs=num_outputs
)
graph = helper.make_graph(
[node],
'split_test',
[helper.make_tensor_value_info('X', TensorProto.FLOAT, [input_len])],
[helper.make_tensor_value_info(f'Y{i}', TensorProto.FLOAT, [None])
for i in range(num_outputs)]
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("",
18)])
# 2. Reference output (ORT)
sess = ort.InferenceSession(model.SerializeToString())
ort_outs = sess.run(None, {'X': x_np})
ort_shapes = [o.shape[0] for o in ort_outs]
print(f" ONNX Runtime shapes: {ort_shapes}")
# 3. TVM output
try:
tvm_mod = from_onnx(model)
target = tvm.target.Target("llvm")
exe = relax.build(tvm_mod, target)
vm = relax.VirtualMachine(exe, tvm.cpu())
tvm_outs = vm["main"](tvm.nd.array(x_np))
tvm_shapes = [o.asnumpy().shape[0] for o in tvm_outs]
print(f" TVM shapes: {tvm_shapes}")
if tvm_shapes == ort_shapes:
print(" Result: PASS")
else:
print(" Result: FAIL (Shape Mismatch)")
except Exception as e:
print(f" Result: FAIL (Conversion or Runtime Error)")
print(f" Error: {str(e)[:100]}...")
if __name__ == "__main__":
# Case 1: Uniform split (Should PASS)
run_split_test(input_len=9, num_outputs=3)
# Case 2: Non-uniform split (Should FAIL)
run_split_test(input_len=10, num_outputs=3)
```
### Triage
* relax:frontend:onnx
* needs-triage
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]