This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b5b0337568 [Relax][PyTorch] support for index.Tensor (#17836)
b5b0337568 is described below

commit b5b0337568f4aa6912c9c48f9e03027fc140e9af
Author: Hugo Latendresse <[email protected]>
AuthorDate: Mon Apr 21 10:59:51 2025 -0400

    [Relax][PyTorch] support for index.Tensor (#17836)
    
    New op for advanced indexing + unit tests
---
 .../frontend/torch/base_fx_graph_translator.py     |   5 +
 .../frontend/torch/exported_program_translator.py  |   1 +
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/manipulate.py                  |  63 +++++++++
 .../tvm/relax/transform/legalize_ops/manipulate.py |   9 ++
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 python/tvm/topi/transform.py                       |  51 ++++++++
 src/relax/op/tensor/manipulate.cc                  | 145 +++++++++++++++++++++
 src/relax/op/tensor/manipulate.h                   |  12 ++
 tests/python/relax/test_from_exported_to_cuda.py   | 141 +++++++++++++++++++-
 10 files changed, 424 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 733a5d6b1a..13d13ff24c 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1148,6 +1148,11 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         index = self.env[node.args[2]]
         return self.block_builder.emit(relax.op.gather_elements(x, index, 
axis=dim))
 
+    def _index_tensor(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        indices = args[1]
+        return self.block_builder.emit(relax.op.index_tensor(args[0], indices))
+
     def _permute(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index af1393329e..ab55ded36c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -420,6 +420,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "flatten.using_ints": self._flatten,
             "flip.default": self._flip,
             "gather.default": self._gather,
+            "index.Tensor": self._index_tensor,
             "narrow.default": self._narrow,
             "permute.default": self._permute,
             "repeat.default": self._repeat,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 3145a7c292..097313a33d 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -95,6 +95,7 @@ from .manipulate import (
     flip,
     gather_elements,
     gather_nd,
+    index_tensor,
     layout_transform,
     one_hot,
     permute_dims,
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index 725e58bd01..a693adf432 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -532,6 +532,69 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 
0) -> Expr:
     return _ffi_api.gather_nd(data, indices, batch_dims)  # type: ignore
 
 
+def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr:
+    """Advanced‑tensor indexing (NumPy/PyTorch‐style).
+
+    Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this
+    operator selects elements from ``data`` as if one had written
+    ``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch:
+
+    All index tensors must have an integer dtype.
+
+    Their shapes are broadcast together to a common shape ``B`` in
+    the usual NumPy way.
+
+    The result shape is ``B + data.shape[k:]`` (i.e. the broadcast
+    shape followed by the remaining axes of ``data`` that are *not*
+    indexed).
+
+    At compile‑time Relax checks that the number of index tensors
+    ``k`` does not exceed ``data.ndim``, that the dtypes are integer,
+    and that the shapes are consitent (broadcast‑compatible).
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input tensor to be indexed.
+
+    indices : Union[relax.Expr, List[relax.Expr]]
+        A Tuple expression containing the index tensors,
+        or a Python ``list`` / ``tuple`` that will be promoted to a
+        tuple expression automatically. Each tensor must have an
+        integer dtype.
+
+    Returns
+    -------
+    result : relax.Expr
+        The tensor obtained after advanced indexing.  Its dtype equals
+        ``data.dtype``
+
+    Examples
+    --------
+    .. code-block:: python
+
+        import numpy as np
+        import tvm.relax as R
+
+        x   = R.const(np.arange(9).reshape(3, 3).astype("float32"))
+        row = R.const(np.array([0, 2]))        # shape (2,)
+        col = R.const(np.array([1, 0]))        # shape (2,)
+
+        y = R.index_tensor(x, [row, col])
+        # y.shape == (2,) ;  y == [1., 6.]
+
+        # Broadcasting: row : (2,1), col : (1,3)  →  B = (2,3)
+        row = R.const(np.array([[0],[1]]))
+        col = R.const(np.array([[0,1,2]]))
+        z = R.index_tensor(x, [row, col])
+        # z.shape == (2,3)
+
+    """
+    if isinstance(indices, (list, tuple)):
+        indices = RxTuple(indices)
+    return _ffi_api.index_tensor(data, indices)  # type: ignore
+
+
 def scatter_elements(
     data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = 
"update"
 ):
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index a481d7af95..84baa887d9 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -49,6 +49,7 @@ register_legalize(
     "relax.collapse_sum_like",
     _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True),
 )
+
 register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, 
"collapse_sum"))
 
 
@@ -184,6 +185,14 @@ def _gather_nd(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(te_gather_nd, call.args[0], call.args[1], 
int(call.attrs.batch_dims))
 
 
+@register_legalize("relax.index_tensor")
+def _index_tensor(bb: BlockBuilder, call: Call) -> Expr:
+    t = call.args[1]
+    n_field = len(t.struct_info.fields)
+    fields = [bb.emit(TupleGetItem(t, i)) for i in range(n_field)]
+    return bb.call_te(topi.index_tensor, call.args[0], fields)
+
+
 @register_legalize("relax.scatter_elements")
 def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 79b1884aac..22b00cd704 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -101,6 +101,7 @@ from tvm.relax.op import (
     greater_equal,
     hint_on_device,
     image,
+    index_tensor,
     invoke_closure,
     invoke_pure_closure,
     isfinite,
@@ -785,6 +786,7 @@ __all__ = [
     "hexagon",
     "hint_on_device",
     "image",
+    "index_tensor",
     "invoke_closure",
     "invoke_pure_closure",
     "isfinite",
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index 37743e97a3..1ef6523059 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -1054,3 +1054,54 @@ def trilu(data, k, upper):
         return tvm.tir.Select(check_position, value, tvm.tir.const(0, 
data.dtype))
 
     return te.compute(data.shape, _apply_trilu, name="trilu", 
tag=topi.tag.ELEMWISE)
+
+
+def index_tensor(data, indices):
+    """Advanced‑tensor indexing (NumPy/PyTorch‐style).
+
+    Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this
+    operator selects elements from ``data`` as if one had written
+    ``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch:
+
+    * All index tensors must have an integer dtype.
+    * Their shapes are broadcast together to a common shape ``B`` in
+      the usual NumPy way.
+    * The result shape is ``B + data.shape[k:]`` (i.e. the broadcast
+      shape followed by the remaining axes of ``data`` that are *not*
+      indexed).
+    *  ``k`` must not exceed ``data.ndim``; otherwise a compile‑time
+       error is raised.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The tensor to be indexed.
+
+    indices : Sequence[tvm.te.Tensor]
+        A Python ``list`` / ``tuple`` of **k** index tensors,
+        or a `tvm.te.Tensor` tuple expression. Each tensor must have an
+        integer dtype.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        The tensor obtained after advanced indexing.  Its dtype equals
+        ``data.dtype``
+
+    Examples
+    --------
+    .. code-block:: python
+
+        x     = te.placeholder((3, 3),  name="x")        # shape (3,3)
+        row   = te.placeholder((2,),    name="row", dtype="int32")
+        col   = te.placeholder((2,),    name="col", dtype="int32")
+
+        # Equivalent to x[row, col] in NumPy / PyTorch
+        y = topi.index_tensor(x, [row, col])             # shape (2,)
+
+        # Broadcasting example:
+        row = te.placeholder((2, 1), name="row", dtype="int32")
+        col = te.placeholder((1, 3), name="col", dtype="int32")
+        z = topi.index_tensor(x, [row, col])             # shape (2, 3)
+    """
+    return topi.adv_index(data, indices)
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index 4abfe01387..f56135a35b 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -474,6 +474,151 @@ TVM_REGISTER_OP("relax.flatten")
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.index_tensor */
+
+Expr index_tensor(Expr first, Expr tensors) {
+  static const Op& op = Op::Get("relax.index_tensor");
+  return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor);
+
+StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& 
ctx) {
+  if (call->args.size() != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "Index.Tensor op should have 2 
arguments");
+  }
+
+  TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
+  Array<TensorStructInfo> indices_sinfo = GetTensorStructInfoFromTuple(call, 
ctx, call->args[1]);
+
+  if (indices_sinfo.empty()) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "index_tensor expects a non‑empty tuple of index 
tensors");
+  }
+
+  DataType output_dtype = data_sinfo->dtype;
+  int n_indices = static_cast<int>(indices_sinfo.size());
+  Optional<VDevice> vdev = data_sinfo->vdevice;
+
+  // Indices must be integers
+  for (int i = 0; i < n_indices; ++i) {
+    const auto& s = indices_sinfo[i];
+    if (!s->IsUnknownDtype() && !s->dtype.is_int()) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "index_tensor requires every index tensor to have an 
integer dtype; "
+                       << "index " << i << " has dtype " << s->dtype);
+    }
+  }
+
+  // Count of indices must be less than or equal to data.ndim
+  if (!data_sinfo->IsUnknownNdim() && n_indices > data_sinfo->ndim) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "index_tensor received " << n_indices
+                     << " index tensors, but data has only " << 
data_sinfo->ndim << " dimensions");
+  }
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  bool all_index_have_shape_value = true;
+  std::vector<Array<PrimExpr>> index_shapes;
+  int max_index_ndim = 0;
+
+  for (const auto& s : indices_sinfo) {
+    const auto* shp = s->shape.as<ShapeExprNode>();
+    if (!shp) {
+      all_index_have_shape_value = false;
+    } else {
+      index_shapes.push_back(shp->values);
+      max_index_ndim = std::max(max_index_ndim, 
static_cast<int>(shp->values.size()));
+    }
+    if (!s->IsUnknownNdim()) {
+      max_index_ndim = std::max(max_index_ndim, s->ndim);
+    }
+  }
+
+  Optional<Array<PrimExpr>> broadcast_shape;
+  bool shape_unknown = !all_index_have_shape_value;
+
+  if (all_index_have_shape_value) {
+    // initialise broadcast result with 1’s
+    Array<PrimExpr> out_shape;
+    for (int i = 0; i < max_index_ndim; ++i) {
+      out_shape.push_back(IntImm(DataType::Int(64), 1));
+    }
+
+    for (const auto& ishape : index_shapes) {
+      int cur_ndim = ishape.size();
+      for (int axis = 0; axis < max_index_ndim; ++axis) {
+        int lhs_axis = max_index_ndim - 1 - axis;  // aligned from right
+        int rhs_axis = cur_ndim - 1 - axis;
+        if (rhs_axis < 0) break;  // shorter rank – done
+
+        PrimExpr lhs_dim = out_shape[lhs_axis];
+        PrimExpr rhs_dim = ishape[rhs_axis];
+
+        const auto* lhs_int = lhs_dim.as<IntImmNode>();
+        const auto* rhs_int = rhs_dim.as<IntImmNode>();
+
+        // Case 1: current broadcast slot is 1 -> always replace
+        if (lhs_int && lhs_int->value == 1) {
+          out_shape.Set(lhs_axis, rhs_dim);
+          continue;
+        }
+        // Case 2: rhs is 1 -> keep lhs_dim unchanged
+        if (rhs_int && rhs_int->value == 1) {
+          continue;
+        }
+        // Both are non‑one constants: must equal
+        if (lhs_int && rhs_int && lhs_int->value != rhs_int->value) {
+          ctx->ReportFatal(Diagnostic::Error(call)
+                           << "index_tensor: cannot broadcast index shapes. 
Mismatch at axis "
+                           << lhs_axis << ": " << lhs_dim << " vs " << 
rhs_dim);
+        }
+        // Give up if not provablt equal
+        if (!analyzer->CanProveEqual(lhs_dim, rhs_dim)) {
+          shape_unknown = true;
+          break;
+        }
+      }
+      if (shape_unknown) break;
+    }
+
+    if (!shape_unknown) broadcast_shape = out_shape;
+  }
+
+  // Count of dimensions in output
+  int out_ndim = kUnknownNDim;
+  if (!data_sinfo->IsUnknownNdim()) {
+    int tail_ndim = data_sinfo->ndim - n_indices;
+    if (broadcast_shape.defined()) {
+      out_ndim = static_cast<int>(broadcast_shape.value().size()) + tail_ndim;
+    } else if (!shape_unknown) {
+      out_ndim = max_index_ndim + tail_ndim;
+    }
+  }
+
+  // Derive output shape
+  if (broadcast_shape.defined()) {
+    const auto* data_shape_expr = data_sinfo->shape.as<ShapeExprNode>();
+    if (data_shape_expr) {
+      Array<PrimExpr> result_shape = broadcast_shape.value();
+      for (int i = n_indices; i < data_sinfo->ndim; ++i) {
+        result_shape.push_back(data_shape_expr->values[i]);
+      }
+      return TensorStructInfo(ShapeExpr(result_shape), output_dtype, vdev);
+    }
+  }
+
+  // Unknown output shape
+  return TensorStructInfo(output_dtype, out_ndim, vdev);
+}
+
+TVM_REGISTER_OP("relax.index_tensor")
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input data.")
+    .add_argument("indices", "List of Tensors", "The indices used to index.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoIndexTensor)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.layout_transform */
 TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
 
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 7e5de217bc..4580f9191b 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -206,6 +206,18 @@ Expr gather_elements(Expr data, Expr indices, int axis = 
0);
  */
 Expr gather_nd(Expr data, Expr indices, int batch_dims = 0);
 
+/*!
+ * \brief NumPy/PyTorch‑style advanced indexing with tensors.
+ * \param data The input tensor.
+ * \param indices  A Tuple expression (or list) containing the index tensors.
+ * \return The indexed tensor.
+ *
+ * \note When all shapes are static, Relax checks that the index shapes are
+ *       broadcast-compatible. Bounds checking of the values in indices is
+ *       deferred to runtime.
+ */
+Expr index_tensor(Expr data, Expr indices);
+
 /*!
  * \brief Scatter updates into an array according to indices.
  * \param data The input tensor.
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/relax/test_from_exported_to_cuda.py
index e92855885e..76a4bb2039 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -63,6 +63,108 @@ def 
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
         np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, 
atol=1e-5)
 
 
[email protected]_targets("cuda")
+def test_index_tensor(target, dev):
+    class IndexModel0(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x[torch.tensor([0])]
+
+    torch_module = IndexModel0().eval()
+    raw_data = np.random.rand(3, 3).astype("float32")
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+    class IndexModel1(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x[torch.tensor([[0]])]
+
+    torch_module = IndexModel1().eval()
+    raw_data = np.random.rand(2, 3).astype("float32")
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+    class IndexTensorModel2(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x[torch.tensor([0, 2])]
+
+    torch_module = IndexTensorModel2().eval()
+    raw_data = np.random.rand(3, 4).astype("float32")
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+    class IndexTensorModel3(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x[[[[0, 2], [1, 3]]]]
+
+    torch_module = IndexTensorModel3().eval()
+    raw_data = np.random.rand(5, 5, 5).astype("float32")
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+    class IndexTensorModel4(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x[[[1, 4]]]
+
+    torch_module = IndexTensorModel4().eval()
+    raw_data = np.random.rand(5, 5, 5).astype("float32")
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+    class IndexTensorModel5(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x[[[[1, 2, 4]]]]
+
+    torch_module = IndexTensorModel5().eval()
+    raw_data = np.random.rand(5, 5, 5).astype("float32")
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+    class IndexTensorModel6(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x[[[0, 1], [0, 1]]]
+
+    torch_module = IndexTensorModel6().eval()
+    raw_data = np.random.rand(5, 5, 5, 5).astype("float32")
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+    class IndexTensorModel7(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x[[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 0]]]
+
+    torch_module = IndexTensorModel7().eval()
+    raw_data = np.random.rand(5, 5, 5, 5).astype("float32")
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+    class IndexTensorModel8(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return x[[[[0, 1], [2, 3]], [[2, 3], [3, 4]], [[2, 4], [1, 2]], 
[[0, 4], [0, 3]]]]
+
+    torch_module = IndexTensorModel8().eval()
+    raw_data = np.random.rand(5, 5, 5, 5).astype("float32")
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 @tvm.testing.parametrize_targets("cuda")
 def test_full(target, dev):
     class FullModel(nn.Module):
@@ -73,9 +175,7 @@ def test_full(target, dev):
             return torch.full((2, 3), 3.141592)
 
     torch_module = FullModel().eval()
-
     raw_data = np.random.rand(3, 3).astype("float32")
-
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
@@ -91,7 +191,6 @@ def test_full_like(target, dev):
 
     torch_module = FullLike().eval()
     raw_data = np.random.rand(2, 3).astype("float32")
-
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
@@ -105,9 +204,7 @@ def test_ones(target, dev):
             return torch.ones((2, 3))
 
     torch_module = FullModel().eval()
-
     raw_data = np.random.rand(1, 1).astype("float32")
-
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
@@ -583,10 +680,42 @@ def test_sum(target, dev):
             return new_vec.sum()
 
     torch_module = SumModel().eval()
-
     raw_data = np.random.rand(10, 10, 10).astype("float32")
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
[email protected]_targets("cuda")
+def test_mul(target, dev):
+    class MulModule(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.y = torch.tensor(np.random.rand(2, 3).astype("float32"))
+
+        def forward(self, x):
+            return x.mul(self.y)
+
+    torch_module = MulModule().eval()
+    raw_data = np.random.rand(2, 3).astype("float32")
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
[email protected]_targets("cuda")
+def test_concat(target, dev):
+    class ConcatFour(nn.Module):
+        def __init__(self, dim=0):
+            super(ConcatFour, self).__init__()
+            self.dim = dim
+            self.x2 = torch.randn(2, 3)
+            self.x3 = torch.randn(2, 3)
+            self.x4 = torch.randn(2, 3)
+
+        def forward(self, x):
+            return torch.cat((x, self.x2, self.x3, self.x4), dim=self.dim)
+
+    torch_module = ConcatFour().eval()
+    raw_data = np.random.rand(2, 3).astype("float32")
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to