This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 7db0b984de [Unity][Op] Dynamic Strided Slice (#14548) 7db0b984de is described below commit 7db0b984de6dd7d9ee7db79681476bc2bb7f26d5 Author: Sunghyun Park <49998730+sun...@users.noreply.github.com> AuthorDate: Wed Apr 12 13:04:04 2023 -0700 [Unity][Op] Dynamic Strided Slice (#14548) * feat: dyn_strided_slice op * feat: shape computation * feat: legalizer for dynamic strided slice * remove whitespace * reflect feedback * fix * fix * remove whitespace --- include/tvm/topi/transform.h | 37 +- python/tvm/relax/op/index.py | 37 ++ python/tvm/relax/transform/legalize_ops/index.py | 68 ++- python/tvm/script/ir_builder/relax/ir.py | 3 +- python/tvm/topi/transform.py | 35 ++ src/relax/op/tensor/index.cc | 78 +++- src/relax/transform/legalize_ops.cc | 10 +- src/topi/transform.cc | 8 + tests/python/relax/test_e2e_op_dynamic.py | 104 +++++ tests/python/relax/test_op_index.py | 197 +++++++++ ..._transform_legalize_ops_index_linear_algebra.py | 489 +++++++++++++++++++++ tests/python/topi/python/test_topi_transform.py | 54 +++ 12 files changed, 1111 insertions(+), 9 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 0490ae7f1e..579dbb5833 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -2021,7 +2021,6 @@ inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices, for (size_t i = 0; i < broadcast_shape.size(); ++i) { tensor_indices.push_back(iter_var[i]); } - Array<PrimExpr> real_indices; for (size_t i = 0; i < bindices.size(); ++i) { real_indices.push_back(bindices[i](tensor_indices)); @@ -2035,6 +2034,42 @@ inline Tensor adv_index(const Tensor& data, const Array<Tensor>& indices, name, tag); } +namespace relax { +// relax dynamic slice +inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin, + const te::Tensor& end, const te::Tensor& strides, + Array<PrimExpr> output_shape, + std::string name = "T_strided_slice_dynamic", + std::string tag = kInjective) { + const size_t num_dynamic_axes = x.ndim(); + ICHECK_EQ(begin.ndim(), 1); + ICHECK_EQ(end.ndim(), 1); + ICHECK_EQ(strides.ndim(), 1); + const auto* len_begin = begin->shape[0].as<IntImmNode>(); + const auto* len_end = end->shape[0].as<IntImmNode>(); + const auto* len_strides = strides->shape[0].as<IntImmNode>(); + ICHECK(len_begin); + ICHECK(len_end); + ICHECK(len_strides); + ICHECK_EQ(len_begin->value, num_dynamic_axes); + ICHECK_EQ(len_end->value, num_dynamic_axes); + ICHECK_EQ(len_strides->value, num_dynamic_axes); + + return te::compute( + output_shape, + [&](const Array<tvm::tir::Var>& indices) { + Array<PrimExpr> real_indices; + for (size_t i = 0; i < num_dynamic_axes; ++i) { + auto ind = make_const(DataType::Int(64), i); + real_indices.push_back(indices[i] * strides(ind) + tvm::min(begin(ind), x->shape[i] - 1)); + } + return x(real_indices); + }, + name, tag); +} + +} // namespace relax + } // namespace topi } // namespace tvm #endif // TVM_TOPI_TRANSFORM_H_ diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py index b9acf11b9f..835c9350b0 100644 --- a/python/tvm/relax/op/index.py +++ b/python/tvm/relax/op/index.py @@ -91,3 +91,40 @@ def strided_slice( same length as `axes`. """ return _ffi_api.strided_slice(x, axes, begin, end, strides) # type: ignore + + +def dynamic_strided_slice( + x: Expr, + begin: Expr, + end: Expr, + strides: Expr, +) -> Expr: + """Dynamic strided slice of a tensor. `begin`, `end`, `strids` can be computed at runtime. + + Parameters + ---------- + x : Expr + The source tensor to be sliced. + + begin : Expr + The indices to begin with in the slicing, inclusive. + + end : Expr + The indices indicating end of the slice, exclusive. + + strides : Expr + Specifies the stride values, it can be negative in that case, + the input tensor will be reversed in that particular axis. + If not specified, it by default is an list of ones of the same length as `axes`. + + Returns + ------- + ret : relax.Expr + The sliced result. + + Note + ---- + dyn_strided_slice require the input `begin`, `end` and `strides` to have the + same length as rank of `data` tensor. + """ + return _ffi_api.dynamic_strided_slice(x, begin, end, strides) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/index.py b/python/tvm/relax/transform/legalize_ops/index.py index eccccc7c6d..8ee1bed9b9 100644 --- a/python/tvm/relax/transform/legalize_ops/index.py +++ b/python/tvm/relax/transform/legalize_ops/index.py @@ -18,9 +18,10 @@ """Default legalization function for index operators.""" import logging -from tvm import topi, tir +from tvm import topi, tir, te from ...block_builder import BlockBuilder -from ...expr import Call, Expr +from ...expr import Call, Expr, ExternFunc +from ...struct_info import ShapeStructInfo from .common import register_legalize @@ -59,3 +60,66 @@ def _strided_slice(bb: BlockBuilder, call: Call) -> Expr: call.attrs.axes, slice_mode="end", ) + + +@register_legalize("relax.dynamic_strided_slice") +def _dynamic_strided_slice(bb: BlockBuilder, call: Call) -> Expr: + assert len(call.args) == 4 + data, begin, end, strides = call.args + + # 1. Insert shape function + def shape_func(data, begin, end, strides): + def _compute(i): + def canonicalize_index(index, extent, strides): + begin_range = tir.Select(strides < 0, tir.const(-1, "int64"), tir.const(0, "int64")) + end_range = tir.Select(strides < 0, extent - 1, extent) + index = tir.Select(index < 0, index + extent, index) + return tir.Min(tir.Max(index, begin_range), end_range) + + def get_length(begin, end, strides, length): + begin = canonicalize_index(begin, length, strides) + end = canonicalize_index(end, length, strides) + len1 = tir.ceildiv(begin - end, -strides) + len2 = tir.ceildiv(end - begin, strides) + return tir.Select(strides < 0, len1, len2) + + length = tir.const(-1, "int64") + for idx in range(data.ndim): + length = tir.Select(i == tir.const(idx, "int64"), data.shape[idx], length) + + return get_length(begin[i], end[i], strides[i], length) + + return te.compute((begin.shape[0],), _compute, name="T_shape_func_strided_slice_dynamic") + + output_shape = bb.normalize( + bb.call_te( + shape_func, + data, + begin, + end, + strides, + ) + ) + + # 2. Convert tensor to shape and match cast with new symbolic vars + # Get shape length + ndim = int(output_shape.struct_info.shape[0]) + output_shape = bb.emit( + Call( + ExternFunc("vm.builtin.tensor_to_shape"), + [output_shape], + sinfo_args=[ShapeStructInfo(ndim=ndim)], + ) + ) + output_shape_vars = [tir.Var("s", "int64") for i in range(ndim)] + bb.match_cast(output_shape, ShapeStructInfo(output_shape_vars)) + + # 3. Pass the output shape vars to TOPI + return bb.call_te( + topi.dynamic_strided_slice, + call.args[0], + call.args[1], + call.args[2], + call.args[3], + output_shape=output_shape_vars, + ) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e390658c8f..39327c4b4a 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -111,6 +111,7 @@ from tvm.relax.op import ( shape_of, std, strided_slice, + dynamic_strided_slice, sum, take, variance, @@ -639,7 +640,6 @@ __all__ = [ "ShapeExpr", "std", "str", - "strided_slice", "sum", "sigmoid", "sign", @@ -652,6 +652,7 @@ __all__ = [ "stop_lift_params", "str", "strided_slice", + "dynamic_strided_slice", "subtract", "take", "tan", diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index e4fe3c5839..7807351e90 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -227,6 +227,41 @@ def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"): return cpp.strided_slice(a, begin, end, strides, axes, slice_mode) +def dynamic_strided_slice(a, begin, end, strides, output_shape): + """Slice of an array. + + Parameters + ---------- + a : tvm.te.Tensor + The tensor to be sliced. + + begin : tvm.te.Tensor + The indices to begin with in the slicing. + + end : tvm.te.Tensor + Indices indicating end of the slice. + + strides : tvm.te.Tensor + Specifies the stride values, it can be negative + in that case, the input tensor will be reversed + in that particular axis. + + output_shape: list of PrimExpr + Specifies the output shape + + Returns + ------- + ret : tvm.te.Tensor + """ + if not isinstance(begin, tvm.te.Tensor): + begin = const_vector(begin) + if not isinstance(end, tvm.te.Tensor): + end = const_vector(end) + if not isinstance(strides, tvm.te.Tensor): + strides = const_vector(strides) + return cpp.relax_dynamic_strided_slice(a, begin, end, strides, output_shape) + + @tvm.te.tag_scope(tag=tag.INJECTIVE + ",strided_set") def strided_set(a, v, begin, end, strides=None): """Set slice of an array. diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index b627e1425b..c3d38db4e1 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -190,7 +190,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx : Array<PrimExpr>(n_axis, IntImm(DataType::Int(64), 1)); std::vector<int64_t> int_strides; int_strides.reserve(n_axis); - // Only do output shape inference when all the begin/end/stride values are integers. + // Only do output shape inference when all the begin/end/strides values are integers. for (int i = 0; i < n_axis; ++i) { const auto* int_begin = attrs->begin[i].as<IntImmNode>(); const auto* int_end = attrs->end[i].as<IntImmNode>(); @@ -204,7 +204,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx Array<PrimExpr> output_shape = data_shape->values; for (int i = 0; i < n_axis; ++i) { ICHECK_NE(int_strides[i], 0) - << "Strided slice requires stride to be non-zero but got 0 for axis " << axes[i] << "."; + << "Strided slice requires strides to be non-zero but got 0 for axis " << axes[i] << "."; output_shape.Set(axes[i], GetLength(attrs->begin[i], attrs->end[i], int_strides[i], data_shape->values[axes[i]])); } @@ -239,5 +239,79 @@ TVM_REGISTER_OP("relax.strided_slice") .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice) .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); +/* relax.dynamic_strided_slice */ +Expr dynamic_strided_slice(Expr x, // + Expr begin, // + Expr end, // + Expr strides) { + static const Op& op = Op::Get("relax.dynamic_strided_slice"); + return Call(op, {std::move(x), std::move(begin), std::move(end), std::move(strides)}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice); + +StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder& ctx) { + const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]); + const auto* begin_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]); + const auto* end_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]); + const auto* strides_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[3]); + + ICHECK(data_sinfo); + if (data_sinfo->IsUnknownNdim()) { + LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes begin/end/strides " + "tensors are well-formed. It could produce runtime error when this assumption " + "turns out to be wrong."; + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + if (data_sinfo->IsUnknownDtype()) { + LOG(WARNING) << "When data type is unknown, dynamic strided slice assumes to have a valid " + "dtype. It could produce runtime error when this assumption " + "turns out to be wrong."; + } + + int n_axis = data_sinfo->ndim; + auto diag_def = [&](const TensorStructInfoNode* sinfo, String name) { + ICHECK(sinfo) << "Dynamic strided slice requires the input " << name + << " to be have the struct info. Please try normalizing the inputs."; + CHECK_EQ(sinfo->ndim, 1) << "Dynamic strided slice requires " << name + << " to be 1d tensor (list of values)."; + const auto* shape = sinfo->shape.as<ShapeExprNode>(); + ICHECK(shape) << "Dynamic strided slice requires the input " << name + << " to have well-defined shape."; + // NOTE(tvm-team): This strong restriction seems necessary for now until we have a generic + // solution in converting 1d Tensor with unknown num_elem to Array<PrimExpr>. + const auto* num_elem = shape->values[0].as<IntImmNode>(); + ICHECK(num_elem) << "Dynamic strided slice requires the input " << name + << " to have a known integer shape value."; + CHECK_EQ(num_elem->value, n_axis) << "Dynamic strided slice requires the number of indices in " + << name << " to equal the number of axes."; + if (sinfo->IsUnknownDtype()) { + LOG(WARNING) << "Dynamic strided slice assumes " << name + << " to be int64 when it is not specified."; + } else { + CHECK(sinfo->dtype == DataType::Int(64)) + << "Dynamic strided_slice expects the input " << name + << "values to be all int64. However, " << name << " has dtype " << sinfo->dtype << "."; + } + }; + diag_def(begin_sinfo, "begin"); + diag_def(end_sinfo, "end"); + diag_def(strides_sinfo, "strides"); + + // The output shape will depend on the runtime value in begin/end/strides tensors. + // TODO(tvm-team): Currently, it is unable to express partially-static shape. Revisit when + // PrimValue lands. + return TensorStructInfo(data_sinfo->dtype, n_axis); +} // namespace relax + +// TODO(tvm-team): Register FRelaxInferLayout, TMixedPrecisionPolicy +TVM_REGISTER_OP("relax.dynamic_strided_slice") + .set_num_inputs(4) + .add_argument("x", "Tensor", "The source tensor to be sliced.") + .add_argument("begin", "Tensor", "The indices to begin with in the slicing.") + .add_argument("end", "Tensor", "Indices indicating end of the slice.") + .add_argument("strides", "Tensor", "The stride values.") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDynStridedSlice); + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 350a40c37b..7c5393c6ca 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -83,15 +83,19 @@ class LegalizeMutator : public ExprMutator { return visited_call; } + auto op = GetRef<Op>(op_node); + std::string op_name(op->name); + bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos); // Not all shape values are known + // Data-dependent ops are exception since their output shape will be identified at runtime. + // Legalizer will insert their shape functions, which are manually registered, and match cast + // to define symbolic output shape at compile time. if (!std::all_of(visited_call->args.begin(), visited_call->args.end(), [](Expr arg) { return KnowAllShapeValues(GetStructInfo(arg)); }) || - !KnowAllShapeValues(GetStructInfo(visited_call))) { + (!is_data_dependent_op && !KnowAllShapeValues(GetStructInfo(visited_call)))) { return visited_call; } - auto op = GetRef<Op>(op_node); - // Priority: customize > default. // Check if it has customize legalization registered. if (cmap_.defined() && cmap_.value().count(op->name)) { diff --git a/src/topi/transform.cc b/src/topi/transform.cc index bbefa19c20..906f2b9115 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -201,6 +201,14 @@ TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice").set_body([](TVMArgs args, TVMR *rv = dynamic_strided_slice(args[0], begin, end, strides); }); +TVM_REGISTER_GLOBAL("topi.relax_dynamic_strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { + te::Tensor begin = args[1]; + te::Tensor end = args[2]; + te::Tensor strides = args[3]; + Array<PrimExpr> output_shape = args[4]; + *rv = relax::dynamic_strided_slice(args[0], begin, end, strides, output_shape); +}); + TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { int depth = args[3]; int axis = args[4]; diff --git a/tests/python/relax/test_e2e_op_dynamic.py b/tests/python/relax/test_e2e_op_dynamic.py new file mode 100644 index 0000000000..1e9414c15d --- /dev/null +++ b/tests/python/relax/test_e2e_op_dynamic.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +import tvm +from tvm import relax +import tvm.topi.testing +from tvm.relax.transform import LegalizeOps +from tvm.script import relax as R, tir as T +import tvm.testing + +# TODO(tvm-team): `tir.transform.DefaultGPUSchedule` does not work. +target, dev = "llvm", tvm.cpu() + + +def build(mod): + exe = relax.build(mod, target=target) + return relax.VirtualMachine(exe, dev) + + +@pytest.mark.parametrize( + "begin, end, strides", + [ + ([0, 2, 4, 4], [5, 5, 7, 8], [1, 1, 2, 3]), + ([0, 2, 4, 4], [5, 5, 11, 10], [1, 1, 1, 1]), + ([0, 2, 10, 14], [0, 5, 1, 1], [1, 1, -1, -2]), + ], +) +def test_dynamic_strided_slice(begin, end, strides): + # fmt: off + @tvm.script.ir_module + class DynamicStridedSlice: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), "float32"), begin: R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides: R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4): + gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides) + return gv + # fmt: on + mod = LegalizeOps()(DynamicStridedSlice) + with tvm.target.Target(target): + mod = tvm.tir.transform.DefaultGPUSchedule()(mod) + vm = build(mod) + + x_np = np.random.rand(8, 9, 10, 10).astype(np.float32) + data_nd = tvm.nd.array(x_np, dev) + begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev) + end_nd = tvm.nd.array(np.array(end).astype("int64"), dev) + strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev) + + # Reference implementation + out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) + out_nd = vm["main"](data_nd, begin_nd, end_nd, strides_nd) + tvm.testing.assert_allclose(out_nd.numpy(), out_npy) + + +@pytest.mark.parametrize( + "begin, end, strides", + [ + ([0, 2, 4, 4], [5, 5, 7, 8], [1, 1, 2, 3]), + ([0, 2, 4, 4], [5, 5, 11, 10], [1, 1, 1, 1]), + ([0, 2, 10, 14], [0, 5, 1, 1], [1, 1, -1, -2]), + ], +) +def test_dynamic_strided_slice_symbolic(begin, end, strides): + # fmt: off + @tvm.script.ir_module + class DynamicStridedSlice: + @R.function + def main(x: R.Tensor(("m", "n", 10, 10), "float32"), begin: R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides: R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4): + m = T.int64() + n = T.int64() + gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides) + return gv + # fmt: on + mod = LegalizeOps()(DynamicStridedSlice) + vm = build(mod) + + x_np = np.random.rand(8, 9, 10, 10).astype(np.float32) + data_nd = tvm.nd.array(x_np, dev) + begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev) + end_nd = tvm.nd.array(np.array(end).astype("int64"), dev) + strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev) + + # Reference implementation + out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) + out_nd = vm["main"](data_nd, begin_nd, end_nd, strides_nd) + tvm.testing.assert_allclose(out_nd.numpy(), out_npy) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py index 9390a2c9b0..8b2f8c0b29 100644 --- a/tests/python/relax/test_op_index.py +++ b/tests/python/relax/test_op_index.py @@ -30,6 +30,7 @@ def test_op_correctness(): assert relax.op.strided_slice(x, axes=[0], begin=[0], end=[2]).op == Op.get( "relax.strided_slice" ) + assert relax.op.dynamic_strided_slice(x, x, x, x).op == Op.get("relax.dynamic_strided_slice") def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -673,5 +674,201 @@ def test_strided_slice_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8])) +def test_dynamic_strided_slice_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((8, 9, 10, 10))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + + b0 = relax.Var("begin", R.Tensor((4,), "int64")) + e0 = relax.Var("end", R.Tensor((4,), "int64")) + s0 = relax.Var("strides", R.Tensor((4,), "int64")) + b1 = relax.Var("begin", R.Tensor((4,))) + e1 = relax.Var("end", R.Tensor((4,))) + s1 = relax.Var("stride", R.Tensor((4,))) + + _check_inference( + bb, + relax.op.dynamic_strided_slice(x0, b0, e0, s0), + R.Tensor("float32", ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x1, b0, e0, s0), + R.Tensor("float32", ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x2, b0, e0, s0), + R.Tensor("float32", ndim=-1), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x3, b0, e0, s0), + R.Tensor(ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x4, b0, e0, s0), + R.Tensor(ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x5, b0, e0, s0), + R.Tensor(ndim=-1), + ) + + _check_inference( + bb, + relax.op.dynamic_strided_slice(x0, b1, e1, s1), + R.Tensor("float32", ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x1, b1, e1, s1), + R.Tensor("float32", ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x2, b1, e1, s1), + R.Tensor("float32", ndim=-1), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x3, b1, e1, s1), + R.Tensor(ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x4, b1, e1, s1), + R.Tensor(ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x5, b1, e1, s1), + R.Tensor(ndim=-1), + ) + + +def test_dynamic_strided_slice_infer_struct_info_symbolic(): + bb = relax.BlockBuilder() + i = tir.Var("i", "int64") + j = tir.Var("j", "int64") + k = tir.Var("k", "int64") + l = tir.Var("l", "int64") + x0 = relax.Var("x", R.Tensor((i, j, k, l), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((i, j, k, l))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + + b0 = relax.Var("begin", R.Tensor((4,), "int64")) + e0 = relax.Var("end", R.Tensor((4,), "int64")) + s0 = relax.Var("stride", R.Tensor((4,), "int64")) + b1 = relax.Var("begin", R.Tensor((4,))) + e1 = relax.Var("end", R.Tensor((4,))) + s1 = relax.Var("stride", R.Tensor((4,))) + + _check_inference( + bb, + relax.op.dynamic_strided_slice(x0, b0, e0, s0), + R.Tensor("float32", ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x1, b0, e0, s0), + R.Tensor("float32", ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x2, b0, e0, s0), + R.Tensor("float32", ndim=-1), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x3, b0, e0, s0), + R.Tensor(ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x4, b0, e0, s0), + R.Tensor(ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x5, b0, e0, s0), + R.Tensor(ndim=-1), + ) + + _check_inference( + bb, + relax.op.dynamic_strided_slice(x0, b1, e1, s1), + R.Tensor("float32", ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x1, b1, e1, s1), + R.Tensor("float32", ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x2, b1, e1, s1), + R.Tensor("float32", ndim=-1), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x3, b1, e1, s1), + R.Tensor(ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x4, b1, e1, s1), + R.Tensor(ndim=4), + ) + _check_inference( + bb, + relax.op.dynamic_strided_slice(x5, b1, e1, s1), + R.Tensor(ndim=-1), + ) + + +def test_dynamic_strided_slice_infer_struct_info_arg_wrong_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + b0 = relax.Var("begin", R.Tensor((4,), "float32")) + e0 = relax.Var("end", R.Tensor((4,), "float32")) + s0 = relax.Var("stride", R.Tensor((4,), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x0, b0, e0, s0)) + + +def test_dynamic_strided_slice_infer_struct_info_arg_wrong_shape_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + m = tir.Var("m", "int64") + # invalid arg + b0 = relax.Var("begin", R.Tensor("int64", ndim=2)) + b1 = relax.Var("begin", R.Tensor((1,), "int64")) + b2 = relax.Var("begin", R.Tensor((2, 2), "int64")) + b3 = relax.Var("begin", R.Tensor((m,), "int64")) + # valid args + e0 = relax.Var("end", R.Tensor((4,), "int64")) + s0 = relax.Var("stride", R.Tensor((4,), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x0, b0, e0, s0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x0, b1, e0, s0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x0, b2, e0, s0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x0, b3, e0, s0)) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index 39224240ef..d7c0b54af2 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -210,6 +210,495 @@ def test_strided_slice_symbolic(): tvm.ir.assert_structural_equal(mod, Expected) +def test_dynamic_strided_slice(): + # fmt: off + @tvm.script.ir_module + class DynamicStridedSlice: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), "float32"), begin: R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides: R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4): + gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides) + return gv + @tvm.script.ir_module + class Expected: + @T.prim_func + def dynamic_strided_slice( + rxplaceholder: T.Buffer( + (T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32" + ), + rxplaceholder_1: T.Buffer((T.int64(4),), "int64"), + rxplaceholder_2: T.Buffer((T.int64(4),), "int64"), + rxplaceholder_3: T.Buffer((T.int64(4),), "int64"), + var_T_strided_slice_dynamic: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + s, s_1, s_2, s_3 = T.int64(), T.int64(), T.int64(), T.int64() + T_strided_slice_dynamic = T.match_buffer( + var_T_strided_slice_dynamic, (s, s_1, s_2, s_3) + ) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(s, s_1, s_2, s_3): + with T.block("T_strided_slice_dynamic"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads( + rxplaceholder[ + T.min(rxplaceholder_1[T.int64(0)], T.int64(7)) + + v_ax0 * rxplaceholder_3[T.int64(0)], + T.min(rxplaceholder_1[T.int64(1)], T.int64(8)) + + v_ax1 * rxplaceholder_3[T.int64(1)], + T.min(rxplaceholder_1[T.int64(2)], T.int64(9)) + + v_ax2 * rxplaceholder_3[T.int64(2)], + T.min(rxplaceholder_1[T.int64(3)], T.int64(9)) + + v_ax3 * rxplaceholder_3[T.int64(3)], + ], + rxplaceholder_1[T.int64(0) : T.int64(4)], + rxplaceholder_3[T.int64(0) : T.int64(4)], + ) + T.writes(T_strided_slice_dynamic[v_ax0, v_ax1, v_ax2, v_ax3]) + T_strided_slice_dynamic[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[ + T.min(rxplaceholder_1[T.int64(0)], T.int64(7)) + + v_ax0 * rxplaceholder_3[T.int64(0)], + T.min(rxplaceholder_1[T.int64(1)], T.int64(8)) + + v_ax1 * rxplaceholder_3[T.int64(1)], + T.min(rxplaceholder_1[T.int64(2)], T.int64(9)) + + v_ax2 * rxplaceholder_3[T.int64(2)], + T.min(rxplaceholder_1[T.int64(3)], T.int64(9)) + + v_ax3 * rxplaceholder_3[T.int64(3)], + ] + + @T.prim_func + def shape_func( + rxplaceholder: T.Buffer( + (T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32" + ), + rxplaceholder_1: T.Buffer((T.int64(4),), "int64"), + rxplaceholder_2: T.Buffer((T.int64(4),), "int64"), + rxplaceholder_3: T.Buffer((T.int64(4),), "int64"), + T_shape_func_strided_slice_dynamic: T.Buffer((T.int64(4),), "int64"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + for i in range(T.int64(4)): + with T.block("T_shape_func_strided_slice_dynamic"): + v_i = T.axis.spatial(T.int64(4), i) + T.reads( + rxplaceholder_3[v_i], rxplaceholder_1[v_i], rxplaceholder_2[v_i] + ) + T.writes(T_shape_func_strided_slice_dynamic[v_i]) + T_shape_func_strided_slice_dynamic[v_i] = T.Select( + rxplaceholder_3[v_i] < T.int64(0), + ( + T.min( + T.max( + T.Select( + rxplaceholder_1[v_i] < T.int64(0), + rxplaceholder_1[v_i] + + T.Select( + v_i == T.int64(3), + T.int64(10), + T.Select( + v_i == T.int64(2), + T.int64(10), + T.Select( + v_i == T.int64(1), + T.int64(9), + T.Select( + v_i == T.int64(0), + T.int64(8), + T.int64(-1), + ), + ), + ), + ), + rxplaceholder_1[v_i], + ), + T.int64(-1), + ), + T.Select( + v_i == T.int64(3), + T.int64(10), + T.Select( + v_i == T.int64(2), + T.int64(10), + T.Select( + v_i == T.int64(1), + T.int64(9), + T.Select( + v_i == T.int64(0), T.int64(8), T.int64(-1) + ), + ), + ), + ) + - T.int64(1), + ) + - T.min( + T.max( + T.Select( + rxplaceholder_2[v_i] < T.int64(0), + rxplaceholder_2[v_i] + + T.Select( + v_i == T.int64(3), + T.int64(10), + T.Select( + v_i == T.int64(2), + T.int64(10), + T.Select( + v_i == T.int64(1), + T.int64(9), + T.Select( + v_i == T.int64(0), + T.int64(8), + T.int64(-1), + ), + ), + ), + ), + rxplaceholder_2[v_i], + ), + T.int64(-1), + ), + T.Select( + v_i == T.int64(3), + T.int64(10), + T.Select( + v_i == T.int64(2), + T.int64(10), + T.Select( + v_i == T.int64(1), + T.int64(9), + T.Select( + v_i == T.int64(0), T.int64(8), T.int64(-1) + ), + ), + ), + ) + - T.int64(1), + ) + - rxplaceholder_3[v_i] + - T.int64(1) + ) + // (rxplaceholder_3[v_i] * T.int64(-1)), + ( + T.min( + T.max( + T.Select( + rxplaceholder_2[v_i] < T.int64(0), + rxplaceholder_2[v_i] + + T.Select( + v_i == T.int64(3), + T.int64(10), + T.Select( + v_i == T.int64(2), + T.int64(10), + T.Select( + v_i == T.int64(1), + T.int64(9), + T.Select( + v_i == T.int64(0), + T.int64(8), + T.int64(-1), + ), + ), + ), + ), + rxplaceholder_2[v_i], + ), + T.int64(0), + ), + T.Select( + v_i == T.int64(3), + T.int64(10), + T.Select( + v_i == T.int64(2), + T.int64(10), + T.Select( + v_i == T.int64(1), + T.int64(9), + T.Select( + v_i == T.int64(0), T.int64(8), T.int64(-1) + ), + ), + ), + ), + ) + + rxplaceholder_3[v_i] + - T.min( + T.max( + T.Select( + rxplaceholder_1[v_i] < T.int64(0), + rxplaceholder_1[v_i] + + T.Select( + v_i == T.int64(3), + T.int64(10), + T.Select( + v_i == T.int64(2), + T.int64(10), + T.Select( + v_i == T.int64(1), + T.int64(9), + T.Select( + v_i == T.int64(0), + T.int64(8), + T.int64(-1), + ), + ), + ), + ), + rxplaceholder_1[v_i], + ), + T.int64(0), + ), + T.Select( + v_i == T.int64(3), + T.int64(10), + T.Select( + v_i == T.int64(2), + T.int64(10), + T.Select( + v_i == T.int64(1), + T.int64(9), + T.Select( + v_i == T.int64(0), T.int64(8), T.int64(-1) + ), + ), + ), + ), + ) + - T.int64(1) + ) + // rxplaceholder_3[v_i], + ) + + @R.function + def main( + x: R.Tensor((8, 9, 10, 10), dtype="float32"), + begin: R.Tensor((4,), dtype="int64"), + end: R.Tensor((4,), dtype="int64"), + strides: R.Tensor((4,), dtype="int64"), + ) -> R.Tensor(dtype="float32", ndim=4): + s = T.int64() + s_1 = T.int64() + s_2 = T.int64() + s_3 = T.int64() + gv = R.call_tir( + Expected.shape_func, + (x, begin, end, strides), + out_sinfo=R.Tensor((4,), dtype="int64"), + ) + gv1: R.Shape(ndim=4) = R.call_packed( + "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=4),) + ) + gv2: R.Shape([s, s_1, s_2, s_3]) = R.match_cast( + gv1, R.Shape([s, s_1, s_2, s_3]) + ) + gv_1 = R.call_tir( + Expected.dynamic_strided_slice, + (x, begin, end, strides), + out_sinfo=R.Tensor((s, s_1, s_2, s_3), dtype="float32"), + ) + return gv_1 + # fmt: on + mod = LegalizeOps()(DynamicStridedSlice) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_dynamic_strided_slice_symbolic(): + # fmt: off + @tvm.script.ir_module + class DynamicStridedSlice: + @R.function + def main(x: R.Tensor((10, "n"), "float32"), begin:R.Tensor((2,), "int64"), end:R.Tensor((2,), "int64"), strides:R.Tensor((2,), "int64")) -> R.Tensor("float32", ndim=2): + n = T.int64() + gv: R.Tensor("float32", ndim=2) = R.dynamic_strided_slice(x, begin, end, strides) + return gv + @tvm.script.ir_module + class Expected: + @T.prim_func + def dynamic_strided_slice( + var_rxplaceholder: T.handle, + rxplaceholder: T.Buffer((T.int64(2),), "int64"), + rxplaceholder_1: T.Buffer((T.int64(2),), "int64"), + rxplaceholder_2: T.Buffer((T.int64(2),), "int64"), + var_T_strided_slice_dynamic: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder_3 = T.match_buffer(var_rxplaceholder, (T.int64(10), n)) + s, s_1 = T.int64(), T.int64() + T_strided_slice_dynamic = T.match_buffer(var_T_strided_slice_dynamic, (s, s_1)) + # with T.block("root"): + for ax0, ax1 in T.grid(s, s_1): + with T.block("T_strided_slice_dynamic"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads( + rxplaceholder_3[ + T.min(rxplaceholder[T.int64(0)], T.int64(9)) + + v_ax0 * rxplaceholder_2[T.int64(0)], + T.min(rxplaceholder[T.int64(1)], n - T.int64(1)) + + v_ax1 * rxplaceholder_2[T.int64(1)], + ], + rxplaceholder[T.int64(0) : T.int64(2)], + rxplaceholder_2[T.int64(0) : T.int64(2)], + ) + T.writes(T_strided_slice_dynamic[v_ax0, v_ax1]) + T_strided_slice_dynamic[v_ax0, v_ax1] = rxplaceholder_3[ + T.min(rxplaceholder[T.int64(0)], T.int64(9)) + + v_ax0 * rxplaceholder_2[T.int64(0)], + T.min(rxplaceholder[T.int64(1)], n - T.int64(1)) + + v_ax1 * rxplaceholder_2[T.int64(1)], + ] + + @T.prim_func + def shape_func( + var_rxplaceholder: T.handle, + rxplaceholder: T.Buffer((T.int64(2),), "int64"), + rxplaceholder_1: T.Buffer((T.int64(2),), "int64"), + rxplaceholder_2: T.Buffer((T.int64(2),), "int64"), + T_shape_func_strided_slice_dynamic: T.Buffer((T.int64(2),), "int64"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + rxplaceholder_3 = T.match_buffer(var_rxplaceholder, (T.int64(10), n)) + # with T.block("root"): + for i in range(T.int64(2)): + with T.block("T_shape_func_strided_slice_dynamic"): + v_i = T.axis.spatial(T.int64(2), i) + T.reads(rxplaceholder_2[v_i], rxplaceholder[v_i], rxplaceholder_1[v_i]) + T.writes(T_shape_func_strided_slice_dynamic[v_i]) + T_shape_func_strided_slice_dynamic[v_i] = T.Select( + rxplaceholder_2[v_i] < T.int64(0), + ( + T.min( + T.max( + T.Select( + rxplaceholder[v_i] < T.int64(0), + rxplaceholder[v_i] + + T.Select( + v_i == T.int64(1), + n, + T.Select( + v_i == T.int64(0), T.int64(10), T.int64(-1) + ), + ), + rxplaceholder[v_i], + ), + T.int64(-1), + ), + T.Select( + v_i == T.int64(1), + n, + T.Select(v_i == T.int64(0), T.int64(10), T.int64(-1)), + ) + - T.int64(1), + ) + - T.min( + T.max( + T.Select( + rxplaceholder_1[v_i] < T.int64(0), + rxplaceholder_1[v_i] + + T.Select( + v_i == T.int64(1), + n, + T.Select( + v_i == T.int64(0), T.int64(10), T.int64(-1) + ), + ), + rxplaceholder_1[v_i], + ), + T.int64(-1), + ), + T.Select( + v_i == T.int64(1), + n, + T.Select(v_i == T.int64(0), T.int64(10), T.int64(-1)), + ) + - T.int64(1), + ) + - rxplaceholder_2[v_i] + - T.int64(1) + ) + // (rxplaceholder_2[v_i] * T.int64(-1)), + ( + T.min( + T.max( + T.Select( + rxplaceholder_1[v_i] < T.int64(0), + rxplaceholder_1[v_i] + + T.Select( + v_i == T.int64(1), + n, + T.Select( + v_i == T.int64(0), T.int64(10), T.int64(-1) + ), + ), + rxplaceholder_1[v_i], + ), + T.int64(0), + ), + T.Select( + v_i == T.int64(1), + n, + T.Select(v_i == T.int64(0), T.int64(10), T.int64(-1)), + ), + ) + + rxplaceholder_2[v_i] + - T.min( + T.max( + T.Select( + rxplaceholder[v_i] < T.int64(0), + rxplaceholder[v_i] + + T.Select( + v_i == T.int64(1), + n, + T.Select( + v_i == T.int64(0), T.int64(10), T.int64(-1) + ), + ), + rxplaceholder[v_i], + ), + T.int64(0), + ), + T.Select( + v_i == T.int64(1), + n, + T.Select(v_i == T.int64(0), T.int64(10), T.int64(-1)), + ), + ) + - T.int64(1) + ) + // rxplaceholder_2[v_i], + ) + + @R.function + def main( + x: R.Tensor((10, "n"), dtype="float32"), + begin: R.Tensor((2,), dtype="int64"), + end: R.Tensor((2,), dtype="int64"), + strides: R.Tensor((2,), dtype="int64"), + ) -> R.Tensor(dtype="float32", ndim=2): + n = T.int64() + s = T.int64() + s_1 = T.int64() + gv = R.call_tir( + Expected.shape_func, + (x, begin, end, strides), + out_sinfo=R.Tensor((2,), dtype="int64"), + ) + gv1: R.Shape(ndim=2) = R.call_packed( + "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=2),) + ) + gv2: R.Shape([s, s_1]) = R.match_cast(gv1, R.Shape([s, s_1])) + gv_1 = R.call_tir( + Expected.dynamic_strided_slice, + (x, begin, end, strides), + out_sinfo=R.Tensor((s, s_1), dtype="float32"), + ) + return gv_1 + # fmt: on + + mod = LegalizeOps()(DynamicStridedSlice) + tvm.ir.assert_structural_equal(mod, Expected) + + ##################### Linear algebra ##################### diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 5866ffd5f7..ac69a2c85e 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -457,6 +457,51 @@ def verify_dynamic_strided_slice(in_shape, begin, end, strides=None): check_device(target) +def verify_relax_dynamic_strided_slice(in_shape, begin, end, strides, output_shape): + A = te.placeholder(shape=in_shape, name="A") + Begin = te.placeholder(shape=[len(in_shape)], name="begin", dtype="int64") + End = te.placeholder(shape=[len(in_shape)], name="end", dtype="int64") + Strides = te.placeholder(shape=[len(in_shape)], name="strides", dtype="int64") + + B = topi.dynamic_strided_slice(A, Begin, End, Strides, output_shape) + 1 + + OutShape = topi.shape_func_dynamic_strided_slice(A, Begin, End, Strides) + + def check_device(target): + dev = tvm.device(target, 0) + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + return + print("Running on target: %s" % target) + x_np = np.random.uniform(size=in_shape).astype(A.dtype) + out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) + 1 + data_nd = tvm.nd.array(x_np, dev) + out_nd = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype) + begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev) + end_nd = tvm.nd.array(np.array(end).astype("int64"), dev) + strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev) + + if target == "llvm": + # Check shape func + s = tvm.te.create_schedule(OutShape.op) + bar = tvm.build( + s, [A, Begin, End, Strides, OutShape], target, name="shape_func_stride_slice" + ) + out_shape_nd = tvm.nd.empty((len(out_npy.shape),), device=dev, dtype="int64") + bar(data_nd, begin_nd, end_nd, strides_nd, out_shape_nd) + + tvm.testing.assert_allclose(out_shape_nd.numpy(), output_shape) + + with tvm.target.Target(target): + s = tvm.topi.testing.get_injective_schedule(target)(B) + foo = tvm.build(s, [A, Begin, End, Strides, B], target, name="stride_slice") + foo(data_nd, begin_nd, end_nd, strides_nd, out_nd) + tvm.testing.assert_allclose(out_nd.numpy(), out_npy) + + for target in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: + check_device(target) + + def verify_strided_set(in_shape, v_shape, begin, end, strides=None): A = te.placeholder(shape=in_shape, name="A") V = te.placeholder(shape=v_shape, name="V") @@ -859,6 +904,15 @@ def test_dynamic_strided_slice(): verify_dynamic_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3]) +@tvm.testing.uses_gpu +def test_relax_dynamic_strided_slice(): + verify_relax_dynamic_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], [3, 1, 2]) + verify_relax_dynamic_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], [1, 3, 3]) + verify_relax_dynamic_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], [1, 2, 2]) + verify_relax_dynamic_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [1, 1, 1], [2, 3, 3]) + verify_relax_dynamic_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3], [1, 1, 1], [1, 0, 3]) + + @tvm.testing.uses_gpu def test_strided_set(): verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2])