This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 605a61450b [Unity][Op] Support symbolic shape inference for slice op. 
(#15450)
605a61450b is described below

commit 605a61450b770e092da7cd4d50ee52cf6454c107
Author: Siyuan Feng <hzfen...@sjtu.edu.cn>
AuthorDate: Wed Aug 2 09:06:15 2023 +0800

    [Unity][Op] Support symbolic shape inference for slice op. (#15450)
    
    This PR improves two things:
    1. Support symbolic `begin` and `end` for slice op.
    2. Add a new attribute `assume_inbound` for slice op. If `assume_inbound`
    is set to True, the slice op will assume the `begin` and `end` are always
    inbound, which will simplify the shape deduction.
---
 include/tvm/relax/attrs/index.h     |  6 +++
 python/tvm/relax/op/index.py        |  8 +++-
 src/relax/op/tensor/index.cc        | 36 +++++++++--------
 src/relax/op/tensor/index.h         | 12 +++---
 tests/python/relax/test_op_index.py | 78 +++++++++++++++++++++++++++++++++----
 5 files changed, 108 insertions(+), 32 deletions(-)

diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h
index c95395a803..1043fe30ce 100644
--- a/include/tvm/relax/attrs/index.h
+++ b/include/tvm/relax/attrs/index.h
@@ -44,6 +44,7 @@ struct StridedSliceAttrs : public 
tvm::AttrsNode<StridedSliceAttrs> {
   Array<PrimExpr> begin;
   Array<PrimExpr> end;
   Optional<Array<PrimExpr>> strides;
+  bool assume_inbound;
 
   TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") {
     TVM_ATTR_FIELD(axes).describe("Axes along which slicing is applied.");
@@ -53,6 +54,11 @@ struct StridedSliceAttrs : public 
tvm::AttrsNode<StridedSliceAttrs> {
         "Specifies the stride values, it can be negative in that case, the 
input tensor will be "
         "reversed in that particular axis. If not specified, it by default is 
an list of ones of "
         "the same length as `axes`.");
+    TVM_ATTR_FIELD(assume_inbound)
+        .set_default(true)
+        .describe(
+            "Whether to assume the indices are in bound. If it is set to 
false, "
+            "out of bound indices will be clipped to the bound.");
   }
 };  // struct StridedSliceAttrs
 
diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py
index 835c9350b0..8504b4d683 100644
--- a/python/tvm/relax/op/index.py
+++ b/python/tvm/relax/op/index.py
@@ -58,6 +58,7 @@ def strided_slice(
     begin: List[PrimExprLike],
     end: List[PrimExprLike],
     strides: Optional[List[PrimExprLike]] = None,
+    assume_inbound: bool = False,
 ) -> Expr:
     """Strided slice of a tensor.
 
@@ -80,6 +81,9 @@ def strided_slice(
         the input tensor will be reversed in that particular axis.
         If not specified, it by default is an list of ones of the same length 
as `axes`.
 
+    assume_inbound : bool
+        Whether to assume the indices are in bound. If it is set to false,
+        out of bound indices will be clipped to the bound.
     Returns
     -------
     ret : relax.Expr
@@ -90,7 +94,7 @@ def strided_slice(
     strided_slice require the input `begin`, `end` and `strides` to have the
     same length as `axes`.
     """
-    return _ffi_api.strided_slice(x, axes, begin, end, strides)  # type: ignore
+    return _ffi_api.strided_slice(x, axes, begin, end, strides, 
assume_inbound)  # type: ignore
 
 
 def dynamic_strided_slice(
@@ -99,7 +103,7 @@ def dynamic_strided_slice(
     end: Expr,
     strides: Expr,
 ) -> Expr:
-    """Dynamic strided slice of a tensor. `begin`, `end`, `strids` can be 
computed at runtime.
+    """Dynamic strided slice of a tensor. `begin`, `end`, `strides` can be 
computed at runtime.
 
     Parameters
     ----------
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index a9c61bb56a..5f1d5149b3 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -101,11 +101,12 @@ TVM_REGISTER_OP("relax.take")
 /* relax.strided_slice */
 TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
 
-Expr strided_slice(Expr x,                 //
-                   Array<Integer> axes,    //
-                   Array<PrimExpr> begin,  //
-                   Array<PrimExpr> end,    //
-                   Optional<Array<PrimExpr>> strides) {
+Expr strided_slice(Expr x,                             //
+                   Array<Integer> axes,                //
+                   Array<PrimExpr> begin,              //
+                   Array<PrimExpr> end,                //
+                   Optional<Array<PrimExpr>> strides,  //
+                   bool assume_inbound) {
   int n_axis = axes.size();
   CHECK_EQ(static_cast<int>(begin.size()), n_axis)
       << "StridedSlice requires the number of begin indices to equal the 
number of axes.";
@@ -141,6 +142,7 @@ Expr strided_slice(Expr x,                 //
   attrs->begin = begin.Map(f_convert_to_int64);
   attrs->end = end.Map(f_convert_to_int64);
   attrs->strides = strides.defined() ? strides.value().Map(f_convert_to_int64) 
: strides;
+  attrs->assume_inbound = assume_inbound;
 
   static const Op& op = Op::Get("relax.strided_slice");
   return Call(op, {std::move(x)}, Attrs(attrs), {});
@@ -148,23 +150,25 @@ Expr strided_slice(Expr x,                 //
 
 TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice);
 
-inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t 
stride) {
+inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t 
stride,
+                                  bool assume_inbound) {
   // Same as topi strided slice CanonicalizeIndex function in
   // include/tvm/topi/detail/strided_slice.h
   PrimExpr begin_range = stride < 0 ? -1 : 0;
   PrimExpr end_range = stride < 0 ? extent - 1 : extent;
   index = if_then_else(index < 0, index + extent, index);
-  return min(max(index, begin_range), end_range);  // NOLINT
+  return assume_inbound ? index : min(max(index, begin_range), end_range);  // 
NOLINT
 }
 
-PrimExpr GetLength(PrimExpr begin, PrimExpr end, const int64_t stride, const 
PrimExpr& length) {
-  begin = CanonicalizeIndex(begin, length, stride);
-  end = CanonicalizeIndex(end, length, stride);
-
+PrimExpr GetLength(PrimExpr begin, PrimExpr end, const int64_t stride, const 
PrimExpr& length,
+                   bool assume_inbound) {
+  begin = CanonicalizeIndex(begin, length, stride, assume_inbound);
+  end = CanonicalizeIndex(end, length, stride, assume_inbound);
+  arith::Analyzer ana;
   if (stride < 0) {
-    return ceildiv(begin - end, IntImm(DataType::Int(64), -stride));
+    return ana.Simplify(ceildiv(begin - end, IntImm(DataType::Int(64), 
-stride)));
   } else {
-    return ceildiv(end - begin, IntImm(DataType::Int(64), stride));
+    return ana.Simplify(ceildiv(end - begin, IntImm(DataType::Int(64), 
stride)));
   }
 }
 
@@ -193,10 +197,8 @@ StructInfo InferStructInfoStridedSlice(const Call& call, 
const BlockBuilder& ctx
   int_strides.reserve(n_axis);
   // Only do output shape inference when all the begin/end/strides values are 
integers.
   for (int i = 0; i < n_axis; ++i) {
-    const auto* int_begin = attrs->begin[i].as<IntImmNode>();
-    const auto* int_end = attrs->end[i].as<IntImmNode>();
     const auto* int_stride = strides[i].as<IntImmNode>();
-    if (!int_begin || !int_end || !int_stride) {
+    if (!int_stride) {
       return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
     }
     int_strides.push_back(int_stride->value);
@@ -207,7 +209,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, 
const BlockBuilder& ctx
     ICHECK_NE(int_strides[i], 0)
         << "Strided slice requires strides to be non-zero but got 0 for axis " 
<< axes[i] << ".";
     output_shape.Set(axes[i], GetLength(attrs->begin[i], attrs->end[i], 
int_strides[i],
-                                        data_shape->values[axes[i]]));
+                                        data_shape->values[axes[i]], 
attrs->assume_inbound));
   }
   return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
 }
diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h
index 6944493a0f..c8c7428f48 100644
--- a/src/relax/op/tensor/index.h
+++ b/src/relax/op/tensor/index.h
@@ -51,13 +51,15 @@ Expr take(Expr x, Expr indices, Optional<Integer> axis);
  * \param strides Specifies the stride values, it can be negative in that case,
  * the input tensor will be reversed in that particular axis.
  * If it is `NullOpt`, it by default is an list of ones of the same length as 
`axes`.
+ * \param assume_inbound Whether to assume the indices are in bound.
  * \return The sliced result
  */
-Expr strided_slice(Expr x,                 //
-                   Array<Integer> axes,    //
-                   Array<PrimExpr> begin,  //
-                   Array<PrimExpr> end,    //
-                   Optional<Array<PrimExpr>> strides);
+Expr strided_slice(Expr x,                             //
+                   Array<Integer> axes,                //
+                   Array<PrimExpr> begin,              //
+                   Array<PrimExpr> end,                //
+                   Optional<Array<PrimExpr>> strides,  //
+                   bool assume_inbound = false);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/tests/python/relax/test_op_index.py 
b/tests/python/relax/test_op_index.py
index 8b2f8c0b29..cc09b266f5 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -461,22 +461,22 @@ def test_strided_slice_infer_struct_info_shape_symbolic():
     _check_inference(
         bb,
         relax.op.strided_slice(x0, axes=[0], begin=[1], end=[3]),
-        relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m) + 1 - 1, n), 
"float32"),
+        relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m), n), "float32"),
     )
     _check_inference(
         bb,
         relax.op.strided_slice(x0, axes=[0], begin=[1], end=[8], strides=[3]),
-        relax.TensorStructInfo(((tir.min(8, m) - tir.min(1, m) + 3 - 1) // 3, 
n), "float32"),
+        relax.TensorStructInfo(((tir.min(8, m) + 2 - tir.min(1, m)) // 3, n), 
"float32"),
     )
     _check_inference(
         bb,
         relax.op.strided_slice(x1, axes=[0], begin=[1], end=[3]),
-        relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m) + 1 - 1, n), 
dtype=""),
+        relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m), n), dtype=""),
     )
     _check_inference(
         bb,
         relax.op.strided_slice(x1, axes=[0], begin=[1], end=[8], strides=[3]),
-        relax.TensorStructInfo(((tir.min(8, m) - tir.min(1, m) + 3 - 1) // 3, 
n), dtype=""),
+        relax.TensorStructInfo(((tir.min(8, m) + 2 - tir.min(1, m)) // 3, n), 
dtype=""),
     )
 
 
@@ -549,22 +549,84 @@ def 
test_strided_slice_infer_struct_info_more_input_dtype():
 
 def test_strided_slice_infer_struct_info_symbolic_begin_end_strides():
     bb = relax.BlockBuilder()
-    a = tir.Var("a", "int64")
+    var = tir.Var("var", "int64")
+    size_var = tir.SizeVar("size_var", "int64")
     x = relax.Var("x", R.Tensor((8, 9), "float32"))
 
     _check_inference(
         bb,
-        relax.op.strided_slice(x, axes=[0], begin=[a], end=[8]),
+        relax.op.strided_slice(x, axes=[0], begin=[var], end=[8]),
+        relax.TensorStructInfo(
+            (tir.max(8 - tir.max(tir.if_then_else(var < 0, var + 8, var), 0), 
0), 9),
+            dtype="float32",
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[size_var], end=[8]),
+        relax.TensorStructInfo((tir.max(8 - size_var, 0), 9), dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[0], end=[var]),
+        relax.TensorStructInfo(
+            (tir.min(tir.max(tir.if_then_else(var < 0, var + 8, var), 0), 8), 
9), dtype="float32"
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[0], end=[size_var]),
+        relax.TensorStructInfo((tir.min(size_var, 8), 9), dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var]),
+        relax.TensorStructInfo(dtype="float32", ndim=2),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], 
strides=[size_var]),
         relax.TensorStructInfo(dtype="float32", ndim=2),
     )
+
+
+def test_strided_slice_infer_struct_info_symbolic_begin_end_strides_inbound():
+    bb = relax.BlockBuilder()
+    var = tir.Var("var", "int64")
+    size_var = tir.SizeVar("size_var", "int64")
+    x = relax.Var("x", R.Tensor((8, 9), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[var], end=[8], 
assume_inbound=True),
+        relax.TensorStructInfo(
+            (8 - tir.if_then_else(var < 0, var + 8, var), 9),
+            dtype="float32",
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[size_var], end=[8], 
assume_inbound=True),
+        relax.TensorStructInfo((8 - size_var, 9), dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[0], end=[var], 
assume_inbound=True),
+        relax.TensorStructInfo((tir.if_then_else(var < 0, var + 8, var), 9), 
dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.strided_slice(x, axes=[0], begin=[0], end=[size_var], 
assume_inbound=True),
+        relax.TensorStructInfo((size_var, 9), dtype="float32"),
+    )
     _check_inference(
         bb,
-        relax.op.strided_slice(x, axes=[0], begin=[0], end=[a]),
+        relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var], 
assume_inbound=True),
         relax.TensorStructInfo(dtype="float32", ndim=2),
     )
     _check_inference(
         bb,
-        relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[a]),
+        relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var], 
assume_inbound=True),
         relax.TensorStructInfo(dtype="float32", ndim=2),
     )
 

Reply via email to