This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 605a61450b [Unity][Op] Support symbolic shape inference for slice op. (#15450) 605a61450b is described below commit 605a61450b770e092da7cd4d50ee52cf6454c107 Author: Siyuan Feng <hzfen...@sjtu.edu.cn> AuthorDate: Wed Aug 2 09:06:15 2023 +0800 [Unity][Op] Support symbolic shape inference for slice op. (#15450) This PR improves two things: 1. Support symbolic `begin` and `end` for slice op. 2. Add a new attribute `assume_inbound` for slice op. If `assume_inbound` is set to True, the slice op will assume the `begin` and `end` are always inbound, which will simplify the shape deduction. --- include/tvm/relax/attrs/index.h | 6 +++ python/tvm/relax/op/index.py | 8 +++- src/relax/op/tensor/index.cc | 36 +++++++++-------- src/relax/op/tensor/index.h | 12 +++--- tests/python/relax/test_op_index.py | 78 +++++++++++++++++++++++++++++++++---- 5 files changed, 108 insertions(+), 32 deletions(-) diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h index c95395a803..1043fe30ce 100644 --- a/include/tvm/relax/attrs/index.h +++ b/include/tvm/relax/attrs/index.h @@ -44,6 +44,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> { Array<PrimExpr> begin; Array<PrimExpr> end; Optional<Array<PrimExpr>> strides; + bool assume_inbound; TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") { TVM_ATTR_FIELD(axes).describe("Axes along which slicing is applied."); @@ -53,6 +54,11 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> { "Specifies the stride values, it can be negative in that case, the input tensor will be " "reversed in that particular axis. If not specified, it by default is an list of ones of " "the same length as `axes`."); + TVM_ATTR_FIELD(assume_inbound) + .set_default(true) + .describe( + "Whether to assume the indices are in bound. If it is set to false, " + "out of bound indices will be clipped to the bound."); } }; // struct StridedSliceAttrs diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py index 835c9350b0..8504b4d683 100644 --- a/python/tvm/relax/op/index.py +++ b/python/tvm/relax/op/index.py @@ -58,6 +58,7 @@ def strided_slice( begin: List[PrimExprLike], end: List[PrimExprLike], strides: Optional[List[PrimExprLike]] = None, + assume_inbound: bool = False, ) -> Expr: """Strided slice of a tensor. @@ -80,6 +81,9 @@ def strided_slice( the input tensor will be reversed in that particular axis. If not specified, it by default is an list of ones of the same length as `axes`. + assume_inbound : bool + Whether to assume the indices are in bound. If it is set to false, + out of bound indices will be clipped to the bound. Returns ------- ret : relax.Expr @@ -90,7 +94,7 @@ def strided_slice( strided_slice require the input `begin`, `end` and `strides` to have the same length as `axes`. """ - return _ffi_api.strided_slice(x, axes, begin, end, strides) # type: ignore + return _ffi_api.strided_slice(x, axes, begin, end, strides, assume_inbound) # type: ignore def dynamic_strided_slice( @@ -99,7 +103,7 @@ def dynamic_strided_slice( end: Expr, strides: Expr, ) -> Expr: - """Dynamic strided slice of a tensor. `begin`, `end`, `strids` can be computed at runtime. + """Dynamic strided slice of a tensor. `begin`, `end`, `strides` can be computed at runtime. Parameters ---------- diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index a9c61bb56a..5f1d5149b3 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -101,11 +101,12 @@ TVM_REGISTER_OP("relax.take") /* relax.strided_slice */ TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); -Expr strided_slice(Expr x, // - Array<Integer> axes, // - Array<PrimExpr> begin, // - Array<PrimExpr> end, // - Optional<Array<PrimExpr>> strides) { +Expr strided_slice(Expr x, // + Array<Integer> axes, // + Array<PrimExpr> begin, // + Array<PrimExpr> end, // + Optional<Array<PrimExpr>> strides, // + bool assume_inbound) { int n_axis = axes.size(); CHECK_EQ(static_cast<int>(begin.size()), n_axis) << "StridedSlice requires the number of begin indices to equal the number of axes."; @@ -141,6 +142,7 @@ Expr strided_slice(Expr x, // attrs->begin = begin.Map(f_convert_to_int64); attrs->end = end.Map(f_convert_to_int64); attrs->strides = strides.defined() ? strides.value().Map(f_convert_to_int64) : strides; + attrs->assume_inbound = assume_inbound; static const Op& op = Op::Get("relax.strided_slice"); return Call(op, {std::move(x)}, Attrs(attrs), {}); @@ -148,23 +150,25 @@ Expr strided_slice(Expr x, // TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice); -inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t stride) { +inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t stride, + bool assume_inbound) { // Same as topi strided slice CanonicalizeIndex function in // include/tvm/topi/detail/strided_slice.h PrimExpr begin_range = stride < 0 ? -1 : 0; PrimExpr end_range = stride < 0 ? extent - 1 : extent; index = if_then_else(index < 0, index + extent, index); - return min(max(index, begin_range), end_range); // NOLINT + return assume_inbound ? index : min(max(index, begin_range), end_range); // NOLINT } -PrimExpr GetLength(PrimExpr begin, PrimExpr end, const int64_t stride, const PrimExpr& length) { - begin = CanonicalizeIndex(begin, length, stride); - end = CanonicalizeIndex(end, length, stride); - +PrimExpr GetLength(PrimExpr begin, PrimExpr end, const int64_t stride, const PrimExpr& length, + bool assume_inbound) { + begin = CanonicalizeIndex(begin, length, stride, assume_inbound); + end = CanonicalizeIndex(end, length, stride, assume_inbound); + arith::Analyzer ana; if (stride < 0) { - return ceildiv(begin - end, IntImm(DataType::Int(64), -stride)); + return ana.Simplify(ceildiv(begin - end, IntImm(DataType::Int(64), -stride))); } else { - return ceildiv(end - begin, IntImm(DataType::Int(64), stride)); + return ana.Simplify(ceildiv(end - begin, IntImm(DataType::Int(64), stride))); } } @@ -193,10 +197,8 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx int_strides.reserve(n_axis); // Only do output shape inference when all the begin/end/strides values are integers. for (int i = 0; i < n_axis; ++i) { - const auto* int_begin = attrs->begin[i].as<IntImmNode>(); - const auto* int_end = attrs->end[i].as<IntImmNode>(); const auto* int_stride = strides[i].as<IntImmNode>(); - if (!int_begin || !int_end || !int_stride) { + if (!int_stride) { return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); } int_strides.push_back(int_stride->value); @@ -207,7 +209,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx ICHECK_NE(int_strides[i], 0) << "Strided slice requires strides to be non-zero but got 0 for axis " << axes[i] << "."; output_shape.Set(axes[i], GetLength(attrs->begin[i], attrs->end[i], int_strides[i], - data_shape->values[axes[i]])); + data_shape->values[axes[i]], attrs->assume_inbound)); } return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); } diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h index 6944493a0f..c8c7428f48 100644 --- a/src/relax/op/tensor/index.h +++ b/src/relax/op/tensor/index.h @@ -51,13 +51,15 @@ Expr take(Expr x, Expr indices, Optional<Integer> axis); * \param strides Specifies the stride values, it can be negative in that case, * the input tensor will be reversed in that particular axis. * If it is `NullOpt`, it by default is an list of ones of the same length as `axes`. + * \param assume_inbound Whether to assume the indices are in bound. * \return The sliced result */ -Expr strided_slice(Expr x, // - Array<Integer> axes, // - Array<PrimExpr> begin, // - Array<PrimExpr> end, // - Optional<Array<PrimExpr>> strides); +Expr strided_slice(Expr x, // + Array<Integer> axes, // + Array<PrimExpr> begin, // + Array<PrimExpr> end, // + Optional<Array<PrimExpr>> strides, // + bool assume_inbound = false); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py index 8b2f8c0b29..cc09b266f5 100644 --- a/tests/python/relax/test_op_index.py +++ b/tests/python/relax/test_op_index.py @@ -461,22 +461,22 @@ def test_strided_slice_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.strided_slice(x0, axes=[0], begin=[1], end=[3]), - relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m) + 1 - 1, n), "float32"), + relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m), n), "float32"), ) _check_inference( bb, relax.op.strided_slice(x0, axes=[0], begin=[1], end=[8], strides=[3]), - relax.TensorStructInfo(((tir.min(8, m) - tir.min(1, m) + 3 - 1) // 3, n), "float32"), + relax.TensorStructInfo(((tir.min(8, m) + 2 - tir.min(1, m)) // 3, n), "float32"), ) _check_inference( bb, relax.op.strided_slice(x1, axes=[0], begin=[1], end=[3]), - relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m) + 1 - 1, n), dtype=""), + relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m), n), dtype=""), ) _check_inference( bb, relax.op.strided_slice(x1, axes=[0], begin=[1], end=[8], strides=[3]), - relax.TensorStructInfo(((tir.min(8, m) - tir.min(1, m) + 3 - 1) // 3, n), dtype=""), + relax.TensorStructInfo(((tir.min(8, m) + 2 - tir.min(1, m)) // 3, n), dtype=""), ) @@ -549,22 +549,84 @@ def test_strided_slice_infer_struct_info_more_input_dtype(): def test_strided_slice_infer_struct_info_symbolic_begin_end_strides(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") + var = tir.Var("var", "int64") + size_var = tir.SizeVar("size_var", "int64") x = relax.Var("x", R.Tensor((8, 9), "float32")) _check_inference( bb, - relax.op.strided_slice(x, axes=[0], begin=[a], end=[8]), + relax.op.strided_slice(x, axes=[0], begin=[var], end=[8]), + relax.TensorStructInfo( + (tir.max(8 - tir.max(tir.if_then_else(var < 0, var + 8, var), 0), 0), 9), + dtype="float32", + ), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[size_var], end=[8]), + relax.TensorStructInfo((tir.max(8 - size_var, 0), 9), dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[var]), + relax.TensorStructInfo( + (tir.min(tir.max(tir.if_then_else(var < 0, var + 8, var), 0), 8), 9), dtype="float32" + ), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[size_var]), + relax.TensorStructInfo((tir.min(size_var, 8), 9), dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[size_var]), relax.TensorStructInfo(dtype="float32", ndim=2), ) + + +def test_strided_slice_infer_struct_info_symbolic_begin_end_strides_inbound(): + bb = relax.BlockBuilder() + var = tir.Var("var", "int64") + size_var = tir.SizeVar("size_var", "int64") + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[var], end=[8], assume_inbound=True), + relax.TensorStructInfo( + (8 - tir.if_then_else(var < 0, var + 8, var), 9), + dtype="float32", + ), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[size_var], end=[8], assume_inbound=True), + relax.TensorStructInfo((8 - size_var, 9), dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[var], assume_inbound=True), + relax.TensorStructInfo((tir.if_then_else(var < 0, var + 8, var), 9), dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[size_var], assume_inbound=True), + relax.TensorStructInfo((size_var, 9), dtype="float32"), + ) _check_inference( bb, - relax.op.strided_slice(x, axes=[0], begin=[0], end=[a]), + relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var], assume_inbound=True), relax.TensorStructInfo(dtype="float32", ndim=2), ) _check_inference( bb, - relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[a]), + relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var], assume_inbound=True), relax.TensorStructInfo(dtype="float32", ndim=2), )