masahi commented on code in PR #14548:
URL: https://github.com/apache/tvm/pull/14548#discussion_r1161648422
##########
src/relax/op/tensor/index.cc:
##########
@@ -239,5 +239,78 @@ TVM_REGISTER_OP("relax.strided_slice")
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
+/* relax.dynamic_strided_slice */
+Expr dynamic_strided_slice(Expr x, //
+ Expr begin, //
+ Expr end, //
+ Expr strides) {
+ static const Op& op = Op::Get("relax.dynamic_strided_slice");
+ return Call(op, {std::move(x), std::move(begin), std::move(end),
std::move(strides)}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice);
+
+StructInfo InferStructInfoDynStridedSlice(const Call& call, const
BlockBuilder& ctx) {
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* begin_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+ const auto* end_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+ const auto* strides_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[3]);
+
+ ICHECK(data_sinfo);
+ if (data_sinfo->IsUnknownNdim()) {
+ LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes
begin/end/stride "
+ "tensors are well-formed. It could produce runtime error
when this assumption "
+ "turns out to be wrong.";
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
+ }
+ if (data_sinfo->IsUnknownDtype()) {
+ LOG(WARNING) << "When data type is unknown, dynamic strided slice assumes
to have a valid "
+ "dtype. It could produce runtime error when this
assumption "
+ "turns out to be wrong.";
+ }
+
+ int n_axis = data_sinfo->ndim;
+ auto diag_def = [&](const TensorStructInfoNode* sinfo, String name) {
+ ICHECK(sinfo) << "Dynamic strided slice requires the input " << name
+ << " to be have the struct info. Please try normalizing the
inputs.";
+ CHECK_EQ(sinfo->ndim, 1) << "Dynamic strided slice requires " << name
+ << " to be 1d tensor (list of values).";
+ const auto* shape = sinfo->shape.as<ShapeExprNode>();
+ ICHECK(shape) << "Dynamic strided slice requires the input " << name
+ << " to have well-defined shape.";
+ // NOTE(tvm-team): This strong restriction seems necessary for now until
we have a generic
+ // solution in converting 1d Tensor with unknown num_elem to
Array<PrimExpr>.
+ const auto* num_elem = shape->values[0].as<IntImmNode>();
+ ICHECK(num_elem) << "Dynamic strided slice requires the input " << name
+ << " to have a known integer shape value.";
+ CHECK_EQ(num_elem->value, n_axis) << "Dynamic strided slice requires the
number of indices in "
+ << name << " to equal the number of
axes.";
+ if (sinfo->IsUnknownDtype()) {
+ LOG(WARNING) << "Dynamic strided slice assumes " << name
+ << " to be int64 when it is not specified.";
+ } else {
+ CHECK(sinfo->dtype == DataType::Int(64))
+ << "Dynamic strided_slice expects the input " << name
+ << "values to be all int64. However, " << name << " has dtype " <<
sinfo->dtype << ".";
+ }
+ };
+ diag_def(begin_sinfo, "begin");
+ diag_def(end_sinfo, "end");
+ diag_def(strides_sinfo, "stride");
+
+ // The output shape will depend on the runtime value in begin/end/stride
tensors.
+ // TODO(tvm-team): Extract more compile-time info when those tensors are
constants.
+ return TensorStructInfo(data_sinfo->dtype, n_axis);
Review Comment:
I'm a bit concerned about this. It seems we can only express slicing with
either "all static" or "all dynamic" axes. But partially static / dynamic
slicing is very common in practice (e.g., slicing only along the dynamic
"batch" (the number of detected boxes) dimension object detection models). If
shape func needs to be implemented via topi, I don't see an way to express such
partially-static output shape.
For Relay I added a hacky workaround https://github.com/apache/tvm/pull/8165
for this issue.
##########
python/tvm/relax/transform/legalize_ops/index.py:
##########
@@ -59,3 +60,39 @@ def _strided_slice(bb: BlockBuilder, call: Call) -> Expr:
call.attrs.axes,
slice_mode="end",
)
+
+
+@register_legalize("relax.dynamic_strided_slice")
+def _dynamic_strided_slice(bb: BlockBuilder, call: Call) -> Expr:
+ # 1. Insert shape function
+ output_shape = bb.normalize(
+ bb.call_te(
+ topi.shape_func_dynamic_strided_slice,
Review Comment:
Could there be an alternative way to implement shape funcs? Because te
tensor cannot express "constant"-ness.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]