This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7db0b984de [Unity][Op] Dynamic Strided Slice (#14548)
7db0b984de is described below

commit 7db0b984de6dd7d9ee7db79681476bc2bb7f26d5
Author: Sunghyun Park <49998730+sun...@users.noreply.github.com>
AuthorDate: Wed Apr 12 13:04:04 2023 -0700

    [Unity][Op] Dynamic Strided Slice (#14548)
    
    * feat: dyn_strided_slice op
    
    * feat: shape computation
    
    * feat: legalizer for dynamic strided slice
    
    * remove whitespace
    
    * reflect feedback
    
    * fix
    
    * fix
    
    * remove whitespace
---
 include/tvm/topi/transform.h                       |  37 +-
 python/tvm/relax/op/index.py                       |  37 ++
 python/tvm/relax/transform/legalize_ops/index.py   |  68 ++-
 python/tvm/script/ir_builder/relax/ir.py           |   3 +-
 python/tvm/topi/transform.py                       |  35 ++
 src/relax/op/tensor/index.cc                       |  78 +++-
 src/relax/transform/legalize_ops.cc                |  10 +-
 src/topi/transform.cc                              |   8 +
 tests/python/relax/test_e2e_op_dynamic.py          | 104 +++++
 tests/python/relax/test_op_index.py                | 197 +++++++++
 ..._transform_legalize_ops_index_linear_algebra.py | 489 +++++++++++++++++++++
 tests/python/topi/python/test_topi_transform.py    |  54 +++
 12 files changed, 1111 insertions(+), 9 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 0490ae7f1e..579dbb5833 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -2021,7 +2021,6 @@ inline Tensor adv_index(const Tensor& data, const 
Array<Tensor>& indices,
         for (size_t i = 0; i < broadcast_shape.size(); ++i) {
           tensor_indices.push_back(iter_var[i]);
         }
-
         Array<PrimExpr> real_indices;
         for (size_t i = 0; i < bindices.size(); ++i) {
           real_indices.push_back(bindices[i](tensor_indices));
@@ -2035,6 +2034,42 @@ inline Tensor adv_index(const Tensor& data, const 
Array<Tensor>& indices,
       name, tag);
 }
 
+namespace relax {
+// relax dynamic slice
+inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& 
begin,
+                                        const te::Tensor& end, const 
te::Tensor& strides,
+                                        Array<PrimExpr> output_shape,
+                                        std::string name = 
"T_strided_slice_dynamic",
+                                        std::string tag = kInjective) {
+  const size_t num_dynamic_axes = x.ndim();
+  ICHECK_EQ(begin.ndim(), 1);
+  ICHECK_EQ(end.ndim(), 1);
+  ICHECK_EQ(strides.ndim(), 1);
+  const auto* len_begin = begin->shape[0].as<IntImmNode>();
+  const auto* len_end = end->shape[0].as<IntImmNode>();
+  const auto* len_strides = strides->shape[0].as<IntImmNode>();
+  ICHECK(len_begin);
+  ICHECK(len_end);
+  ICHECK(len_strides);
+  ICHECK_EQ(len_begin->value, num_dynamic_axes);
+  ICHECK_EQ(len_end->value, num_dynamic_axes);
+  ICHECK_EQ(len_strides->value, num_dynamic_axes);
+
+  return te::compute(
+      output_shape,
+      [&](const Array<tvm::tir::Var>& indices) {
+        Array<PrimExpr> real_indices;
+        for (size_t i = 0; i < num_dynamic_axes; ++i) {
+          auto ind = make_const(DataType::Int(64), i);
+          real_indices.push_back(indices[i] * strides(ind) + 
tvm::min(begin(ind), x->shape[i] - 1));
+        }
+        return x(real_indices);
+      },
+      name, tag);
+}
+
+}  // namespace relax
+
 }  // namespace topi
 }  // namespace tvm
 #endif  // TVM_TOPI_TRANSFORM_H_
diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py
index b9acf11b9f..835c9350b0 100644
--- a/python/tvm/relax/op/index.py
+++ b/python/tvm/relax/op/index.py
@@ -91,3 +91,40 @@ def strided_slice(
     same length as `axes`.
     """
     return _ffi_api.strided_slice(x, axes, begin, end, strides)  # type: ignore
+
+
+def dynamic_strided_slice(
+    x: Expr,
+    begin: Expr,
+    end: Expr,
+    strides: Expr,
+) -> Expr:
+    """Dynamic strided slice of a tensor. `begin`, `end`, `strids` can be 
computed at runtime.
+
+    Parameters
+    ----------
+    x : Expr
+        The source tensor to be sliced.
+
+    begin : Expr
+        The indices to begin with in the slicing, inclusive.
+
+    end : Expr
+        The indices indicating end of the slice, exclusive.
+
+    strides : Expr
+        Specifies the stride values, it can be negative in that case,
+        the input tensor will be reversed in that particular axis.
+        If not specified, it by default is an list of ones of the same length 
as `axes`.
+
+    Returns
+    -------
+    ret : relax.Expr
+        The sliced result.
+
+    Note
+    ----
+    dyn_strided_slice require the input `begin`, `end` and `strides` to have 
the
+    same length as rank of `data` tensor.
+    """
+    return _ffi_api.dynamic_strided_slice(x, begin, end, strides)  # type: 
ignore
diff --git a/python/tvm/relax/transform/legalize_ops/index.py 
b/python/tvm/relax/transform/legalize_ops/index.py
index eccccc7c6d..8ee1bed9b9 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -18,9 +18,10 @@
 """Default legalization function for index operators."""
 import logging
 
-from tvm import topi, tir
+from tvm import topi, tir, te
 from ...block_builder import BlockBuilder
-from ...expr import Call, Expr
+from ...expr import Call, Expr, ExternFunc
+from ...struct_info import ShapeStructInfo
 from .common import register_legalize
 
 
@@ -59,3 +60,66 @@ def _strided_slice(bb: BlockBuilder, call: Call) -> Expr:
         call.attrs.axes,
         slice_mode="end",
     )
+
+
+@register_legalize("relax.dynamic_strided_slice")
+def _dynamic_strided_slice(bb: BlockBuilder, call: Call) -> Expr:
+    assert len(call.args) == 4
+    data, begin, end, strides = call.args
+
+    # 1. Insert shape function
+    def shape_func(data, begin, end, strides):
+        def _compute(i):
+            def canonicalize_index(index, extent, strides):
+                begin_range = tir.Select(strides < 0, tir.const(-1, "int64"), 
tir.const(0, "int64"))
+                end_range = tir.Select(strides < 0, extent - 1, extent)
+                index = tir.Select(index < 0, index + extent, index)
+                return tir.Min(tir.Max(index, begin_range), end_range)
+
+            def get_length(begin, end, strides, length):
+                begin = canonicalize_index(begin, length, strides)
+                end = canonicalize_index(end, length, strides)
+                len1 = tir.ceildiv(begin - end, -strides)
+                len2 = tir.ceildiv(end - begin, strides)
+                return tir.Select(strides < 0, len1, len2)
+
+            length = tir.const(-1, "int64")
+            for idx in range(data.ndim):
+                length = tir.Select(i == tir.const(idx, "int64"), 
data.shape[idx], length)
+
+            return get_length(begin[i], end[i], strides[i], length)
+
+        return te.compute((begin.shape[0],), _compute, 
name="T_shape_func_strided_slice_dynamic")
+
+    output_shape = bb.normalize(
+        bb.call_te(
+            shape_func,
+            data,
+            begin,
+            end,
+            strides,
+        )
+    )
+
+    # 2. Convert tensor to shape and match cast with new symbolic vars
+    # Get shape length
+    ndim = int(output_shape.struct_info.shape[0])
+    output_shape = bb.emit(
+        Call(
+            ExternFunc("vm.builtin.tensor_to_shape"),
+            [output_shape],
+            sinfo_args=[ShapeStructInfo(ndim=ndim)],
+        )
+    )
+    output_shape_vars = [tir.Var("s", "int64") for i in range(ndim)]
+    bb.match_cast(output_shape, ShapeStructInfo(output_shape_vars))
+
+    # 3. Pass the output shape vars to TOPI
+    return bb.call_te(
+        topi.dynamic_strided_slice,
+        call.args[0],
+        call.args[1],
+        call.args[2],
+        call.args[3],
+        output_shape=output_shape_vars,
+    )
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index e390658c8f..39327c4b4a 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -111,6 +111,7 @@ from tvm.relax.op import (
     shape_of,
     std,
     strided_slice,
+    dynamic_strided_slice,
     sum,
     take,
     variance,
@@ -639,7 +640,6 @@ __all__ = [
     "ShapeExpr",
     "std",
     "str",
-    "strided_slice",
     "sum",
     "sigmoid",
     "sign",
@@ -652,6 +652,7 @@ __all__ = [
     "stop_lift_params",
     "str",
     "strided_slice",
+    "dynamic_strided_slice",
     "subtract",
     "take",
     "tan",
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index e4fe3c5839..7807351e90 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -227,6 +227,41 @@ def strided_slice(a, begin, end, strides=None, axes=None, 
slice_mode="end"):
     return cpp.strided_slice(a, begin, end, strides, axes, slice_mode)
 
 
+def dynamic_strided_slice(a, begin, end, strides, output_shape):
+    """Slice of an array.
+
+    Parameters
+    ----------
+    a : tvm.te.Tensor
+        The tensor to be sliced.
+
+    begin : tvm.te.Tensor
+        The indices to begin with in the slicing.
+
+    end : tvm.te.Tensor
+        Indices indicating end of the slice.
+
+    strides : tvm.te.Tensor
+        Specifies the stride values, it can be negative
+        in that case, the input tensor will be reversed
+        in that particular axis.
+
+    output_shape: list of PrimExpr
+        Specifies the output shape
+
+    Returns
+    -------
+    ret : tvm.te.Tensor
+    """
+    if not isinstance(begin, tvm.te.Tensor):
+        begin = const_vector(begin)
+    if not isinstance(end, tvm.te.Tensor):
+        end = const_vector(end)
+    if not isinstance(strides, tvm.te.Tensor):
+        strides = const_vector(strides)
+    return cpp.relax_dynamic_strided_slice(a, begin, end, strides, 
output_shape)
+
+
 @tvm.te.tag_scope(tag=tag.INJECTIVE + ",strided_set")
 def strided_set(a, v, begin, end, strides=None):
     """Set slice of an array.
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index b627e1425b..c3d38db4e1 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -190,7 +190,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, 
const BlockBuilder& ctx
                                 : Array<PrimExpr>(n_axis, 
IntImm(DataType::Int(64), 1));
   std::vector<int64_t> int_strides;
   int_strides.reserve(n_axis);
-  // Only do output shape inference when all the begin/end/stride values are 
integers.
+  // Only do output shape inference when all the begin/end/strides values are 
integers.
   for (int i = 0; i < n_axis; ++i) {
     const auto* int_begin = attrs->begin[i].as<IntImmNode>();
     const auto* int_end = attrs->end[i].as<IntImmNode>();
@@ -204,7 +204,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, 
const BlockBuilder& ctx
   Array<PrimExpr> output_shape = data_shape->values;
   for (int i = 0; i < n_axis; ++i) {
     ICHECK_NE(int_strides[i], 0)
-        << "Strided slice requires stride to be non-zero but got 0 for axis " 
<< axes[i] << ".";
+        << "Strided slice requires strides to be non-zero but got 0 for axis " 
<< axes[i] << ".";
     output_shape.Set(axes[i], GetLength(attrs->begin[i], attrs->end[i], 
int_strides[i],
                                         data_shape->values[axes[i]]));
   }
@@ -239,5 +239,79 @@ TVM_REGISTER_OP("relax.strided_slice")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
+/* relax.dynamic_strided_slice */
+Expr dynamic_strided_slice(Expr x,      //
+                           Expr begin,  //
+                           Expr end,    //
+                           Expr strides) {
+  static const Op& op = Op::Get("relax.dynamic_strided_slice");
+  return Call(op, {std::move(x), std::move(begin), std::move(end), 
std::move(strides)}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice);
+
+StructInfo InferStructInfoDynStridedSlice(const Call& call, const 
BlockBuilder& ctx) {
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* begin_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  const auto* end_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+  const auto* strides_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[3]);
+
+  ICHECK(data_sinfo);
+  if (data_sinfo->IsUnknownNdim()) {
+    LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes 
begin/end/strides "
+                    "tensors are well-formed. It could produce runtime error 
when this assumption "
+                    "turns out to be wrong.";
+    return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
+  }
+  if (data_sinfo->IsUnknownDtype()) {
+    LOG(WARNING) << "When data type is unknown, dynamic strided slice assumes 
to have a valid "
+                    "dtype. It could produce runtime error when this 
assumption "
+                    "turns out to be wrong.";
+  }
+
+  int n_axis = data_sinfo->ndim;
+  auto diag_def = [&](const TensorStructInfoNode* sinfo, String name) {
+    ICHECK(sinfo) << "Dynamic strided slice requires the input " << name
+                  << " to be have the struct info. Please try normalizing the 
inputs.";
+    CHECK_EQ(sinfo->ndim, 1) << "Dynamic strided slice requires " << name
+                             << " to be 1d tensor (list of values).";
+    const auto* shape = sinfo->shape.as<ShapeExprNode>();
+    ICHECK(shape) << "Dynamic strided slice requires the input " << name
+                  << " to have well-defined shape.";
+    // NOTE(tvm-team): This strong restriction seems necessary for now until 
we have a generic
+    // solution in converting 1d Tensor with unknown num_elem to 
Array<PrimExpr>.
+    const auto* num_elem = shape->values[0].as<IntImmNode>();
+    ICHECK(num_elem) << "Dynamic strided slice requires the input " << name
+                     << " to have a known integer shape value.";
+    CHECK_EQ(num_elem->value, n_axis) << "Dynamic strided slice requires the 
number of indices in "
+                                      << name << " to equal the number of 
axes.";
+    if (sinfo->IsUnknownDtype()) {
+      LOG(WARNING) << "Dynamic strided slice assumes " << name
+                   << " to be int64 when it is not specified.";
+    } else {
+      CHECK(sinfo->dtype == DataType::Int(64))
+          << "Dynamic strided_slice expects the input " << name
+          << "values to be all int64. However, " << name << " has dtype " << 
sinfo->dtype << ".";
+    }
+  };
+  diag_def(begin_sinfo, "begin");
+  diag_def(end_sinfo, "end");
+  diag_def(strides_sinfo, "strides");
+
+  // The output shape will depend on the runtime value in begin/end/strides 
tensors.
+  // TODO(tvm-team): Currently, it is unable to express partially-static 
shape. Revisit when
+  // PrimValue lands.
+  return TensorStructInfo(data_sinfo->dtype, n_axis);
+}  // namespace relax
+
+// TODO(tvm-team): Register FRelaxInferLayout, TMixedPrecisionPolicy
+TVM_REGISTER_OP("relax.dynamic_strided_slice")
+    .set_num_inputs(4)
+    .add_argument("x", "Tensor", "The source tensor to be sliced.")
+    .add_argument("begin", "Tensor", "The indices to begin with in the 
slicing.")
+    .add_argument("end", "Tensor", "Indices indicating end of the slice.")
+    .add_argument("strides", "Tensor", "The stride values.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoDynStridedSlice);
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/transform/legalize_ops.cc 
b/src/relax/transform/legalize_ops.cc
index 350a40c37b..7c5393c6ca 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -83,15 +83,19 @@ class LegalizeMutator : public ExprMutator {
       return visited_call;
     }
 
+    auto op = GetRef<Op>(op_node);
+    std::string op_name(op->name);
+    bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos);
     // Not all shape values are known
+    // Data-dependent ops are exception since their output shape will be 
identified at runtime.
+    // Legalizer will insert their shape functions, which are manually 
registered, and match cast
+    // to define symbolic output shape at compile time.
     if (!std::all_of(visited_call->args.begin(), visited_call->args.end(),
                      [](Expr arg) { return 
KnowAllShapeValues(GetStructInfo(arg)); }) ||
-        !KnowAllShapeValues(GetStructInfo(visited_call))) {
+        (!is_data_dependent_op && 
!KnowAllShapeValues(GetStructInfo(visited_call)))) {
       return visited_call;
     }
 
-    auto op = GetRef<Op>(op_node);
-
     // Priority: customize > default.
     // Check if it has customize legalization registered.
     if (cmap_.defined() && cmap_.value().count(op->name)) {
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index bbefa19c20..906f2b9115 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -201,6 +201,14 @@ 
TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice").set_body([](TVMArgs args, TVMR
   *rv = dynamic_strided_slice(args[0], begin, end, strides);
 });
 
+TVM_REGISTER_GLOBAL("topi.relax_dynamic_strided_slice").set_body([](TVMArgs 
args, TVMRetValue* rv) {
+  te::Tensor begin = args[1];
+  te::Tensor end = args[2];
+  te::Tensor strides = args[3];
+  Array<PrimExpr> output_shape = args[4];
+  *rv = relax::dynamic_strided_slice(args[0], begin, end, strides, 
output_shape);
+});
+
 TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) 
{
   int depth = args[3];
   int axis = args[4];
diff --git a/tests/python/relax/test_e2e_op_dynamic.py 
b/tests/python/relax/test_e2e_op_dynamic.py
new file mode 100644
index 0000000000..1e9414c15d
--- /dev/null
+++ b/tests/python/relax/test_e2e_op_dynamic.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+import tvm
+from tvm import relax
+import tvm.topi.testing
+from tvm.relax.transform import LegalizeOps
+from tvm.script import relax as R, tir as T
+import tvm.testing
+
+# TODO(tvm-team): `tir.transform.DefaultGPUSchedule` does not work.
+target, dev = "llvm", tvm.cpu()
+
+
+def build(mod):
+    exe = relax.build(mod, target=target)
+    return relax.VirtualMachine(exe, dev)
+
+
+@pytest.mark.parametrize(
+    "begin, end, strides",
+    [
+        ([0, 2, 4, 4], [5, 5, 7, 8], [1, 1, 2, 3]),
+        ([0, 2, 4, 4], [5, 5, 11, 10], [1, 1, 1, 1]),
+        ([0, 2, 10, 14], [0, 5, 1, 1], [1, 1, -1, -2]),
+    ],
+)
+def test_dynamic_strided_slice(begin, end, strides):
+    # fmt: off
+    @tvm.script.ir_module
+    class DynamicStridedSlice:
+        @R.function
+        def main(x: R.Tensor((8, 9, 10, 10), "float32"), begin: 
R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides: 
R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4):
+            gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, 
begin, end, strides)
+            return gv
+    # fmt: on
+    mod = LegalizeOps()(DynamicStridedSlice)
+    with tvm.target.Target(target):
+        mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+    vm = build(mod)
+
+    x_np = np.random.rand(8, 9, 10, 10).astype(np.float32)
+    data_nd = tvm.nd.array(x_np, dev)
+    begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev)
+    end_nd = tvm.nd.array(np.array(end).astype("int64"), dev)
+    strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev)
+
+    # Reference implementation
+    out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides)
+    out_nd = vm["main"](data_nd, begin_nd, end_nd, strides_nd)
+    tvm.testing.assert_allclose(out_nd.numpy(), out_npy)
+
+
+@pytest.mark.parametrize(
+    "begin, end, strides",
+    [
+        ([0, 2, 4, 4], [5, 5, 7, 8], [1, 1, 2, 3]),
+        ([0, 2, 4, 4], [5, 5, 11, 10], [1, 1, 1, 1]),
+        ([0, 2, 10, 14], [0, 5, 1, 1], [1, 1, -1, -2]),
+    ],
+)
+def test_dynamic_strided_slice_symbolic(begin, end, strides):
+    # fmt: off
+    @tvm.script.ir_module
+    class DynamicStridedSlice:
+        @R.function
+        def main(x: R.Tensor(("m", "n", 10, 10), "float32"), begin: 
R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides: 
R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4):
+            m = T.int64()
+            n = T.int64()
+            gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, 
begin, end, strides)
+            return gv
+    # fmt: on
+    mod = LegalizeOps()(DynamicStridedSlice)
+    vm = build(mod)
+
+    x_np = np.random.rand(8, 9, 10, 10).astype(np.float32)
+    data_nd = tvm.nd.array(x_np, dev)
+    begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev)
+    end_nd = tvm.nd.array(np.array(end).astype("int64"), dev)
+    strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev)
+
+    # Reference implementation
+    out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides)
+    out_nd = vm["main"](data_nd, begin_nd, end_nd, strides_nd)
+    tvm.testing.assert_allclose(out_nd.numpy(), out_npy)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_op_index.py 
b/tests/python/relax/test_op_index.py
index 9390a2c9b0..8b2f8c0b29 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -30,6 +30,7 @@ def test_op_correctness():
     assert relax.op.strided_slice(x, axes=[0], begin=[0], end=[2]).op == 
Op.get(
         "relax.strided_slice"
     )
+    assert relax.op.dynamic_strided_slice(x, x, x, x).op == 
Op.get("relax.dynamic_strided_slice")
 
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
@@ -673,5 +674,201 @@ def 
test_strided_slice_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]))
 
 
+def test_dynamic_strided_slice_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((8, 9, 10, 10)))
+    x4 = relax.Var("x", R.Tensor(ndim=4))
+    x5 = relax.Var("x", R.Tensor())
+
+    b0 = relax.Var("begin", R.Tensor((4,), "int64"))
+    e0 = relax.Var("end", R.Tensor((4,), "int64"))
+    s0 = relax.Var("strides", R.Tensor((4,), "int64"))
+    b1 = relax.Var("begin", R.Tensor((4,)))
+    e1 = relax.Var("end", R.Tensor((4,)))
+    s1 = relax.Var("stride", R.Tensor((4,)))
+
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x0, b0, e0, s0),
+        R.Tensor("float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x1, b0, e0, s0),
+        R.Tensor("float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x2, b0, e0, s0),
+        R.Tensor("float32", ndim=-1),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x3, b0, e0, s0),
+        R.Tensor(ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x4, b0, e0, s0),
+        R.Tensor(ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x5, b0, e0, s0),
+        R.Tensor(ndim=-1),
+    )
+
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x0, b1, e1, s1),
+        R.Tensor("float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x1, b1, e1, s1),
+        R.Tensor("float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x2, b1, e1, s1),
+        R.Tensor("float32", ndim=-1),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x3, b1, e1, s1),
+        R.Tensor(ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x4, b1, e1, s1),
+        R.Tensor(ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x5, b1, e1, s1),
+        R.Tensor(ndim=-1),
+    )
+
+
+def test_dynamic_strided_slice_infer_struct_info_symbolic():
+    bb = relax.BlockBuilder()
+    i = tir.Var("i", "int64")
+    j = tir.Var("j", "int64")
+    k = tir.Var("k", "int64")
+    l = tir.Var("l", "int64")
+    x0 = relax.Var("x", R.Tensor((i, j, k, l), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((i, j, k, l)))
+    x4 = relax.Var("x", R.Tensor(ndim=4))
+    x5 = relax.Var("x", R.Tensor())
+
+    b0 = relax.Var("begin", R.Tensor((4,), "int64"))
+    e0 = relax.Var("end", R.Tensor((4,), "int64"))
+    s0 = relax.Var("stride", R.Tensor((4,), "int64"))
+    b1 = relax.Var("begin", R.Tensor((4,)))
+    e1 = relax.Var("end", R.Tensor((4,)))
+    s1 = relax.Var("stride", R.Tensor((4,)))
+
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x0, b0, e0, s0),
+        R.Tensor("float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x1, b0, e0, s0),
+        R.Tensor("float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x2, b0, e0, s0),
+        R.Tensor("float32", ndim=-1),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x3, b0, e0, s0),
+        R.Tensor(ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x4, b0, e0, s0),
+        R.Tensor(ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x5, b0, e0, s0),
+        R.Tensor(ndim=-1),
+    )
+
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x0, b1, e1, s1),
+        R.Tensor("float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x1, b1, e1, s1),
+        R.Tensor("float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x2, b1, e1, s1),
+        R.Tensor("float32", ndim=-1),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x3, b1, e1, s1),
+        R.Tensor(ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x4, b1, e1, s1),
+        R.Tensor(ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.dynamic_strided_slice(x5, b1, e1, s1),
+        R.Tensor(ndim=-1),
+    )
+
+
+def test_dynamic_strided_slice_infer_struct_info_arg_wrong_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
+    b0 = relax.Var("begin", R.Tensor((4,), "float32"))
+    e0 = relax.Var("end", R.Tensor((4,), "float32"))
+    s0 = relax.Var("stride", R.Tensor((4,), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.strided_slice(x0, b0, e0, s0))
+
+
+def test_dynamic_strided_slice_infer_struct_info_arg_wrong_shape_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
+    m = tir.Var("m", "int64")
+    # invalid arg
+    b0 = relax.Var("begin", R.Tensor("int64", ndim=2))
+    b1 = relax.Var("begin", R.Tensor((1,), "int64"))
+    b2 = relax.Var("begin", R.Tensor((2, 2), "int64"))
+    b3 = relax.Var("begin", R.Tensor((m,), "int64"))
+    # valid args
+    e0 = relax.Var("end", R.Tensor((4,), "int64"))
+    s0 = relax.Var("stride", R.Tensor((4,), "int64"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.strided_slice(x0, b0, e0, s0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.strided_slice(x0, b1, e0, s0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.strided_slice(x0, b2, e0, s0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.strided_slice(x0, b3, e0, s0))
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py 
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index 39224240ef..d7c0b54af2 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -210,6 +210,495 @@ def test_strided_slice_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_dynamic_strided_slice():
+    # fmt: off
+    @tvm.script.ir_module
+    class DynamicStridedSlice:
+        @R.function
+        def main(x: R.Tensor((8, 9, 10, 10), "float32"), begin: 
R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides: 
R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4):
+            gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, 
begin, end, strides)
+            return gv
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def dynamic_strided_slice(
+            rxplaceholder: T.Buffer(
+                (T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"
+            ),
+            rxplaceholder_1: T.Buffer((T.int64(4),), "int64"),
+            rxplaceholder_2: T.Buffer((T.int64(4),), "int64"),
+            rxplaceholder_3: T.Buffer((T.int64(4),), "int64"),
+            var_T_strided_slice_dynamic: T.handle,
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            s, s_1, s_2, s_3 = T.int64(), T.int64(), T.int64(), T.int64()
+            T_strided_slice_dynamic = T.match_buffer(
+                var_T_strided_slice_dynamic, (s, s_1, s_2, s_3)
+            )
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(s, s_1, s_2, s_3):
+                with T.block("T_strided_slice_dynamic"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(
+                        rxplaceholder[
+                            T.min(rxplaceholder_1[T.int64(0)], T.int64(7))
+                            + v_ax0 * rxplaceholder_3[T.int64(0)],
+                            T.min(rxplaceholder_1[T.int64(1)], T.int64(8))
+                            + v_ax1 * rxplaceholder_3[T.int64(1)],
+                            T.min(rxplaceholder_1[T.int64(2)], T.int64(9))
+                            + v_ax2 * rxplaceholder_3[T.int64(2)],
+                            T.min(rxplaceholder_1[T.int64(3)], T.int64(9))
+                            + v_ax3 * rxplaceholder_3[T.int64(3)],
+                        ],
+                        rxplaceholder_1[T.int64(0) : T.int64(4)],
+                        rxplaceholder_3[T.int64(0) : T.int64(4)],
+                    )
+                    T.writes(T_strided_slice_dynamic[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                    T_strided_slice_dynamic[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[
+                        T.min(rxplaceholder_1[T.int64(0)], T.int64(7))
+                        + v_ax0 * rxplaceholder_3[T.int64(0)],
+                        T.min(rxplaceholder_1[T.int64(1)], T.int64(8))
+                        + v_ax1 * rxplaceholder_3[T.int64(1)],
+                        T.min(rxplaceholder_1[T.int64(2)], T.int64(9))
+                        + v_ax2 * rxplaceholder_3[T.int64(2)],
+                        T.min(rxplaceholder_1[T.int64(3)], T.int64(9))
+                        + v_ax3 * rxplaceholder_3[T.int64(3)],
+                    ]
+
+        @T.prim_func
+        def shape_func(
+            rxplaceholder: T.Buffer(
+                (T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"
+            ),
+            rxplaceholder_1: T.Buffer((T.int64(4),), "int64"),
+            rxplaceholder_2: T.Buffer((T.int64(4),), "int64"),
+            rxplaceholder_3: T.Buffer((T.int64(4),), "int64"),
+            T_shape_func_strided_slice_dynamic: T.Buffer((T.int64(4),), 
"int64"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i in range(T.int64(4)):
+                with T.block("T_shape_func_strided_slice_dynamic"):
+                    v_i = T.axis.spatial(T.int64(4), i)
+                    T.reads(
+                        rxplaceholder_3[v_i], rxplaceholder_1[v_i], 
rxplaceholder_2[v_i]
+                    )
+                    T.writes(T_shape_func_strided_slice_dynamic[v_i])
+                    T_shape_func_strided_slice_dynamic[v_i] = T.Select(
+                        rxplaceholder_3[v_i] < T.int64(0),
+                        (
+                            T.min(
+                                T.max(
+                                    T.Select(
+                                        rxplaceholder_1[v_i] < T.int64(0),
+                                        rxplaceholder_1[v_i]
+                                        + T.Select(
+                                            v_i == T.int64(3),
+                                            T.int64(10),
+                                            T.Select(
+                                                v_i == T.int64(2),
+                                                T.int64(10),
+                                                T.Select(
+                                                    v_i == T.int64(1),
+                                                    T.int64(9),
+                                                    T.Select(
+                                                        v_i == T.int64(0),
+                                                        T.int64(8),
+                                                        T.int64(-1),
+                                                    ),
+                                                ),
+                                            ),
+                                        ),
+                                        rxplaceholder_1[v_i],
+                                    ),
+                                    T.int64(-1),
+                                ),
+                                T.Select(
+                                    v_i == T.int64(3),
+                                    T.int64(10),
+                                    T.Select(
+                                        v_i == T.int64(2),
+                                        T.int64(10),
+                                        T.Select(
+                                            v_i == T.int64(1),
+                                            T.int64(9),
+                                            T.Select(
+                                                v_i == T.int64(0), T.int64(8), 
T.int64(-1)
+                                            ),
+                                        ),
+                                    ),
+                                )
+                                - T.int64(1),
+                            )
+                            - T.min(
+                                T.max(
+                                    T.Select(
+                                        rxplaceholder_2[v_i] < T.int64(0),
+                                        rxplaceholder_2[v_i]
+                                        + T.Select(
+                                            v_i == T.int64(3),
+                                            T.int64(10),
+                                            T.Select(
+                                                v_i == T.int64(2),
+                                                T.int64(10),
+                                                T.Select(
+                                                    v_i == T.int64(1),
+                                                    T.int64(9),
+                                                    T.Select(
+                                                        v_i == T.int64(0),
+                                                        T.int64(8),
+                                                        T.int64(-1),
+                                                    ),
+                                                ),
+                                            ),
+                                        ),
+                                        rxplaceholder_2[v_i],
+                                    ),
+                                    T.int64(-1),
+                                ),
+                                T.Select(
+                                    v_i == T.int64(3),
+                                    T.int64(10),
+                                    T.Select(
+                                        v_i == T.int64(2),
+                                        T.int64(10),
+                                        T.Select(
+                                            v_i == T.int64(1),
+                                            T.int64(9),
+                                            T.Select(
+                                                v_i == T.int64(0), T.int64(8), 
T.int64(-1)
+                                            ),
+                                        ),
+                                    ),
+                                )
+                                - T.int64(1),
+                            )
+                            - rxplaceholder_3[v_i]
+                            - T.int64(1)
+                        )
+                        // (rxplaceholder_3[v_i] * T.int64(-1)),
+                        (
+                            T.min(
+                                T.max(
+                                    T.Select(
+                                        rxplaceholder_2[v_i] < T.int64(0),
+                                        rxplaceholder_2[v_i]
+                                        + T.Select(
+                                            v_i == T.int64(3),
+                                            T.int64(10),
+                                            T.Select(
+                                                v_i == T.int64(2),
+                                                T.int64(10),
+                                                T.Select(
+                                                    v_i == T.int64(1),
+                                                    T.int64(9),
+                                                    T.Select(
+                                                        v_i == T.int64(0),
+                                                        T.int64(8),
+                                                        T.int64(-1),
+                                                    ),
+                                                ),
+                                            ),
+                                        ),
+                                        rxplaceholder_2[v_i],
+                                    ),
+                                    T.int64(0),
+                                ),
+                                T.Select(
+                                    v_i == T.int64(3),
+                                    T.int64(10),
+                                    T.Select(
+                                        v_i == T.int64(2),
+                                        T.int64(10),
+                                        T.Select(
+                                            v_i == T.int64(1),
+                                            T.int64(9),
+                                            T.Select(
+                                                v_i == T.int64(0), T.int64(8), 
T.int64(-1)
+                                            ),
+                                        ),
+                                    ),
+                                ),
+                            )
+                            + rxplaceholder_3[v_i]
+                            - T.min(
+                                T.max(
+                                    T.Select(
+                                        rxplaceholder_1[v_i] < T.int64(0),
+                                        rxplaceholder_1[v_i]
+                                        + T.Select(
+                                            v_i == T.int64(3),
+                                            T.int64(10),
+                                            T.Select(
+                                                v_i == T.int64(2),
+                                                T.int64(10),
+                                                T.Select(
+                                                    v_i == T.int64(1),
+                                                    T.int64(9),
+                                                    T.Select(
+                                                        v_i == T.int64(0),
+                                                        T.int64(8),
+                                                        T.int64(-1),
+                                                    ),
+                                                ),
+                                            ),
+                                        ),
+                                        rxplaceholder_1[v_i],
+                                    ),
+                                    T.int64(0),
+                                ),
+                                T.Select(
+                                    v_i == T.int64(3),
+                                    T.int64(10),
+                                    T.Select(
+                                        v_i == T.int64(2),
+                                        T.int64(10),
+                                        T.Select(
+                                            v_i == T.int64(1),
+                                            T.int64(9),
+                                            T.Select(
+                                                v_i == T.int64(0), T.int64(8), 
T.int64(-1)
+                                            ),
+                                        ),
+                                    ),
+                                ),
+                            )
+                            - T.int64(1)
+                        )
+                        // rxplaceholder_3[v_i],
+                    )
+
+        @R.function
+        def main(
+            x: R.Tensor((8, 9, 10, 10), dtype="float32"),
+            begin: R.Tensor((4,), dtype="int64"),
+            end: R.Tensor((4,), dtype="int64"),
+            strides: R.Tensor((4,), dtype="int64"),
+        ) -> R.Tensor(dtype="float32", ndim=4):
+            s = T.int64()
+            s_1 = T.int64()
+            s_2 = T.int64()
+            s_3 = T.int64()
+            gv = R.call_tir(
+                Expected.shape_func,
+                (x, begin, end, strides),
+                out_sinfo=R.Tensor((4,), dtype="int64"),
+            )
+            gv1: R.Shape(ndim=4) = R.call_packed(
+                "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=4),)
+            )
+            gv2: R.Shape([s, s_1, s_2, s_3]) = R.match_cast(
+                gv1, R.Shape([s, s_1, s_2, s_3])
+            )
+            gv_1 = R.call_tir(
+                Expected.dynamic_strided_slice,
+                (x, begin, end, strides),
+                out_sinfo=R.Tensor((s, s_1, s_2, s_3), dtype="float32"),
+            )
+            return gv_1
+    # fmt: on
+    mod = LegalizeOps()(DynamicStridedSlice)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_dynamic_strided_slice_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class DynamicStridedSlice:
+        @R.function
+        def main(x: R.Tensor((10, "n"), "float32"), begin:R.Tensor((2,), 
"int64"), end:R.Tensor((2,), "int64"), strides:R.Tensor((2,), "int64")) -> 
R.Tensor("float32", ndim=2):
+            n = T.int64()
+            gv: R.Tensor("float32", ndim=2) = R.dynamic_strided_slice(x, 
begin, end, strides)
+            return gv
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def dynamic_strided_slice(
+            var_rxplaceholder: T.handle,
+            rxplaceholder: T.Buffer((T.int64(2),), "int64"),
+            rxplaceholder_1: T.Buffer((T.int64(2),), "int64"),
+            rxplaceholder_2: T.Buffer((T.int64(2),), "int64"),
+            var_T_strided_slice_dynamic: T.handle,
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n = T.int64()
+            rxplaceholder_3 = T.match_buffer(var_rxplaceholder, (T.int64(10), 
n))
+            s, s_1 = T.int64(), T.int64()
+            T_strided_slice_dynamic = 
T.match_buffer(var_T_strided_slice_dynamic, (s, s_1))
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(s, s_1):
+                with T.block("T_strided_slice_dynamic"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(
+                        rxplaceholder_3[
+                            T.min(rxplaceholder[T.int64(0)], T.int64(9))
+                            + v_ax0 * rxplaceholder_2[T.int64(0)],
+                            T.min(rxplaceholder[T.int64(1)], n - T.int64(1))
+                            + v_ax1 * rxplaceholder_2[T.int64(1)],
+                        ],
+                        rxplaceholder[T.int64(0) : T.int64(2)],
+                        rxplaceholder_2[T.int64(0) : T.int64(2)],
+                    )
+                    T.writes(T_strided_slice_dynamic[v_ax0, v_ax1])
+                    T_strided_slice_dynamic[v_ax0, v_ax1] = rxplaceholder_3[
+                        T.min(rxplaceholder[T.int64(0)], T.int64(9))
+                        + v_ax0 * rxplaceholder_2[T.int64(0)],
+                        T.min(rxplaceholder[T.int64(1)], n - T.int64(1))
+                        + v_ax1 * rxplaceholder_2[T.int64(1)],
+                    ]
+
+        @T.prim_func
+        def shape_func(
+            var_rxplaceholder: T.handle,
+            rxplaceholder: T.Buffer((T.int64(2),), "int64"),
+            rxplaceholder_1: T.Buffer((T.int64(2),), "int64"),
+            rxplaceholder_2: T.Buffer((T.int64(2),), "int64"),
+            T_shape_func_strided_slice_dynamic: T.Buffer((T.int64(2),), 
"int64"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n = T.int64()
+            rxplaceholder_3 = T.match_buffer(var_rxplaceholder, (T.int64(10), 
n))
+            # with T.block("root"):
+            for i in range(T.int64(2)):
+                with T.block("T_shape_func_strided_slice_dynamic"):
+                    v_i = T.axis.spatial(T.int64(2), i)
+                    T.reads(rxplaceholder_2[v_i], rxplaceholder[v_i], 
rxplaceholder_1[v_i])
+                    T.writes(T_shape_func_strided_slice_dynamic[v_i])
+                    T_shape_func_strided_slice_dynamic[v_i] = T.Select(
+                        rxplaceholder_2[v_i] < T.int64(0),
+                        (
+                            T.min(
+                                T.max(
+                                    T.Select(
+                                        rxplaceholder[v_i] < T.int64(0),
+                                        rxplaceholder[v_i]
+                                        + T.Select(
+                                            v_i == T.int64(1),
+                                            n,
+                                            T.Select(
+                                                v_i == T.int64(0), 
T.int64(10), T.int64(-1)
+                                            ),
+                                        ),
+                                        rxplaceholder[v_i],
+                                    ),
+                                    T.int64(-1),
+                                ),
+                                T.Select(
+                                    v_i == T.int64(1),
+                                    n,
+                                    T.Select(v_i == T.int64(0), T.int64(10), 
T.int64(-1)),
+                                )
+                                - T.int64(1),
+                            )
+                            - T.min(
+                                T.max(
+                                    T.Select(
+                                        rxplaceholder_1[v_i] < T.int64(0),
+                                        rxplaceholder_1[v_i]
+                                        + T.Select(
+                                            v_i == T.int64(1),
+                                            n,
+                                            T.Select(
+                                                v_i == T.int64(0), 
T.int64(10), T.int64(-1)
+                                            ),
+                                        ),
+                                        rxplaceholder_1[v_i],
+                                    ),
+                                    T.int64(-1),
+                                ),
+                                T.Select(
+                                    v_i == T.int64(1),
+                                    n,
+                                    T.Select(v_i == T.int64(0), T.int64(10), 
T.int64(-1)),
+                                )
+                                - T.int64(1),
+                            )
+                            - rxplaceholder_2[v_i]
+                            - T.int64(1)
+                        )
+                        // (rxplaceholder_2[v_i] * T.int64(-1)),
+                        (
+                            T.min(
+                                T.max(
+                                    T.Select(
+                                        rxplaceholder_1[v_i] < T.int64(0),
+                                        rxplaceholder_1[v_i]
+                                        + T.Select(
+                                            v_i == T.int64(1),
+                                            n,
+                                            T.Select(
+                                                v_i == T.int64(0), 
T.int64(10), T.int64(-1)
+                                            ),
+                                        ),
+                                        rxplaceholder_1[v_i],
+                                    ),
+                                    T.int64(0),
+                                ),
+                                T.Select(
+                                    v_i == T.int64(1),
+                                    n,
+                                    T.Select(v_i == T.int64(0), T.int64(10), 
T.int64(-1)),
+                                ),
+                            )
+                            + rxplaceholder_2[v_i]
+                            - T.min(
+                                T.max(
+                                    T.Select(
+                                        rxplaceholder[v_i] < T.int64(0),
+                                        rxplaceholder[v_i]
+                                        + T.Select(
+                                            v_i == T.int64(1),
+                                            n,
+                                            T.Select(
+                                                v_i == T.int64(0), 
T.int64(10), T.int64(-1)
+                                            ),
+                                        ),
+                                        rxplaceholder[v_i],
+                                    ),
+                                    T.int64(0),
+                                ),
+                                T.Select(
+                                    v_i == T.int64(1),
+                                    n,
+                                    T.Select(v_i == T.int64(0), T.int64(10), 
T.int64(-1)),
+                                ),
+                            )
+                            - T.int64(1)
+                        )
+                        // rxplaceholder_2[v_i],
+                    )
+
+        @R.function
+        def main(
+            x: R.Tensor((10, "n"), dtype="float32"),
+            begin: R.Tensor((2,), dtype="int64"),
+            end: R.Tensor((2,), dtype="int64"),
+            strides: R.Tensor((2,), dtype="int64"),
+        ) -> R.Tensor(dtype="float32", ndim=2):
+            n = T.int64()
+            s = T.int64()
+            s_1 = T.int64()
+            gv = R.call_tir(
+                Expected.shape_func,
+                (x, begin, end, strides),
+                out_sinfo=R.Tensor((2,), dtype="int64"),
+            )
+            gv1: R.Shape(ndim=2) = R.call_packed(
+                "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=2),)
+            )
+            gv2: R.Shape([s, s_1]) = R.match_cast(gv1, R.Shape([s, s_1]))
+            gv_1 = R.call_tir(
+                Expected.dynamic_strided_slice,
+                (x, begin, end, strides),
+                out_sinfo=R.Tensor((s, s_1), dtype="float32"),
+            )
+            return gv_1
+    # fmt: on
+
+    mod = LegalizeOps()(DynamicStridedSlice)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 ##################### Linear algebra #####################
 
 
diff --git a/tests/python/topi/python/test_topi_transform.py 
b/tests/python/topi/python/test_topi_transform.py
index 5866ffd5f7..ac69a2c85e 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -457,6 +457,51 @@ def verify_dynamic_strided_slice(in_shape, begin, end, 
strides=None):
         check_device(target)
 
 
+def verify_relax_dynamic_strided_slice(in_shape, begin, end, strides, 
output_shape):
+    A = te.placeholder(shape=in_shape, name="A")
+    Begin = te.placeholder(shape=[len(in_shape)], name="begin", dtype="int64")
+    End = te.placeholder(shape=[len(in_shape)], name="end", dtype="int64")
+    Strides = te.placeholder(shape=[len(in_shape)], name="strides", 
dtype="int64")
+
+    B = topi.dynamic_strided_slice(A, Begin, End, Strides, output_shape) + 1
+
+    OutShape = topi.shape_func_dynamic_strided_slice(A, Begin, End, Strides)
+
+    def check_device(target):
+        dev = tvm.device(target, 0)
+        if not tvm.testing.device_enabled(target):
+            print("Skip because %s is not enabled" % target)
+            return
+        print("Running on target: %s" % target)
+        x_np = np.random.uniform(size=in_shape).astype(A.dtype)
+        out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, 
strides) + 1
+        data_nd = tvm.nd.array(x_np, dev)
+        out_nd = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype)
+        begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev)
+        end_nd = tvm.nd.array(np.array(end).astype("int64"), dev)
+        strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev)
+
+        if target == "llvm":
+            # Check shape func
+            s = tvm.te.create_schedule(OutShape.op)
+            bar = tvm.build(
+                s, [A, Begin, End, Strides, OutShape], target, 
name="shape_func_stride_slice"
+            )
+            out_shape_nd = tvm.nd.empty((len(out_npy.shape),), device=dev, 
dtype="int64")
+            bar(data_nd, begin_nd, end_nd, strides_nd, out_shape_nd)
+
+            tvm.testing.assert_allclose(out_shape_nd.numpy(), output_shape)
+
+        with tvm.target.Target(target):
+            s = tvm.topi.testing.get_injective_schedule(target)(B)
+        foo = tvm.build(s, [A, Begin, End, Strides, B], target, 
name="stride_slice")
+        foo(data_nd, begin_nd, end_nd, strides_nd, out_nd)
+        tvm.testing.assert_allclose(out_nd.numpy(), out_npy)
+
+    for target in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
+        check_device(target)
+
+
 def verify_strided_set(in_shape, v_shape, begin, end, strides=None):
     A = te.placeholder(shape=in_shape, name="A")
     V = te.placeholder(shape=v_shape, name="V")
@@ -859,6 +904,15 @@ def test_dynamic_strided_slice():
     verify_dynamic_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
 
 
+@tvm.testing.uses_gpu
+def test_relax_dynamic_strided_slice():
+    verify_relax_dynamic_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, 
-1, 2], [3, 1, 2])
+    verify_relax_dynamic_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 
1], [1, 3, 3])
+    verify_relax_dynamic_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 
2], [1, 2, 2])
+    verify_relax_dynamic_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [1, 1, 
1], [2, 3, 3])
+    verify_relax_dynamic_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3], [1, 1, 
1], [1, 0, 3])
+
+
 @tvm.testing.uses_gpu
 def test_strided_set():
     verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2])

Reply via email to