This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0149fcd3c6 Add support for index_put_ op (#17865)
0149fcd3c6 is described below
commit 0149fcd3c6e10226af03359fdfa4cc13ec24b73e
Author: Pratheesh-04-MCW <[email protected]>
AuthorDate: Thu Apr 24 21:37:23 2025 +0530
Add support for index_put_ op (#17865)
* add support for index_put_ op
* add test cases from 1D to 5D in both exported program and fx graph
* lint issues
* lint check
* lint issue
* whitespace issue
* whitespace issue
* whitespace issue
* line length error
* trailing space issue
* modified conditions for parameters
* modified base_fx
* resolved conflicts
* removed trailing whitespace
* lint issue
* unity issue
---
include/tvm/relax/attrs/manipulate.h | 14 ++
.../frontend/torch/base_fx_graph_translator.py | 17 ++
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/manipulate.py | 51 ++++++
python/tvm/relax/op/op_attrs.py | 5 +
.../tvm/relax/transform/legalize_ops/manipulate.py | 22 +++
python/tvm/script/ir_builder/relax/ir.py | 2 +
python/tvm/topi/__init__.py | 1 +
python/tvm/topi/index_put.py | 117 ++++++++++++++
src/relax/op/tensor/manipulate.cc | 123 +++++++++++++++
src/relax/op/tensor/manipulate.h | 13 ++
.../relax/test_frontend_from_exported_program.py | 175 +++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 171 ++++++++++++++++++++
15 files changed, 714 insertions(+)
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index 67f99d9b41..943d2f4d0d 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -182,6 +182,20 @@ struct GatherNDAttrs : public
tvm::AttrsNode<GatherNDAttrs> {
}
}; // struct GatherNDAttrs
+/*! \brief Attributes used in index_put operator */
+struct IndexPutAttrs : public tvm::AttrsNode<IndexPutAttrs> {
+ bool accumulate;
+
+ TVM_DECLARE_ATTRS(IndexPutAttrs, "relax.attrs.IndexPutAttrs") {
+ TVM_ATTR_FIELD(accumulate)
+ .set_default(false)
+ .describe(
+ "Whether to accumulate (add) values rather than replace. "
+ "If true, performs tensor[indices] += values, "
+ "otherwise performs tensor[indices] = values.");
+ }
+}; // struct IndexPutAttrs
+
/*! \brief Attributes used in scatter_elements operators */
struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 33f6ffc313..5dd78be483 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1118,6 +1118,23 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
index = self.env[node.args[2]]
return self.block_builder.emit(relax.op.gather_elements(x, index,
axis=dim))
+ def _index_put(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ tensor = args[0]
+ indices = args[1] if len(args) > 1 else node.kwargs.get("indices")
+ values = args[2] if len(args) > 2 else node.kwargs.get("values")
+ accumulate = args[3] if len(args) > 3 else
node.kwargs.get("accumulate", False)
+
+ if indices is None or values is None:
+ raise ValueError("'indices and values' arguments are required for
index_put operation")
+
+ if not isinstance(accumulate, bool):
+ raise TypeError("'accumulate' must be a boolean value, got
{}".format(type(accumulate)))
+
+ if isinstance(indices, (list, tuple)):
+ indices = relax.Tuple(indices)
+ return self.block_builder.emit(relax.op.index_put(tensor, indices,
values, accumulate))
+
def _index_tensor(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
indices = args[1]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index aa4984c0ba..db5ca01399 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -433,6 +433,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"flip.default": self._flip,
"gather.default": self._gather,
"index.Tensor": self._index_tensor,
+ "index_put_.default": self._index_put,
"narrow.default": self._narrow,
"permute.default": self._permute,
"repeat.default": self._repeat,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index bab272bd78..18dba2d988 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -801,6 +801,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"flatten": self._flatten,
"flip": self._flip,
"gather": self._gather,
+ "index_put_": self._index_put,
"narrow": self._narrow,
"numel": self._numel,
"permute": self._permute,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 097313a33d..7b8c34b641 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -95,6 +95,7 @@ from .manipulate import (
flip,
gather_elements,
gather_nd,
+ index_put,
index_tensor,
layout_transform,
one_hot,
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index a693adf432..13334d1479 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -595,6 +595,57 @@ def index_tensor(data: Expr, indices: Union[Expr,
List[Expr]]) -> Expr:
return _ffi_api.index_tensor(data, indices) # type: ignore
+def index_put(
+ data: Expr,
+ indices: Union[Expr, Tuple[Expr]],
+ values: Expr,
+ accumulate: bool = False,
+) -> Expr:
+ """This operation updates values in `data` at positions
+ specified by `indices` with corresponding values from `values`. The
`indices` is a tuple
+ of tensors where each tensor corresponds to a dimension in `data`.
+ When `accumulate` is True, the operation performs accumulation (addition)
rather than
+ replacement. The `reduction` parameter allows specifying different
reduction operations.
+ Parameters
+ ----------
+ data : relax.Expr
+ The input tensor to be modified
+ indices : Union[Expr, Tuple[Expr]]
+ Tuple of index tensors (one for each dimension) specifying positions
to update
+ values : relax.Expr
+ Values to place at the specified indices
+ accumulate : bool
+ Whether to accumulate (add) values rather than replace (default: False)
+
+ Returns
+ -------
+ result : relax.Expr
+ A new tensor with the same shape as data but with specified positions
updated
+ Examples
+ --------
+ .. code-block:: python
+ # inputs
+ data = torch.zeros(3, 3)
+ indices = (torch.tensor([0, 2]), torch.tensor([1, 1]))
+ values = torch.tensor([1.0, 2.0])
+ # output
+ output = [
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 0.0],
+ [0.0, 2.0, 0.0],
+ ]
+ # with accumulate=True
+ output = [
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 0.0],
+ [0.0, 3.0, 0.0],
+ ]
+ """
+ if not isinstance(indices, (list, tuple)):
+ indices = RxTuple(indices)
+ return _ffi_api.index_put(data, indices, values, accumulate) # type:
ignore
+
+
def scatter_elements(
data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str =
"update"
):
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index fda4258a09..fe527e38e8 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -144,6 +144,11 @@ class StackAttrs(Attrs):
"""Attributes for concat operator"""
+@tvm._ffi.register_object("relax.attrs.IndexPutAttrs")
+class IndexPutAttrs(Attrs):
+ """Attributes for index_put operator"""
+
+
@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs")
class LayoutTransformAttrs(Attrs):
"""Attributes used in layout_transform operator"""
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 84baa887d9..a66b60c013 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -193,6 +193,28 @@ def _index_tensor(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.index_tensor, call.args[0], fields)
+@register_legalize("relax.index_put")
+def _index_put(bb: BlockBuilder, call: Call) -> Expr:
+ data = call.args[0]
+ indices = call.args[1]
+ values = call.args[2]
+ accumulate = call.attrs.accumulate
+
+ # If indices is a Tuple, unpack it into individual tensors
+ if isinstance(indices, relax.Tuple):
+ indices_list = [indices.fields[i] for i in range(len(indices.fields))]
+ else:
+ indices_list = [indices]
+
+ return bb.call_te(
+ topi.index_put,
+ data,
+ indices_list,
+ values,
+ accumulate=accumulate,
+ )
+
+
@register_legalize("relax.scatter_elements")
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 22b00cd704..d2952ed8e0 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -100,6 +100,7 @@ from tvm.relax.op import (
greater,
greater_equal,
hint_on_device,
+ index_put,
image,
index_tensor,
invoke_closure,
@@ -785,6 +786,7 @@ __all__ = [
"greater_equal",
"hexagon",
"hint_on_device",
+ "index_put",
"image",
"index_tensor",
"invoke_closure",
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index 1de6941c99..fa4e98a89a 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -33,6 +33,7 @@ from . import cpp
from .math import *
from .tensor import *
from .generic_op_impl import *
+from .index_put import *
from .reduction import *
from .transform import *
from .broadcast import *
diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py
new file mode 100644
index 0000000000..f51c6718ab
--- /dev/null
+++ b/python/tvm/topi/index_put.py
@@ -0,0 +1,117 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contrir_builderutor license agreements. See the NOTICE file
+# distrir_builderuted with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distrir_builderuted under the License is distrir_builderuted on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""IndexPut operator"""
+from tvm import te
+from tvm import tir
+from . import utils
+
+
+def index_put(data, indices, values, accumulate=False):
+ """Put values into an array according to indices.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ The source array to be modified.
+
+ indices : Tuple[tvm.te.Tensor]
+ Tuple of 1D index tensors (one for each dimension) specifying
positions.
+
+ values : tvm.te.Tensor
+ The values to place at the specified indices.
+
+ accumulate : bool, optional
+ Whether to accumulate (add) values rather than replace.
+ If True, performs tensor[indices] += values
+ If False, performs tensor[indices] = values
+ Default is False.
+
+ Returns
+ -------
+ ret : tvm.te.Tensor
+ """
+ if not isinstance(indices, (list, tuple)):
+ indices = [indices]
+
+ # Check indices match data dimensions
+ if len(indices) != len(data.shape):
+ raise ValueError(
+ f"Number of index tensors ({len(indices)}) must match "
+ f"data dimensions ({len(data.shape)})"
+ )
+
+ # Prepare ranges and strides
+ shape = data.shape
+ full_range = 1
+ for dim in shape:
+ full_range *= dim
+
+ # Check all indices have same length
+ index_len = len(indices[0])
+ for idx in indices[1:]:
+ if not utils.equal_const_int(len(idx), index_len):
+ raise ValueError("All index tensors must have same length")
+
+ def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func):
+ ir_builder = tir.ir_builder.create()
+
+ data = ir_builder.buffer_ptr(data_ptr)
+ indices = [ir_builder.buffer_ptr(idx) for idx in index_ptrs]
+ values = ir_builder.buffer_ptr(values_ptr)
+ out = ir_builder.buffer_ptr(out_ptr)
+
+ with ir_builder.for_range(0, full_range, "i", kind="parallel") as i:
+ out[i] = data[i]
+
+ with ir_builder.for_range(0, index_len, "k", kind="parallel") as k:
+ # Calculate multi-dimensional index
+ flat_index = 0
+ stride = 1
+ for dim in range(len(shape) - 1, -1, -1):
+ # Get index and shift to positive if needed
+ idx_val = indices[dim][k]
+ shifted_idx = idx_val + (idx_val < 0) * shape[dim]
+ flat_index += shifted_idx * stride
+ stride *= shape[dim]
+
+ reduce_func(out, flat_index, values[k])
+
+ return ir_builder.get()
+
+ def update_func(dst_ptr, dst_index, update):
+ dst_ptr[dst_index] = update
+
+ def add_func(dst_ptr, dst_index, update):
+ dst_ptr[dst_index] += update
+
+ reduce_func = add_func if accumulate else update_func
+
+ # Prepare input buffers
+ in_buffers = [data]
+ in_buffers.extend(indices)
+ in_buffers.append(values)
+
+ out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf")
+ return te.extern(
+ [data.shape],
+ in_buffers,
+ lambda ins, outs: gen_ir(ins[0], ins[1:-1], ins[-1], outs[0],
reduce_func),
+ dtype=data.dtype,
+ out_buffers=[out_buf],
+ name="index_put.generic",
+ tag="index_put.generic",
+ )
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index f56135a35b..482ebe5cac 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -1972,6 +1972,129 @@ TVM_REGISTER_OP("relax.gather_nd")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGatherND)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.index_put */
+TVM_REGISTER_NODE_TYPE(IndexPutAttrs);
+
+Expr index_put(Expr data, Expr indices, Expr values, bool accumulate) {
+ auto attrs = make_object<IndexPutAttrs>();
+ attrs->accumulate = std::move(accumulate);
+ static const Op& op = Op::Get("relax.index_put");
+ return Call(op, {data, indices, values}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.index_put").set_body_typed(index_put);
+
+StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) {
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* values_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+
+ auto diag_def = [&](const TensorStructInfoNode* sinfo, String name, String
type_key) {
+ if (sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "IndexPut requires the input " << name
+ << " to be a Tensor. However, the given one is " <<
type_key);
+ }
+ };
+
+ diag_def(data_sinfo, "data", call->args[0]->struct_info_->GetTypeKey());
+ diag_def(values_sinfo, "values", call->args[2]->struct_info_->GetTypeKey());
+
+ // Handle indices: either a single tensor or a tuple of tensors
+ Array<TensorStructInfo> indices_tensors;
+
+ if (const auto* tuple_sinfo =
GetStructInfoAs<TupleStructInfoNode>(call->args[1])) {
+ // Indices is a tuple of tensors
+ for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) {
+ const auto* tensor_sinfo =
tuple_sinfo->fields[i].as<TensorStructInfoNode>();
+ if (tensor_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "IndexPut requires each index in the indices tuple
to be a Tensor. "
+ << "However, element " << i << " is "
+ << tuple_sinfo->fields[i]->GetTypeKey());
+ }
+ indices_tensors.push_back(GetRef<TensorStructInfo>(tensor_sinfo));
+ }
+ } else if (const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[1])) {
+ // Indices is a single tensor
+ indices_tensors.push_back(GetRef<TensorStructInfo>(tensor_sinfo));
+ } else {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "IndexPut requires indices to be a Tensor or a tuple
of Tensors. "
+ << "However, the given one is " <<
call->args[1]->struct_info_->GetTypeKey());
+ }
+
+ if (data_sinfo->IsUnknownNdim()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim,
data_sinfo->vdevice);
+ }
+
+ // Validate each index tensor
+ for (size_t i = 0; i < indices_tensors.size(); ++i) {
+ const auto& tensor_sinfo = indices_tensors[i];
+ if (!tensor_sinfo->IsUnknownNdim() && tensor_sinfo->ndim != 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "IndexPut requires each index tensor to be 1D. "
+ << "However, index tensor " << i << " has ndim=" <<
tensor_sinfo->ndim);
+ }
+ if (tensor_sinfo->IsUnknownDtype()) {
+ LOG(WARNING) << "Data type of index tensor " << i
+ << " has not been specified. Assume it has an integer
type.";
+ } else if (!(tensor_sinfo->dtype.is_int() ||
tensor_sinfo->dtype.is_uint())) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "IndexPut requires each index tensor to have integer
dtype. "
+ << "However, index tensor " << i << " has dtype=" <<
tensor_sinfo->dtype);
+ }
+ }
+
+ // Check that the number of index tensors matches data dimensions
+ if (!data_sinfo->IsUnknownNdim() &&
+ indices_tensors.size() != static_cast<size_t>(data_sinfo->ndim)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "IndexPut requires the number of index tensors (" <<
indices_tensors.size()
+ << ") to match the data tensor dimensions (" <<
data_sinfo->ndim << ")");
+ }
+
+ // Check data and values dtype compatibility
+ if (data_sinfo->IsUnknownDtype() || values_sinfo->IsUnknownDtype()) {
+ auto diag_dtype = [&](const TensorStructInfoNode* sinfo, String name) {
+ if (sinfo->IsUnknownDtype()) {
+ LOG(WARNING) << "Data type of " << name
+ << " has not been specified. Assume it has an integer
type.";
+ }
+ };
+ diag_dtype(data_sinfo, "data");
+ diag_dtype(values_sinfo, "values");
+ } else if (data_sinfo->dtype != values_sinfo->dtype) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "IndexPut requires the input data to have the same
type as values. "
+ << "However, the given types are data: " <<
data_sinfo->dtype
+ << ", values: " << values_sinfo->dtype);
+ }
+
+ // Check values shape compatibility
+ const auto* values_shape = values_sinfo->shape.as<ShapeExprNode>();
+ if (values_shape) {
+ if (values_sinfo->ndim != 1) {
+ LOG(WARNING) << "IndexPut typically expects values to be 1D, but got
ndim="
+ << values_sinfo->ndim;
+ }
+ }
+
+ const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+ if (data_shape) {
+ return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype,
data_sinfo->vdevice);
+ }
+ return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim,
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.index_put")
+ .set_attrs_type<IndexPutAttrs>()
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("indices", "Tensor", "The indices tensor(s).")
+ .add_argument("values", "Tensor", "The values to put.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoIndexPut)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.scatter_elements */
TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 4580f9191b..2e4c92c150 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -218,6 +218,19 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims =
0);
*/
Expr index_tensor(Expr data, Expr indices);
+/*!
+ * \brief Put values into an array according to indices.
+ * \param data The input tensor to be modified.
+ * \param indices The index positions where values should be placed.
+ * This should be a tuple of 1D tensors (one for each
dimension).
+ * \param values The values to place at the specified indices.
+ * \param accumulate Whether to accumulate (add) values rather than replace.
+ * If true, equivalent to tensor[indices] += values.
+ * If false, equivalent to tensor[indices] = values.
+ * \return The computed result with values placed at specified indices.
+ */
+Expr index_put(Expr data, Expr indices, Expr values, bool accumulate = false);
+
/*!
* \brief Scatter updates into an array according to indices.
* \param data The input tensor.
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 5ef2c27e91..589d3f5bae 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4169,6 +4169,181 @@ def test_gather():
verify_model(Gather3(), example_args, {}, Expected3)
+def test_index_put():
+ # Test case 1: 1D input
+ class IndexPut1D(Module):
+ def forward(self, data, indices_0, values):
+ indices_tuple = (indices_0,)
+ return data.index_put_(indices_tuple, values, accumulate=False)
+
+ example_args_1d = (
+ torch.randn(64, dtype=torch.float32),
+ torch.randint(0, 64, (128,), dtype=torch.int64),
+ torch.randn(128, dtype=torch.float32),
+ )
+
+ @I.ir_module
+ class Expected1D:
+ @R.function
+ def main(
+ data: R.Tensor((64,), dtype="float32"),
+ indices_0: R.Tensor((128,), dtype="int64"),
+ values: R.Tensor((128,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((64,), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((64,), dtype="float32") = R.index_put(
+ data, R.tuple(indices_0), values, accumulate=False
+ )
+ gv: R.Tuple(R.Tensor((64,), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ # Test case 2: 2D input
+ class IndexPut2D(Module):
+ def forward(self, data, indices_0, indices_1, values):
+ indices_tuple = (indices_0, indices_1)
+ return data.index_put_(indices_tuple, values, accumulate=False)
+
+ example_args_2d = (
+ torch.randn(32, 64, dtype=torch.float32),
+ torch.randint(0, 32, (128,), dtype=torch.int64),
+ torch.randint(0, 64, (128,), dtype=torch.int64),
+ torch.randn(128, dtype=torch.float32),
+ )
+
+ @I.ir_module
+ class Expected2D:
+ @R.function
+ def main(
+ data: R.Tensor((32, 64), dtype="float32"),
+ indices_0: R.Tensor((128,), dtype="int64"),
+ indices_1: R.Tensor((128,), dtype="int64"),
+ values: R.Tensor((128,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((32, 64), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((32, 64), dtype="float32") = R.index_put(
+ data, R.tuple(indices_0, indices_1), values,
accumulate=False
+ )
+ gv: R.Tuple(R.Tensor((32, 64), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ # Test case 3: 3D input
+ class IndexPut3D(Module):
+ def forward(self, data, indices_0, indices_1, indices_2, values):
+ indices_tuple = (indices_0, indices_1, indices_2)
+ return data.index_put_(indices_tuple, values, accumulate=False)
+
+ example_args_3d = (
+ torch.randn(16, 32, 64, dtype=torch.float32),
+ torch.randint(0, 16, (128,), dtype=torch.int64),
+ torch.randint(0, 32, (128,), dtype=torch.int64),
+ torch.randint(0, 64, (128,), dtype=torch.int64),
+ torch.randn(128, dtype=torch.float32),
+ )
+
+ @I.ir_module
+ class Expected3D:
+ @R.function
+ def main(
+ data: R.Tensor((16, 32, 64), dtype="float32"),
+ indices_0: R.Tensor((128,), dtype="int64"),
+ indices_1: R.Tensor((128,), dtype="int64"),
+ indices_2: R.Tensor((128,), dtype="int64"),
+ values: R.Tensor((128,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((16, 32, 64), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put(
+ data, R.tuple(indices_0, indices_1, indices_2), values,
accumulate=False
+ )
+ gv: R.Tuple(R.Tensor((16, 32, 64), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ # Test case 4: 4D input
+ class IndexPut4D(Module):
+ def forward(self, data, indices_0, indices_1, indices_2, indices_3,
values):
+ indices_tuple = (indices_0, indices_1, indices_2, indices_3)
+ return data.index_put_(indices_tuple, values, accumulate=False)
+
+ example_args_4d = (
+ torch.randn(8, 16, 32, 64, dtype=torch.float32),
+ torch.randint(0, 8, (128,), dtype=torch.int64),
+ torch.randint(0, 16, (128,), dtype=torch.int64),
+ torch.randint(0, 32, (128,), dtype=torch.int64),
+ torch.randint(0, 64, (128,), dtype=torch.int64),
+ torch.randn(128, dtype=torch.float32),
+ )
+
+ @I.ir_module
+ class Expected4D:
+ @R.function
+ def main(
+ data: R.Tensor((8, 16, 32, 64), dtype="float32"),
+ indices_0: R.Tensor((128,), dtype="int64"),
+ indices_1: R.Tensor((128,), dtype="int64"),
+ indices_2: R.Tensor((128,), dtype="int64"),
+ indices_3: R.Tensor((128,), dtype="int64"),
+ values: R.Tensor((128,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put(
+ data,
+ R.tuple(indices_0, indices_1, indices_2, indices_3),
+ values,
+ accumulate=False,
+ )
+ gv: R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ # Test case 5: 5D input
+ class IndexPut5D(Module):
+ def forward(self, data, indices_0, indices_1, indices_2, indices_3,
indices_4, values):
+ indices_tuple = (indices_0, indices_1, indices_2, indices_3,
indices_4)
+ return data.index_put_(indices_tuple, values, accumulate=False)
+
+ example_args_5d = (
+ torch.randn(4, 8, 16, 32, 64, dtype=torch.float32),
+ torch.randint(0, 4, (128,), dtype=torch.int64),
+ torch.randint(0, 8, (128,), dtype=torch.int64),
+ torch.randint(0, 16, (128,), dtype=torch.int64),
+ torch.randint(0, 32, (128,), dtype=torch.int64),
+ torch.randint(0, 64, (128,), dtype=torch.int64),
+ torch.randn(128, dtype=torch.float32),
+ )
+
+ @I.ir_module
+ class Expected5D:
+ @R.function
+ def main(
+ data: R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+ indices_0: R.Tensor((128,), dtype="int64"),
+ indices_1: R.Tensor((128,), dtype="int64"),
+ indices_2: R.Tensor((128,), dtype="int64"),
+ indices_3: R.Tensor((128,), dtype="int64"),
+ indices_4: R.Tensor((128,), dtype="int64"),
+ values: R.Tensor((128,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") =
R.index_put(
+ data,
+ R.tuple(indices_0, indices_1, indices_2, indices_3,
indices_4),
+ values,
+ accumulate=False,
+ )
+ gv: R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")) =
(lv,)
+ R.output(gv)
+ return gv
+
+ # Run verification for each case
+ verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
+ verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
+ verify_model(IndexPut3D(), example_args_3d, {}, Expected3D)
+ verify_model(IndexPut4D(), example_args_4d, {}, Expected4D)
+ verify_model(IndexPut5D(), example_args_5d, {}, Expected5D)
+
+
def test_flip():
class Flip0(Module):
def forward(self, data):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index c5b95f6c39..4003202d4f 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4404,6 +4404,177 @@ def test_gather():
verify_model(Gather3(), [([2, 3], "float32"), ([2, 3], "int32")], {},
Expected3)
+def test_index_put():
+ # Test case 1: 1D input
+ class IndexPut1D(Module):
+ def forward(self, data, indices_0, values):
+ indices_tuple = (indices_0,)
+ return data.index_put_(indices_tuple, values, accumulate=False)
+
+ input_info_1d = [((64,), "float32"), ((128,), "int64"), ((128,),
"float32")]
+
+ @I.ir_module
+ class Expected1D:
+ @R.function
+ def main(
+ data: R.Tensor((64,), dtype="float32"),
+ indices_0: R.Tensor((128,), dtype="int64"),
+ values: R.Tensor((128,), dtype="float32"),
+ ) -> R.Tensor((64,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((64,), dtype="float32") = R.index_put(
+ data, R.tuple(indices_0), values, accumulate=False
+ )
+ gv: R.Tensor((64,), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ # Test case 2: 2D input
+ class IndexPut2D(Module):
+ def forward(self, data, indices_0, indices_1, values):
+ indices_tuple = (indices_0, indices_1)
+ return data.index_put_(indices_tuple, values, accumulate=False)
+
+ input_info_2d = [
+ ((32, 64), "float32"),
+ ((128,), "int64"),
+ ((128,), "int64"),
+ ((128,), "float32"),
+ ]
+
+ @I.ir_module
+ class Expected2D:
+ @R.function
+ def main(
+ data: R.Tensor((32, 64), dtype="float32"),
+ indices_0: R.Tensor((128,), dtype="int64"),
+ indices_1: R.Tensor((128,), dtype="int64"),
+ values: R.Tensor((128,), dtype="float32"),
+ ) -> R.Tensor((32, 64), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((32, 64), dtype="float32") = R.index_put(
+ data, R.tuple(indices_0, indices_1), values,
accumulate=False
+ )
+ gv: R.Tensor((32, 64), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ # Test case 3: 3D input
+ class IndexPut3D(Module):
+ def forward(self, data, indices_0, indices_1, indices_2, values):
+ indices_tuple = (indices_0, indices_1, indices_2)
+ return data.index_put_(indices_tuple, values, accumulate=False)
+
+ input_info_3d = [
+ ((16, 32, 64), "float32"),
+ ((128,), "int64"),
+ ((128,), "int64"),
+ ((128,), "int64"),
+ ((128,), "float32"),
+ ]
+
+ @I.ir_module
+ class Expected3D:
+ @R.function
+ def main(
+ data: R.Tensor((16, 32, 64), dtype="float32"),
+ indices_0: R.Tensor((128,), dtype="int64"),
+ indices_1: R.Tensor((128,), dtype="int64"),
+ indices_2: R.Tensor((128,), dtype="int64"),
+ values: R.Tensor((128,), dtype="float32"),
+ ) -> R.Tensor((16, 32, 64), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put(
+ data, R.tuple(indices_0, indices_1, indices_2), values,
accumulate=False
+ )
+ gv: R.Tensor((16, 32, 64), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ # Test case 4: 4D input
+ class IndexPut4D(Module):
+ def forward(self, data, indices_0, indices_1, indices_2, indices_3,
values):
+ indices_tuple = (indices_0, indices_1, indices_2, indices_3)
+ return data.index_put_(indices_tuple, values, accumulate=False)
+
+ input_info_4d = [
+ ((8, 16, 32, 64), "float32"),
+ ((128,), "int64"),
+ ((128,), "int64"),
+ ((128,), "int64"),
+ ((128,), "int64"),
+ ((128,), "float32"),
+ ]
+
+ @I.ir_module
+ class Expected4D:
+ @R.function
+ def main(
+ data: R.Tensor((8, 16, 32, 64), dtype="float32"),
+ indices_0: R.Tensor((128,), dtype="int64"),
+ indices_1: R.Tensor((128,), dtype="int64"),
+ indices_2: R.Tensor((128,), dtype="int64"),
+ indices_3: R.Tensor((128,), dtype="int64"),
+ values: R.Tensor((128,), dtype="float32"),
+ ) -> R.Tensor((8, 16, 32, 64), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put(
+ data,
+ R.tuple(indices_0, indices_1, indices_2, indices_3),
+ values,
+ accumulate=False,
+ )
+ gv: R.Tensor((8, 16, 32, 64), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ # Test case 5: 5D input
+ class IndexPut5D(Module):
+ def forward(self, data, indices_0, indices_1, indices_2, indices_3,
indices_4, values):
+ indices_tuple = (indices_0, indices_1, indices_2, indices_3,
indices_4)
+ return data.index_put_(indices_tuple, values, accumulate=False)
+
+ input_info_5d = [
+ ((4, 8, 16, 32, 64), "float32"),
+ ((128,), "int64"),
+ ((128,), "int64"),
+ ((128,), "int64"),
+ ((128,), "int64"),
+ ((128,), "int64"),
+ ((128,), "float32"),
+ ]
+
+ @I.ir_module
+ class Expected5D:
+ @R.function
+ def main(
+ data: R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+ indices_0: R.Tensor((128,), dtype="int64"),
+ indices_1: R.Tensor((128,), dtype="int64"),
+ indices_2: R.Tensor((128,), dtype="int64"),
+ indices_3: R.Tensor((128,), dtype="int64"),
+ indices_4: R.Tensor((128,), dtype="int64"),
+ values: R.Tensor((128,), dtype="float32"),
+ ) -> R.Tensor((4, 8, 16, 32, 64), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") =
R.index_put(
+ data,
+ R.tuple(indices_0, indices_1, indices_2, indices_3,
indices_4),
+ values,
+ accumulate=False,
+ )
+ gv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ # Run verification for each case
+ verify_model(IndexPut1D(), input_info_1d, {}, Expected1D)
+ verify_model(IndexPut2D(), input_info_2d, {}, Expected2D)
+ verify_model(IndexPut3D(), input_info_3d, {}, Expected3D)
+ verify_model(IndexPut4D(), input_info_4d, {}, Expected4D)
+ verify_model(IndexPut5D(), input_info_5d, {}, Expected5D)
+
+
def test_flip():
class Flip0(Module):
def forward(self, data):