This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 118e3b1316 [Relax][Frontend][ONNX] Error converting operator Expand: 
TVMError: broadcast_to expects the input tensor shape is broadcastable to the 
target shape (#18329)
118e3b1316 is described below

commit 118e3b1316413841033a6f9ca0857002287b5a1d
Author: Neo Chien <[email protected]>
AuthorDate: Tue Sep 23 09:21:40 2025 +0800

    [Relax][Frontend][ONNX] Error converting operator Expand: TVMError: 
broadcast_to expects the input tensor shape is broadcastable to the target 
shape (#18329)
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py |  85 ++++++++++++++++++--
 tests/python/relax/test_frontend_onnx.py        | 100 ++++++++++++++++++++++++
 2 files changed, 177 insertions(+), 8 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 5470c911d3..7a4a65df6e 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1910,15 +1910,47 @@ class Expand(OnnxOpConverter):
         if isinstance(shape, relax.ShapeExpr):
             data_shape = list(data.struct_info.shape)
             target_shape = list(shape.values)
+            original_data_shape = [
+                dim.value if hasattr(dim, "value") else str(dim) for dim in 
data_shape
+            ]
+            original_target_shape = [
+                dim.value if hasattr(dim, "value") else str(dim) for dim in 
target_shape
+            ]
             data_shape = [1] * (len(target_shape) - len(data_shape)) + 
data_shape
             assert len(data_shape) == len(target_shape)
-            # Fix small target shapes or target shapes assigned to -1
+            # Apply ONNX v13 Expand broadcasting rules
             for i, s in enumerate(target_shape):
-                if isinstance(s, tvm.tir.IntImm) and (
-                    (isinstance(data_shape[i], tvm.tir.IntImm) and s < 
data_shape[i])
-                    or s.value == -1
-                ):
-                    target_shape[i] = data_shape[i]
+                if isinstance(s, tvm.tir.IntImm):
+                    if s.value == -1:
+                        # -1 means preserve the input dimension
+                        target_shape[i] = data_shape[i]
+                    elif isinstance(data_shape[i], tvm.tir.IntImm) and 
data_shape[i].value == 1:
+                        # Input dimension is 1, can broadcast to any target 
dimension >= 1
+                        if s.value < 1:
+                            raise ValueError(
+                                f"ONNX Expand: Invalid target dimension 
{s.value} "
+                                f"at possition {i}. Target dimensions must be 
>= 1."
+                            )
+                    elif (
+                        isinstance(data_shape[i], tvm.tir.IntImm) and s.value 
== data_shape[i].value
+                    ):
+                        # Dimensions match, no change needed
+                        pass
+                    elif s.value == 1:
+                        # Target dimension is 1 but input dimension is not 1
+                        # This would "squeeze" the dimension - preserve input 
for safety
+                        target_shape[i] = data_shape[i]
+                    else:
+                        if isinstance(data_shape[i], tvm.tir.IntImm):
+                            raise ValueError(
+                                f"ONNX Expand: Cannot broadcast input shape 
{original_data_shape} "
+                                f"to target shape {original_target_shape}. "
+                                f"At dimension {i}: input size 
{data_shape[i].value} is "
+                                f"incompatible with target size {s.value}. "
+                                f"ONNX broadcasting requires corresponding 
dimensions to have "
+                                f"the same value or one of them to be 1."
+                            )
+                        # For dynamic shapes, let broadcast_to handle it
             if target_shape == data_shape:
                 return data
             return relax.op.broadcast_to(data, relax.ShapeExpr(target_shape))
@@ -1929,6 +1961,8 @@ class Expand(OnnxOpConverter):
             # ONNX Expand operator requires preserving target rank and 
broadcasting
             # according to standard rules. Dimensions are right-aligned.
             data_shape = [dim.value for dim in data.struct_info.shape]
+            original_data_shape = data_shape.copy()
+            original_new_shape = new_shape.copy()
 
             # Right-align the shapes
             if len(new_shape) > len(data_shape):
@@ -1938,8 +1972,32 @@ class Expand(OnnxOpConverter):
             # Fix small target shapes - if target dim is smaller than input dim
             # use the input dim (ONNX-specific behavior).
             for i in range(len(new_shape)):
-                if new_shape[i] < data_shape[i]:
+                if new_shape[i] == -1:
+                    # -1 means preserve the input dimension
+                    new_shape[i] = data_shape[i]
+                elif data_shape[i] == 1:
+                    # Input dimension is 1, can broadcast to any target 
dimension >= 1
+                    if new_shape[i] < 1:
+                        raise ValueError(
+                            f"ONNX Expand: Invalid target dimension 
{new_shape[i]} "
+                            f"at possition {i}. Target dimensions must be >= 
1."
+                        )
+                elif new_shape[i] == data_shape[i]:
+                    # Dimensions match, no change needed
+                    pass
+                elif new_shape[i] == 1:
+                    # Target dimension is 1 but input dimension is not 1
+                    # This would "squeeze" the dimension - preserve input for 
safety
                     new_shape[i] = data_shape[i]
+                else:
+                    raise ValueError(
+                        f"ONNX Expand: Cannot broadcast input shape 
{original_data_shape} "
+                        f"to target shape {original_new_shape}. "
+                        f"At dimension {i}: input size {data_shape[i]} is 
incompatible "
+                        f"with target size {new_shape[i]}. "
+                        f"ONNX broadcasting requires corresponding dimensions 
to have the same "
+                        f"value or one of them to be 1."
+                    )
             return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape))
 
         # Otherwise handle dynamic shapes.
@@ -1956,7 +2014,18 @@ class Expand(OnnxOpConverter):
         for i in range(shape_ndim):
             shape_vars.append(tvm.tir.Var("x_%d" % i, "int64"))
         bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars))
-        return bb.normalize(relax.op.broadcast_to(data, 
relax.ShapeExpr(shape_vars)))
+
+        # Applying broadcasting rules for dynamic shapes
+        data_shape = list(data.struct_info.shape)
+        data_ndim = len(data_shape)
+        target_ndim = shape_ndim
+        padded_data = data
+
+        if target_ndim > data_ndim:
+            padded_data_shape = [tir.IntImm("int64", 1)] * (target_ndim - 
data_ndim) + data_shape
+            padded_data = bb.normalize(relax.op.reshape(data, 
relax.ShapeExpr(padded_data_shape)))
+
+        return bb.normalize(relax.op.broadcast_to(padded_data, 
relax.ShapeExpr(shape_vars)))
 
 
 class Attention(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 625cdebf7f..d2f5a65593 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1909,6 +1909,106 @@ def test_expand(dynamic):
         _test_expand_dynamic_shapeexpr("expand_with_dynamic_dim", data, 
shape_data, shape, ref_data)
 
 
+def test_expand_incompatible_broadcasting():
+    """
+    This test case reproduces the error where input tensor shape at dim 1 is 25
+    and target shape at dim 3 is 56, which violates ONNX broadcasting rules
+    """
+
+    def _test_expand_error_case(name, data_shape, target_shape_vals):
+        data = np.random.uniform(size=data_shape).astype(np.float32)
+
+        shape_array = np.array(target_shape_vals, dtype=np.int64)
+        shape_node = onnx.helper.make_node(
+            "Constant",
+            inputs=[],
+            outputs=["shape"],
+            value=onnx.helper.make_tensor(
+                name="const_tensor",
+                data_type=onnx.TensorProto.INT64,
+                dims=shape_array.shape,
+                vals=shape_array.flatten(),
+            ),
+        )
+
+        expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
+
+        graph = helper.make_graph(
+            [shape_node, expand_node],
+            "expand_error_test",
+            inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, 
list(data.shape))],
+            outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
target_shape_vals)],
+        )
+
+        model = helper.make_model(graph, producer_name=name)
+
+        with pytest.raises(ValueError) as exc_info:
+            from_onnx(model, keep_params_in_input=True)
+
+        error_msg = str(exc_info.value)
+        assert (
+            "broadcast" in error_msg.lower() or "incompatible" in 
error_msg.lower()
+        ), f"Expected broadcasting error, but got: {error_msg}"
+
+    # Test case 1: Reproduce the exact error from the issue-17769
+    # Input shape: (25,), target shape: (1, 1, 1, 56)
+    # This should faill because input dim 1 (25) != target dim 3 (56) and 
neither is 1
+    _test_expand_error_case(
+        "expand_incompatible_25_to_56",
+        data_shape=(25,),
+        target_shape_vals=(1, 1, 1, 56),
+    )
+
+    # Test case 2: Another incompatible case
+    # Input shape: (1, 25), target shape: (1, 1, 1, 56)
+    # After right-alignment, input (1, 1, 1, 25) vs. target (1, 1, 1, 56)
+    # This should fail because 25 != 56 and neither is 1
+    _test_expand_error_case(
+        "expand_incompatible_aligned_25_to_56",
+        data_shape=(1, 25),
+        target_shape_vals=(1, 1, 1, 56),
+    )
+
+    # Test case 3: Valid case for comparison - should not raise error
+    def _test_expand_valid_case():
+        """Test a valid expand case to ensure our fix doesn't break valid 
operations"""
+        data_shape = (1, 25)
+        target_shape_vals = [2, 25]  # Valid: input (1, 25) can broadcast to 
(2, 25)
+
+        data = np.random.uniform(size=data_shape).astype(np.float32)
+        shape_array = np.array(target_shape_vals, dtype=np.int64)
+
+        shape_node = onnx.helper.make_node(
+            "Constant",
+            inputs=[],
+            outputs=["shape"],
+            value=onnx.helper.make_tensor(
+                name="const_tensor",
+                data_type=onnx.TensorProto.INT64,
+                dims=shape_array.shape,
+                vals=shape_array.flatten(),
+            ),
+        )
+
+        expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
+
+        graph = helper.make_graph(
+            [shape_node, expand_node],
+            "expand_valid_test",
+            inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, 
list(data.shape))],
+            outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
target_shape_vals)],
+        )
+
+        model = helper.make_model(graph, 
producer_name="expand_valid_test_case")
+
+        try:
+            tvm_model = from_onnx(model, keep_params_in_input=True)
+        except Exception as e:
+            pytest.fail(f"Valid expand case should not fail, but got error: 
{e}")
+
+    _test_expand_valid_case()
+
+
 # TODO(jwfromm) Current approach to dynamic expand is technically not well 
formed. Reenable once fixed.
 @pytest.mark.skip("Produces ill-formed IR")
 def test_constantofshape():

Reply via email to