This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0149fcd3c6 Add support for index_put_ op (#17865)
0149fcd3c6 is described below

commit 0149fcd3c6e10226af03359fdfa4cc13ec24b73e
Author: Pratheesh-04-MCW <[email protected]>
AuthorDate: Thu Apr 24 21:37:23 2025 +0530

    Add support for index_put_ op (#17865)
    
    * add support for index_put_ op
    
    * add test cases from 1D to 5D in both exported program and fx graph
    
    * lint issues
    
    * lint check
    
    * lint issue
    
    * whitespace issue
    
    * whitespace issue
    
    * whitespace issue
    
    * line length error
    
    * trailing space issue
    
    * modified conditions for parameters
    
    * modified base_fx
    
    * resolved conflicts
    
    * removed trailing whitespace
    
    * lint issue
    
    * unity issue
---
 include/tvm/relax/attrs/manipulate.h               |  14 ++
 .../frontend/torch/base_fx_graph_translator.py     |  17 ++
 .../frontend/torch/exported_program_translator.py  |   1 +
 python/tvm/relax/frontend/torch/fx_translator.py   |   1 +
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/manipulate.py                  |  51 ++++++
 python/tvm/relax/op/op_attrs.py                    |   5 +
 .../tvm/relax/transform/legalize_ops/manipulate.py |  22 +++
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 python/tvm/topi/__init__.py                        |   1 +
 python/tvm/topi/index_put.py                       | 117 ++++++++++++++
 src/relax/op/tensor/manipulate.cc                  | 123 +++++++++++++++
 src/relax/op/tensor/manipulate.h                   |  13 ++
 .../relax/test_frontend_from_exported_program.py   | 175 +++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 171 ++++++++++++++++++++
 15 files changed, 714 insertions(+)

diff --git a/include/tvm/relax/attrs/manipulate.h 
b/include/tvm/relax/attrs/manipulate.h
index 67f99d9b41..943d2f4d0d 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -182,6 +182,20 @@ struct GatherNDAttrs : public 
tvm::AttrsNode<GatherNDAttrs> {
   }
 };  // struct GatherNDAttrs
 
+/*! \brief Attributes used in index_put operator */
+struct IndexPutAttrs : public tvm::AttrsNode<IndexPutAttrs> {
+  bool accumulate;
+
+  TVM_DECLARE_ATTRS(IndexPutAttrs, "relax.attrs.IndexPutAttrs") {
+    TVM_ATTR_FIELD(accumulate)
+        .set_default(false)
+        .describe(
+            "Whether to accumulate (add) values rather than replace. "
+            "If true, performs tensor[indices] += values, "
+            "otherwise performs tensor[indices] = values.");
+  }
+};  // struct IndexPutAttrs
+
 /*! \brief Attributes used in scatter_elements operators */
 struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
   Integer axis;
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 33f6ffc313..5dd78be483 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1118,6 +1118,23 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         index = self.env[node.args[2]]
         return self.block_builder.emit(relax.op.gather_elements(x, index, 
axis=dim))
 
+    def _index_put(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        tensor = args[0]
+        indices = args[1] if len(args) > 1 else node.kwargs.get("indices")
+        values = args[2] if len(args) > 2 else node.kwargs.get("values")
+        accumulate = args[3] if len(args) > 3 else 
node.kwargs.get("accumulate", False)
+
+        if indices is None or values is None:
+            raise ValueError("'indices and values' arguments are required for 
index_put operation")
+
+        if not isinstance(accumulate, bool):
+            raise TypeError("'accumulate' must be a boolean value, got 
{}".format(type(accumulate)))
+
+        if isinstance(indices, (list, tuple)):
+            indices = relax.Tuple(indices)
+        return self.block_builder.emit(relax.op.index_put(tensor, indices, 
values, accumulate))
+
     def _index_tensor(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         indices = args[1]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index aa4984c0ba..db5ca01399 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -433,6 +433,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "flip.default": self._flip,
             "gather.default": self._gather,
             "index.Tensor": self._index_tensor,
+            "index_put_.default": self._index_put,
             "narrow.default": self._narrow,
             "permute.default": self._permute,
             "repeat.default": self._repeat,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index bab272bd78..18dba2d988 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -801,6 +801,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "flatten": self._flatten,
             "flip": self._flip,
             "gather": self._gather,
+            "index_put_": self._index_put,
             "narrow": self._narrow,
             "numel": self._numel,
             "permute": self._permute,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 097313a33d..7b8c34b641 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -95,6 +95,7 @@ from .manipulate import (
     flip,
     gather_elements,
     gather_nd,
+    index_put,
     index_tensor,
     layout_transform,
     one_hot,
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index a693adf432..13334d1479 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -595,6 +595,57 @@ def index_tensor(data: Expr, indices: Union[Expr, 
List[Expr]]) -> Expr:
     return _ffi_api.index_tensor(data, indices)  # type: ignore
 
 
+def index_put(
+    data: Expr,
+    indices: Union[Expr, Tuple[Expr]],
+    values: Expr,
+    accumulate: bool = False,
+) -> Expr:
+    """This operation updates values in `data` at positions
+    specified by `indices` with corresponding values from `values`. The 
`indices` is a tuple
+    of tensors where each tensor corresponds to a dimension in `data`.
+    When `accumulate` is True, the operation performs accumulation (addition) 
rather than
+    replacement. The `reduction` parameter allows specifying different 
reduction operations.
+    Parameters
+    ----------
+    data : relax.Expr
+        The input tensor to be modified
+    indices : Union[Expr, Tuple[Expr]]
+        Tuple of index tensors (one for each dimension) specifying positions 
to update
+    values : relax.Expr
+        Values to place at the specified indices
+    accumulate : bool
+        Whether to accumulate (add) values rather than replace (default: False)
+
+    Returns
+    -------
+    result : relax.Expr
+        A new tensor with the same shape as data but with specified positions 
updated
+    Examples
+    --------
+    .. code-block:: python
+        # inputs
+        data = torch.zeros(3, 3)
+        indices = (torch.tensor([0, 2]), torch.tensor([1, 1]))
+        values = torch.tensor([1.0, 2.0])
+        # output
+        output = [
+            [0.0, 1.0, 0.0],
+            [0.0, 0.0, 0.0],
+            [0.0, 2.0, 0.0],
+        ]
+        # with accumulate=True
+        output = [
+            [0.0, 1.0, 0.0],
+            [0.0, 0.0, 0.0],
+            [0.0, 3.0, 0.0],
+        ]
+    """
+    if not isinstance(indices, (list, tuple)):
+        indices = RxTuple(indices)
+    return _ffi_api.index_put(data, indices, values, accumulate)  # type: 
ignore
+
+
 def scatter_elements(
     data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = 
"update"
 ):
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index fda4258a09..fe527e38e8 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -144,6 +144,11 @@ class StackAttrs(Attrs):
     """Attributes for concat operator"""
 
 
+@tvm._ffi.register_object("relax.attrs.IndexPutAttrs")
+class IndexPutAttrs(Attrs):
+    """Attributes for index_put operator"""
+
+
 @tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs")
 class LayoutTransformAttrs(Attrs):
     """Attributes used in layout_transform operator"""
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 84baa887d9..a66b60c013 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -193,6 +193,28 @@ def _index_tensor(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(topi.index_tensor, call.args[0], fields)
 
 
+@register_legalize("relax.index_put")
+def _index_put(bb: BlockBuilder, call: Call) -> Expr:
+    data = call.args[0]
+    indices = call.args[1]
+    values = call.args[2]
+    accumulate = call.attrs.accumulate
+
+    # If indices is a Tuple, unpack it into individual tensors
+    if isinstance(indices, relax.Tuple):
+        indices_list = [indices.fields[i] for i in range(len(indices.fields))]
+    else:
+        indices_list = [indices]
+
+    return bb.call_te(
+        topi.index_put,
+        data,
+        indices_list,
+        values,
+        accumulate=accumulate,
+    )
+
+
 @register_legalize("relax.scatter_elements")
 def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 22b00cd704..d2952ed8e0 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -100,6 +100,7 @@ from tvm.relax.op import (
     greater,
     greater_equal,
     hint_on_device,
+    index_put,
     image,
     index_tensor,
     invoke_closure,
@@ -785,6 +786,7 @@ __all__ = [
     "greater_equal",
     "hexagon",
     "hint_on_device",
+    "index_put",
     "image",
     "index_tensor",
     "invoke_closure",
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index 1de6941c99..fa4e98a89a 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -33,6 +33,7 @@ from . import cpp
 from .math import *
 from .tensor import *
 from .generic_op_impl import *
+from .index_put import *
 from .reduction import *
 from .transform import *
 from .broadcast import *
diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py
new file mode 100644
index 0000000000..f51c6718ab
--- /dev/null
+++ b/python/tvm/topi/index_put.py
@@ -0,0 +1,117 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contrir_builderutor license agreements.  See the NOTICE file
+# distrir_builderuted with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distrir_builderuted under the License is distrir_builderuted on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""IndexPut operator"""
+from tvm import te
+from tvm import tir
+from . import utils
+
+
+def index_put(data, indices, values, accumulate=False):
+    """Put values into an array according to indices.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The source array to be modified.
+
+    indices : Tuple[tvm.te.Tensor]
+        Tuple of 1D index tensors (one for each dimension) specifying 
positions.
+
+    values : tvm.te.Tensor
+        The values to place at the specified indices.
+
+    accumulate : bool, optional
+        Whether to accumulate (add) values rather than replace.
+        If True, performs tensor[indices] += values
+        If False, performs tensor[indices] = values
+        Default is False.
+
+    Returns
+    -------
+    ret : tvm.te.Tensor
+    """
+    if not isinstance(indices, (list, tuple)):
+        indices = [indices]
+
+    # Check indices match data dimensions
+    if len(indices) != len(data.shape):
+        raise ValueError(
+            f"Number of index tensors ({len(indices)}) must match "
+            f"data dimensions ({len(data.shape)})"
+        )
+
+    # Prepare ranges and strides
+    shape = data.shape
+    full_range = 1
+    for dim in shape:
+        full_range *= dim
+
+    # Check all indices have same length
+    index_len = len(indices[0])
+    for idx in indices[1:]:
+        if not utils.equal_const_int(len(idx), index_len):
+            raise ValueError("All index tensors must have same length")
+
+    def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func):
+        ir_builder = tir.ir_builder.create()
+
+        data = ir_builder.buffer_ptr(data_ptr)
+        indices = [ir_builder.buffer_ptr(idx) for idx in index_ptrs]
+        values = ir_builder.buffer_ptr(values_ptr)
+        out = ir_builder.buffer_ptr(out_ptr)
+
+        with ir_builder.for_range(0, full_range, "i", kind="parallel") as i:
+            out[i] = data[i]
+
+        with ir_builder.for_range(0, index_len, "k", kind="parallel") as k:
+            # Calculate multi-dimensional index
+            flat_index = 0
+            stride = 1
+            for dim in range(len(shape) - 1, -1, -1):
+                # Get index and shift to positive if needed
+                idx_val = indices[dim][k]
+                shifted_idx = idx_val + (idx_val < 0) * shape[dim]
+                flat_index += shifted_idx * stride
+                stride *= shape[dim]
+
+            reduce_func(out, flat_index, values[k])
+
+        return ir_builder.get()
+
+    def update_func(dst_ptr, dst_index, update):
+        dst_ptr[dst_index] = update
+
+    def add_func(dst_ptr, dst_index, update):
+        dst_ptr[dst_index] += update
+
+    reduce_func = add_func if accumulate else update_func
+
+    # Prepare input buffers
+    in_buffers = [data]
+    in_buffers.extend(indices)
+    in_buffers.append(values)
+
+    out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf")
+    return te.extern(
+        [data.shape],
+        in_buffers,
+        lambda ins, outs: gen_ir(ins[0], ins[1:-1], ins[-1], outs[0], 
reduce_func),
+        dtype=data.dtype,
+        out_buffers=[out_buf],
+        name="index_put.generic",
+        tag="index_put.generic",
+    )
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index f56135a35b..482ebe5cac 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -1972,6 +1972,129 @@ TVM_REGISTER_OP("relax.gather_nd")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGatherND)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.index_put */
+TVM_REGISTER_NODE_TYPE(IndexPutAttrs);
+
+Expr index_put(Expr data, Expr indices, Expr values, bool accumulate) {
+  auto attrs = make_object<IndexPutAttrs>();
+  attrs->accumulate = std::move(accumulate);
+  static const Op& op = Op::Get("relax.index_put");
+  return Call(op, {data, indices, values}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.index_put").set_body_typed(index_put);
+
+StructInfo InferStructInfoIndexPut(const Call& call, const BlockBuilder& ctx) {
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* values_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+
+  auto diag_def = [&](const TensorStructInfoNode* sinfo, String name, String 
type_key) {
+    if (sinfo == nullptr) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "IndexPut requires the input " << name
+                       << " to be a Tensor. However, the given one is " << 
type_key);
+    }
+  };
+
+  diag_def(data_sinfo, "data", call->args[0]->struct_info_->GetTypeKey());
+  diag_def(values_sinfo, "values", call->args[2]->struct_info_->GetTypeKey());
+
+  // Handle indices: either a single tensor or a tuple of tensors
+  Array<TensorStructInfo> indices_tensors;
+
+  if (const auto* tuple_sinfo = 
GetStructInfoAs<TupleStructInfoNode>(call->args[1])) {
+    // Indices is a tuple of tensors
+    for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) {
+      const auto* tensor_sinfo = 
tuple_sinfo->fields[i].as<TensorStructInfoNode>();
+      if (tensor_sinfo == nullptr) {
+        ctx->ReportFatal(Diagnostic::Error(call)
+                         << "IndexPut requires each index in the indices tuple 
to be a Tensor. "
+                         << "However, element " << i << " is "
+                         << tuple_sinfo->fields[i]->GetTypeKey());
+      }
+      indices_tensors.push_back(GetRef<TensorStructInfo>(tensor_sinfo));
+    }
+  } else if (const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1])) {
+    // Indices is a single tensor
+    indices_tensors.push_back(GetRef<TensorStructInfo>(tensor_sinfo));
+  } else {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "IndexPut requires indices to be a Tensor or a tuple 
of Tensors. "
+                     << "However, the given one is " << 
call->args[1]->struct_info_->GetTypeKey());
+  }
+
+  if (data_sinfo->IsUnknownNdim()) {
+    return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, 
data_sinfo->vdevice);
+  }
+
+  // Validate each index tensor
+  for (size_t i = 0; i < indices_tensors.size(); ++i) {
+    const auto& tensor_sinfo = indices_tensors[i];
+    if (!tensor_sinfo->IsUnknownNdim() && tensor_sinfo->ndim != 1) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "IndexPut requires each index tensor to be 1D. "
+                       << "However, index tensor " << i << " has ndim=" << 
tensor_sinfo->ndim);
+    }
+    if (tensor_sinfo->IsUnknownDtype()) {
+      LOG(WARNING) << "Data type of index tensor " << i
+                   << " has not been specified. Assume it has an integer 
type.";
+    } else if (!(tensor_sinfo->dtype.is_int() || 
tensor_sinfo->dtype.is_uint())) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "IndexPut requires each index tensor to have integer 
dtype. "
+                       << "However, index tensor " << i << " has dtype=" << 
tensor_sinfo->dtype);
+    }
+  }
+
+  // Check that the number of index tensors matches data dimensions
+  if (!data_sinfo->IsUnknownNdim() &&
+      indices_tensors.size() != static_cast<size_t>(data_sinfo->ndim)) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "IndexPut requires the number of index tensors (" << 
indices_tensors.size()
+                     << ") to match the data tensor dimensions (" << 
data_sinfo->ndim << ")");
+  }
+
+  // Check data and values dtype compatibility
+  if (data_sinfo->IsUnknownDtype() || values_sinfo->IsUnknownDtype()) {
+    auto diag_dtype = [&](const TensorStructInfoNode* sinfo, String name) {
+      if (sinfo->IsUnknownDtype()) {
+        LOG(WARNING) << "Data type of " << name
+                     << " has not been specified. Assume it has an integer 
type.";
+      }
+    };
+    diag_dtype(data_sinfo, "data");
+    diag_dtype(values_sinfo, "values");
+  } else if (data_sinfo->dtype != values_sinfo->dtype) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "IndexPut requires the input data to have the same 
type as values. "
+                     << "However, the given types are data: " << 
data_sinfo->dtype
+                     << ", values: " << values_sinfo->dtype);
+  }
+
+  // Check values shape compatibility
+  const auto* values_shape = values_sinfo->shape.as<ShapeExprNode>();
+  if (values_shape) {
+    if (values_sinfo->ndim != 1) {
+      LOG(WARNING) << "IndexPut typically expects values to be 1D, but got 
ndim="
+                   << values_sinfo->ndim;
+    }
+  }
+
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  if (data_shape) {
+    return TensorStructInfo(ShapeExpr(data_shape->values), data_sinfo->dtype, 
data_sinfo->vdevice);
+  }
+  return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, 
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.index_put")
+    .set_attrs_type<IndexPutAttrs>()
+    .set_num_inputs(3)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("indices", "Tensor", "The indices tensor(s).")
+    .add_argument("values", "Tensor", "The values to put.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoIndexPut)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.scatter_elements */
 TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs);
 
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 4580f9191b..2e4c92c150 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -218,6 +218,19 @@ Expr gather_nd(Expr data, Expr indices, int batch_dims = 
0);
  */
 Expr index_tensor(Expr data, Expr indices);
 
+/*!
+ * \brief Put values into an array according to indices.
+ * \param data The input tensor to be modified.
+ * \param indices The index positions where values should be placed.
+ *                This should be a tuple of 1D tensors (one for each 
dimension).
+ * \param values The values to place at the specified indices.
+ * \param accumulate Whether to accumulate (add) values rather than replace.
+ *                  If true, equivalent to tensor[indices] += values.
+ *                  If false, equivalent to tensor[indices] = values.
+ * \return The computed result with values placed at specified indices.
+ */
+Expr index_put(Expr data, Expr indices, Expr values, bool accumulate = false);
+
 /*!
  * \brief Scatter updates into an array according to indices.
  * \param data The input tensor.
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 5ef2c27e91..589d3f5bae 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4169,6 +4169,181 @@ def test_gather():
     verify_model(Gather3(), example_args, {}, Expected3)
 
 
+def test_index_put():
+    # Test case 1: 1D input
+    class IndexPut1D(Module):
+        def forward(self, data, indices_0, values):
+            indices_tuple = (indices_0,)
+            return data.index_put_(indices_tuple, values, accumulate=False)
+
+    example_args_1d = (
+        torch.randn(64, dtype=torch.float32),
+        torch.randint(0, 64, (128,), dtype=torch.int64),
+        torch.randn(128, dtype=torch.float32),
+    )
+
+    @I.ir_module
+    class Expected1D:
+        @R.function
+        def main(
+            data: R.Tensor((64,), dtype="float32"),
+            indices_0: R.Tensor((128,), dtype="int64"),
+            values: R.Tensor((128,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((64,), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((64,), dtype="float32") = R.index_put(
+                    data, R.tuple(indices_0), values, accumulate=False
+                )
+                gv: R.Tuple(R.Tensor((64,), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    # Test case 2: 2D input
+    class IndexPut2D(Module):
+        def forward(self, data, indices_0, indices_1, values):
+            indices_tuple = (indices_0, indices_1)
+            return data.index_put_(indices_tuple, values, accumulate=False)
+
+    example_args_2d = (
+        torch.randn(32, 64, dtype=torch.float32),
+        torch.randint(0, 32, (128,), dtype=torch.int64),
+        torch.randint(0, 64, (128,), dtype=torch.int64),
+        torch.randn(128, dtype=torch.float32),
+    )
+
+    @I.ir_module
+    class Expected2D:
+        @R.function
+        def main(
+            data: R.Tensor((32, 64), dtype="float32"),
+            indices_0: R.Tensor((128,), dtype="int64"),
+            indices_1: R.Tensor((128,), dtype="int64"),
+            values: R.Tensor((128,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((32, 64), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((32, 64), dtype="float32") = R.index_put(
+                    data, R.tuple(indices_0, indices_1), values, 
accumulate=False
+                )
+                gv: R.Tuple(R.Tensor((32, 64), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    # Test case 3: 3D input
+    class IndexPut3D(Module):
+        def forward(self, data, indices_0, indices_1, indices_2, values):
+            indices_tuple = (indices_0, indices_1, indices_2)
+            return data.index_put_(indices_tuple, values, accumulate=False)
+
+    example_args_3d = (
+        torch.randn(16, 32, 64, dtype=torch.float32),
+        torch.randint(0, 16, (128,), dtype=torch.int64),
+        torch.randint(0, 32, (128,), dtype=torch.int64),
+        torch.randint(0, 64, (128,), dtype=torch.int64),
+        torch.randn(128, dtype=torch.float32),
+    )
+
+    @I.ir_module
+    class Expected3D:
+        @R.function
+        def main(
+            data: R.Tensor((16, 32, 64), dtype="float32"),
+            indices_0: R.Tensor((128,), dtype="int64"),
+            indices_1: R.Tensor((128,), dtype="int64"),
+            indices_2: R.Tensor((128,), dtype="int64"),
+            values: R.Tensor((128,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((16, 32, 64), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put(
+                    data, R.tuple(indices_0, indices_1, indices_2), values, 
accumulate=False
+                )
+                gv: R.Tuple(R.Tensor((16, 32, 64), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    # Test case 4: 4D input
+    class IndexPut4D(Module):
+        def forward(self, data, indices_0, indices_1, indices_2, indices_3, 
values):
+            indices_tuple = (indices_0, indices_1, indices_2, indices_3)
+            return data.index_put_(indices_tuple, values, accumulate=False)
+
+    example_args_4d = (
+        torch.randn(8, 16, 32, 64, dtype=torch.float32),
+        torch.randint(0, 8, (128,), dtype=torch.int64),
+        torch.randint(0, 16, (128,), dtype=torch.int64),
+        torch.randint(0, 32, (128,), dtype=torch.int64),
+        torch.randint(0, 64, (128,), dtype=torch.int64),
+        torch.randn(128, dtype=torch.float32),
+    )
+
+    @I.ir_module
+    class Expected4D:
+        @R.function
+        def main(
+            data: R.Tensor((8, 16, 32, 64), dtype="float32"),
+            indices_0: R.Tensor((128,), dtype="int64"),
+            indices_1: R.Tensor((128,), dtype="int64"),
+            indices_2: R.Tensor((128,), dtype="int64"),
+            indices_3: R.Tensor((128,), dtype="int64"),
+            values: R.Tensor((128,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put(
+                    data,
+                    R.tuple(indices_0, indices_1, indices_2, indices_3),
+                    values,
+                    accumulate=False,
+                )
+                gv: R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    # Test case 5: 5D input
+    class IndexPut5D(Module):
+        def forward(self, data, indices_0, indices_1, indices_2, indices_3, 
indices_4, values):
+            indices_tuple = (indices_0, indices_1, indices_2, indices_3, 
indices_4)
+            return data.index_put_(indices_tuple, values, accumulate=False)
+
+    example_args_5d = (
+        torch.randn(4, 8, 16, 32, 64, dtype=torch.float32),
+        torch.randint(0, 4, (128,), dtype=torch.int64),
+        torch.randint(0, 8, (128,), dtype=torch.int64),
+        torch.randint(0, 16, (128,), dtype=torch.int64),
+        torch.randint(0, 32, (128,), dtype=torch.int64),
+        torch.randint(0, 64, (128,), dtype=torch.int64),
+        torch.randn(128, dtype=torch.float32),
+    )
+
+    @I.ir_module
+    class Expected5D:
+        @R.function
+        def main(
+            data: R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+            indices_0: R.Tensor((128,), dtype="int64"),
+            indices_1: R.Tensor((128,), dtype="int64"),
+            indices_2: R.Tensor((128,), dtype="int64"),
+            indices_3: R.Tensor((128,), dtype="int64"),
+            indices_4: R.Tensor((128,), dtype="int64"),
+            values: R.Tensor((128,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = 
R.index_put(
+                    data,
+                    R.tuple(indices_0, indices_1, indices_2, indices_3, 
indices_4),
+                    values,
+                    accumulate=False,
+                )
+                gv: R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")) = 
(lv,)
+                R.output(gv)
+            return gv
+
+    # Run verification for each case
+    verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
+    verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
+    verify_model(IndexPut3D(), example_args_3d, {}, Expected3D)
+    verify_model(IndexPut4D(), example_args_4d, {}, Expected4D)
+    verify_model(IndexPut5D(), example_args_5d, {}, Expected5D)
+
+
 def test_flip():
     class Flip0(Module):
         def forward(self, data):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index c5b95f6c39..4003202d4f 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4404,6 +4404,177 @@ def test_gather():
     verify_model(Gather3(), [([2, 3], "float32"), ([2, 3], "int32")], {}, 
Expected3)
 
 
+def test_index_put():
+    # Test case 1: 1D input
+    class IndexPut1D(Module):
+        def forward(self, data, indices_0, values):
+            indices_tuple = (indices_0,)
+            return data.index_put_(indices_tuple, values, accumulate=False)
+
+    input_info_1d = [((64,), "float32"), ((128,), "int64"), ((128,), 
"float32")]
+
+    @I.ir_module
+    class Expected1D:
+        @R.function
+        def main(
+            data: R.Tensor((64,), dtype="float32"),
+            indices_0: R.Tensor((128,), dtype="int64"),
+            values: R.Tensor((128,), dtype="float32"),
+        ) -> R.Tensor((64,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((64,), dtype="float32") = R.index_put(
+                    data, R.tuple(indices_0), values, accumulate=False
+                )
+                gv: R.Tensor((64,), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    # Test case 2: 2D input
+    class IndexPut2D(Module):
+        def forward(self, data, indices_0, indices_1, values):
+            indices_tuple = (indices_0, indices_1)
+            return data.index_put_(indices_tuple, values, accumulate=False)
+
+    input_info_2d = [
+        ((32, 64), "float32"),
+        ((128,), "int64"),
+        ((128,), "int64"),
+        ((128,), "float32"),
+    ]
+
+    @I.ir_module
+    class Expected2D:
+        @R.function
+        def main(
+            data: R.Tensor((32, 64), dtype="float32"),
+            indices_0: R.Tensor((128,), dtype="int64"),
+            indices_1: R.Tensor((128,), dtype="int64"),
+            values: R.Tensor((128,), dtype="float32"),
+        ) -> R.Tensor((32, 64), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((32, 64), dtype="float32") = R.index_put(
+                    data, R.tuple(indices_0, indices_1), values, 
accumulate=False
+                )
+                gv: R.Tensor((32, 64), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    # Test case 3: 3D input
+    class IndexPut3D(Module):
+        def forward(self, data, indices_0, indices_1, indices_2, values):
+            indices_tuple = (indices_0, indices_1, indices_2)
+            return data.index_put_(indices_tuple, values, accumulate=False)
+
+    input_info_3d = [
+        ((16, 32, 64), "float32"),
+        ((128,), "int64"),
+        ((128,), "int64"),
+        ((128,), "int64"),
+        ((128,), "float32"),
+    ]
+
+    @I.ir_module
+    class Expected3D:
+        @R.function
+        def main(
+            data: R.Tensor((16, 32, 64), dtype="float32"),
+            indices_0: R.Tensor((128,), dtype="int64"),
+            indices_1: R.Tensor((128,), dtype="int64"),
+            indices_2: R.Tensor((128,), dtype="int64"),
+            values: R.Tensor((128,), dtype="float32"),
+        ) -> R.Tensor((16, 32, 64), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put(
+                    data, R.tuple(indices_0, indices_1, indices_2), values, 
accumulate=False
+                )
+                gv: R.Tensor((16, 32, 64), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    # Test case 4: 4D input
+    class IndexPut4D(Module):
+        def forward(self, data, indices_0, indices_1, indices_2, indices_3, 
values):
+            indices_tuple = (indices_0, indices_1, indices_2, indices_3)
+            return data.index_put_(indices_tuple, values, accumulate=False)
+
+    input_info_4d = [
+        ((8, 16, 32, 64), "float32"),
+        ((128,), "int64"),
+        ((128,), "int64"),
+        ((128,), "int64"),
+        ((128,), "int64"),
+        ((128,), "float32"),
+    ]
+
+    @I.ir_module
+    class Expected4D:
+        @R.function
+        def main(
+            data: R.Tensor((8, 16, 32, 64), dtype="float32"),
+            indices_0: R.Tensor((128,), dtype="int64"),
+            indices_1: R.Tensor((128,), dtype="int64"),
+            indices_2: R.Tensor((128,), dtype="int64"),
+            indices_3: R.Tensor((128,), dtype="int64"),
+            values: R.Tensor((128,), dtype="float32"),
+        ) -> R.Tensor((8, 16, 32, 64), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put(
+                    data,
+                    R.tuple(indices_0, indices_1, indices_2, indices_3),
+                    values,
+                    accumulate=False,
+                )
+                gv: R.Tensor((8, 16, 32, 64), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    # Test case 5: 5D input
+    class IndexPut5D(Module):
+        def forward(self, data, indices_0, indices_1, indices_2, indices_3, 
indices_4, values):
+            indices_tuple = (indices_0, indices_1, indices_2, indices_3, 
indices_4)
+            return data.index_put_(indices_tuple, values, accumulate=False)
+
+    input_info_5d = [
+        ((4, 8, 16, 32, 64), "float32"),
+        ((128,), "int64"),
+        ((128,), "int64"),
+        ((128,), "int64"),
+        ((128,), "int64"),
+        ((128,), "int64"),
+        ((128,), "float32"),
+    ]
+
+    @I.ir_module
+    class Expected5D:
+        @R.function
+        def main(
+            data: R.Tensor((4, 8, 16, 32, 64), dtype="float32"),
+            indices_0: R.Tensor((128,), dtype="int64"),
+            indices_1: R.Tensor((128,), dtype="int64"),
+            indices_2: R.Tensor((128,), dtype="int64"),
+            indices_3: R.Tensor((128,), dtype="int64"),
+            indices_4: R.Tensor((128,), dtype="int64"),
+            values: R.Tensor((128,), dtype="float32"),
+        ) -> R.Tensor((4, 8, 16, 32, 64), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = 
R.index_put(
+                    data,
+                    R.tuple(indices_0, indices_1, indices_2, indices_3, 
indices_4),
+                    values,
+                    accumulate=False,
+                )
+                gv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    # Run verification for each case
+    verify_model(IndexPut1D(), input_info_1d, {}, Expected1D)
+    verify_model(IndexPut2D(), input_info_2d, {}, Expected2D)
+    verify_model(IndexPut3D(), input_info_3d, {}, Expected3D)
+    verify_model(IndexPut4D(), input_info_4d, {}, Expected4D)
+    verify_model(IndexPut5D(), input_info_5d, {}, Expected5D)
+
+
 def test_flip():
     class Flip0(Module):
         def forward(self, data):


Reply via email to