This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ec0026e0bc [Relax][PyTorch] Fix index_put with broadcast indices
(#18533)
ec0026e0bc is described below
commit ec0026e0bc8b7904b29e167e39b252c7e2794d4a
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Dec 2 04:21:32 2025 +0800
[Relax][PyTorch] Fix index_put with broadcast indices (#18533)
## Related Issue
closes https://github.com/apache/tvm/issues/18355
## Why
Converting PyTorch operations like M[:, rows, cols] = x failed because:
1. The TOPI index_put implementation called len() on TVM Tensor objects
(unsupported)
2. Index tensors with different shapes (e.g., (2,) and (10,)) couldn't
broadcast together
## How
- Added broadcasting support following NumPy rules to handle
multi-dimensional index tensors
- add tests for batched indexing pattern M[:, rows, cols] = x
---
.../frontend/torch/base_fx_graph_translator.py | 3 +-
python/tvm/relax/op/manipulate.py | 2 +-
python/tvm/topi/index_put.py | 68 ++++++++++++++++++----
.../relax/test_frontend_from_exported_program.py | 49 ++++++++++++++++
4 files changed, 108 insertions(+), 14 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index e9a9cdd939..7ebb95c136 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1812,8 +1812,9 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
)
)
# Reshape to [dim_size, 1, 1, ...] for broadcasting
+ # Add an extra dimension so it broadcasts with other
indices
arange_idx = self.block_builder.emit(
- relax.op.reshape(arange_idx, [data_shape[i]] + [1]
* (max_ndim - 1))
+ relax.op.reshape(arange_idx, [data_shape[i]] + [1]
* max_ndim)
)
processed_indices.append(arange_idx)
else:
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index bb134f1148..ee486b0ab6 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -642,7 +642,7 @@ def index_put(
[0.0, 3.0, 0.0],
]
"""
- if not isinstance(indices, (list, tuple)):
+ if isinstance(indices, (list, tuple)):
indices = RxTuple(indices)
return _ffi_api.index_put(data, indices, values, accumulate) # type:
ignore
diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py
index f51c6718ab..52406d402c 100644
--- a/python/tvm/topi/index_put.py
+++ b/python/tvm/topi/index_put.py
@@ -1,6 +1,6 @@
# Licensed to the Apache Software Foundation (ASF) under one
-# or more contrir_builderutor license agreements. See the NOTICE file
-# distrir_builderuted with this work for additional information
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
@@ -9,7 +9,7 @@
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
-# software distrir_builderuted under the License is distrir_builderuted on an
+# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
@@ -29,7 +29,8 @@ def index_put(data, indices, values, accumulate=False):
The source array to be modified.
indices : Tuple[tvm.te.Tensor]
- Tuple of 1D index tensors (one for each dimension) specifying
positions.
+ Tuple of index tensors (can be multi-dimensional) specifying positions.
+ Index tensors are broadcast together following NumPy broadcasting
rules.
values : tvm.te.Tensor
The values to place at the specified indices.
@@ -60,11 +61,28 @@ def index_put(data, indices, values, accumulate=False):
for dim in shape:
full_range *= dim
- # Check all indices have same length
- index_len = len(indices[0])
- for idx in indices[1:]:
- if not utils.equal_const_int(len(idx), index_len):
- raise ValueError("All index tensors must have same length")
+ index_shapes = [idx.shape for idx in indices]
+ broadcast_ndim = max(len(s) for s in index_shapes)
+ broadcast_shape = []
+
+ for i in range(broadcast_ndim):
+ max_dim = 1
+ for idx_shape in index_shapes:
+ # Right-align shapes
+ dim_idx = len(idx_shape) - broadcast_ndim + i
+ if dim_idx >= 0:
+ dim_size = idx_shape[dim_idx]
+ if not utils.equal_const_int(dim_size, 1):
+ if utils.equal_const_int(max_dim, 1):
+ max_dim = dim_size
+ elif not utils.equal_const_int(dim_size, max_dim):
+ raise ValueError(f"Cannot broadcast index shapes:
{index_shapes}")
+ broadcast_shape.append(max_dim)
+
+ # Compute total number of elements after broadcasting
+ index_len = 1
+ for dim in broadcast_shape:
+ index_len *= dim
def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func):
ir_builder = tir.ir_builder.create()
@@ -78,12 +96,38 @@ def index_put(data, indices, values, accumulate=False):
out[i] = data[i]
with ir_builder.for_range(0, index_len, "k", kind="parallel") as k:
- # Calculate multi-dimensional index
+ # Decompose k into multi-dimensional broadcast index
+ k_temp = k
+ broadcast_indices = []
+ for i in range(broadcast_ndim - 1, -1, -1):
+ broadcast_indices.insert(0, k_temp % broadcast_shape[i])
+ k_temp = k_temp // broadcast_shape[i]
+
flat_index = 0
stride = 1
for dim in range(len(shape) - 1, -1, -1):
- # Get index and shift to positive if needed
- idx_val = indices[dim][k]
+ # Get the index for this dimension using broadcasting
+ idx_shape = index_shapes[dim]
+ idx_ndim = len(idx_shape)
+
+ # Compute the linear index into this index tensor
+ idx_offset = 0
+ idx_stride = 1
+ for i in range(broadcast_ndim - 1, -1, -1):
+ # Right-align the index shape with broadcast shape
+ dim_idx = idx_ndim - broadcast_ndim + i
+ if dim_idx >= 0:
+ dim_size = idx_shape[dim_idx]
+ # Use broadcasting: if size is 1, use index 0
+ # otherwise use broadcast_indices[i]
+ if utils.equal_const_int(dim_size, 1):
+ idx_in_dim = 0
+ else:
+ idx_in_dim = broadcast_indices[i]
+ idx_offset += idx_in_dim * idx_stride
+ idx_stride *= dim_size
+
+ idx_val = indices[dim][idx_offset]
shifted_idx = idx_val + (idx_val < 0) * shape[dim]
flat_index += shifted_idx * stride
stride *= shape[dim]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 0658dbfaf3..010bd026a8 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7133,6 +7133,54 @@ def test_index_put():
R.output(gv)
return gv
+ # Test case 9: batched indexing with slice (e.g., M[:, rows, cols] = x)
+ class IndexPutBatchedWithNone(Module):
+ def forward(self, x):
+ B = x.size(0)
+ M = torch.zeros(B, 11, 11)
+ rows = torch.arange(10)
+ cols = rows + 1
+ M[:, rows, cols] = x # Batched index assignment
+ return M
+
+ example_args_batched_none = (torch.randn(2, 10, dtype=torch.float32),)
+
+ @I.ir_module
+ class ExpectedBatchedWithNone:
+ @R.function
+ def main(
+ x: R.Tensor((2, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((2, 11, 11), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 11, 11), dtype="float32") = R.full(
+ R.shape([2, 11, 11]), R.const(0.0, "float32"),
dtype="float32"
+ )
+ lv1: R.Tensor((10,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(10), R.prim_value(1),
dtype="int64"
+ )
+ lv2: R.Tensor((10,), dtype="int64") = R.add(lv1, R.const(1,
"int64"))
+ lv3: R.Tensor((2, 11, 11), dtype="float32") = R.strided_slice(
+ lv,
+ (R.prim_value(0),),
+ (R.prim_value(0),),
+ (R.prim_value(9223372036854775807),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv4: R.Tensor((2,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(2), R.prim_value(1),
dtype="int64"
+ )
+ lv5: R.Tensor((2, 1), dtype="int64") = R.reshape(lv4,
R.shape([2, 1]))
+ lv6: R.Tensor((2, 11, 11), dtype="float32") = R.index_put(
+ lv3, (lv5, lv1, lv2), x, accumulate=False
+ )
+ lv7: R.Tensor((2, 11, 11), dtype="float32") = R.slice_scatter(
+ lv, lv6, R.prim_value(0), R.prim_value(2),
R.prim_value(1), axis=0
+ )
+ gv: R.Tuple(R.Tensor((2, 11, 11), dtype="float32")) = (lv7,)
+ R.output(gv)
+ return gv
+
# Run verification for each case
verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
@@ -7142,6 +7190,7 @@ def test_index_put():
verify_model(IndexPutBroadcast1D(), example_args_broadcast1, {},
ExpectedBroadcast1D)
verify_model(IndexPutBroadcast2D(), example_args_broadcast2, {},
ExpectedBroadcast2D)
verify_model(IndexPutBroadcast3D(), example_args_broadcast3d, {},
ExpectedBroadcast3D)
+ verify_model(IndexPutBatchedWithNone(), example_args_batched_none, {},
ExpectedBatchedWithNone)
def test_flip():