This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 701a7539aa [Relax][Onnx] Support Multi Input Ops with Multidirectional
Broadcasting (#18673)
701a7539aa is described below
commit 701a7539aaf99e2855ba0274a9d45e4c7942d733
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Thu Jan 29 18:22:07 2026 +0700
[Relax][Onnx] Support Multi Input Ops with Multidirectional Broadcasting
(#18673)
This PR support Multi Input Ops with Multidirectional Broadcasting
### Description
- Support Multi Input Ops with Multidirectional Broadcasting (Min, Max,
Mean, Sum)
- Edit handle workflow for MultiInputBase:
+ Compute target shape for Multidirectional Broadcasting
+ Broadcast_to with target shape
+ Stack op
+ Reduce ops with axis same stack op
### Expected
- Example target shape:
<img width="700" height="183" alt="image"
src="https://github.com/user-attachments/assets/f9569dff-588e-49c5-ae72-c5b6ea22b6f3"
/>
### Reference
- Multidirectional Broadcasting:
https://onnx.ai/onnx/repo-docs/Broadcasting.html
- Fixed: #18592
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 26 ++++++++++--
tests/python/relax/test_frontend_onnx.py | 55 +++++++++++++++++++------
2 files changed, 66 insertions(+), 15 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index befe131a69..c71fd96caf 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -38,6 +38,7 @@ import math
import operator
import re
import warnings
+import functools
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as _np
@@ -1659,6 +1660,22 @@ class Sqrt(OnnxOpConverter):
return relax.op.sqrt(inputs[0])
+def compute_broadcast_shape(shape_a, shape_b):
+ """Compute target shape for Multidirectional Broadcasting"""
+ rank = max(len(shape_a), len(shape_b))
+
+ a = (1,) * (rank - len(shape_a)) + tuple(shape_a)
+ b = (1,) * (rank - len(shape_b)) + tuple(shape_b)
+
+ target = []
+ for ai, bi in zip(a, b):
+ if ai == bi or ai == 1 or bi == 1:
+ target.append(max(ai, bi))
+ else:
+ raise ValueError(f"Cannot broadcast {ai} and {bi}")
+ return tuple(target)
+
+
class MultiInputBase(OnnxOpConverter):
"""Converts an onnx MultiInputBase node into an equivalent Relax
expression."""
@@ -1674,9 +1691,12 @@ class MultiInputBase(OnnxOpConverter):
output = cls.numpy_op(*np_inputs) # pylint: disable=not-callable
return relax.const(output, output.dtype)
- # Expand inputs, stack them, then perform minimum over the new axis.
- inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in
inputs]
- stacked_tensor = relax.op.concat(inputs, axis=0)
+ input_shapes = [inp.struct_info.shape for inp in inputs]
+ target_shape = functools.reduce(compute_broadcast_shape, input_shapes)
+
+ # broadcast_to, stack them, then perform minimum over the new axis.
+ inputs = [bb.normalize(relax.op.broadcast_to(i, target_shape)) for i
in inputs]
+ stacked_tensor = bb.normalize(relax.op.stack(inputs, axis=0))
return cls.relax_op(stacked_tensor, axis=0) # pylint:
disable=not-callable
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 344bc26065..b4b3baeb4d 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -393,22 +393,53 @@ def test_mod(int_mode: bool):
verify_binary_scalar("Mod", attrs={"fmod": fmod}, dtype=dtype)
[email protected]("num_inputs", [1, 2, 4])
+SHAPE_PARAMS = [
+ ([[32, 32], [32, 32]], [32, 32]),
+ ([[32, 1], [1, 2]], [32, 2]),
+ (
+ [
+ [
+ 32,
+ ],
+ [
+ 1,
+ ],
+ ],
+ [
+ 32,
+ ],
+ ),
+ ([[32, 32, 1, 1], [1, 32, 32]], [32, 32, 32, 32]),
+ (
+ [
+ [32, 32, 1, 1],
+ [1, 32, 1],
+ [
+ 32,
+ ],
+ ],
+ [32, 32, 32, 32],
+ ),
+]
+
+
[email protected]("input_shapes, expected_output_shape", SHAPE_PARAMS)
@pytest.mark.parametrize("op_name", ["Min", "Max", "Sum", "Mean"])
-def test_multi_input(op_name: str, num_inputs: int):
- input_shape = [32, 32]
- input_var = ["i" + str(i) for i in range(num_inputs)]
- input_values = [
- helper.make_tensor_value_info(var, TensorProto.FLOAT, input_shape) for
var in input_var
- ]
- test_node = helper.make_node(op_name, input_var, ["c"])
+def test_multi_input_broadcasting(op_name, input_shapes,
expected_output_shape):
+ num_inputs = len(input_shapes)
+ input_names = [f"i{i}" for i in range(num_inputs)]
+
+ input_values_info = []
+ for name, shape in zip(input_names, input_shapes):
+ input_values_info.append(helper.make_tensor_value_info(name,
TensorProto.FLOAT, shape))
+ test_node = helper.make_node(op_name, input_names, ["output"])
+ output_info = helper.make_tensor_value_info("output", TensorProto.FLOAT,
expected_output_shape)
graph = helper.make_graph(
[test_node],
- "multi_input_test",
- inputs=input_values,
- outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT,
input_shape)],
+ f"multi_input_{op_name}_test",
+ inputs=input_values_info,
+ outputs=[output_info],
)
-
model = helper.make_model(graph, producer_name="multi_input_test")
check_correctness(model)