masahi commented on code in PR #14548:
URL: https://github.com/apache/tvm/pull/14548#discussion_r1162238558
##########
src/relax/op/tensor/index.cc:
##########
@@ -239,5 +239,78 @@ TVM_REGISTER_OP("relax.strided_slice")
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
+/* relax.dynamic_strided_slice */
+Expr dynamic_strided_slice(Expr x, //
+ Expr begin, //
+ Expr end, //
+ Expr strides) {
+ static const Op& op = Op::Get("relax.dynamic_strided_slice");
+ return Call(op, {std::move(x), std::move(begin), std::move(end),
std::move(strides)}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice);
+
+StructInfo InferStructInfoDynStridedSlice(const Call& call, const
BlockBuilder& ctx) {
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* begin_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+ const auto* end_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+ const auto* strides_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[3]);
+
+ ICHECK(data_sinfo);
+ if (data_sinfo->IsUnknownNdim()) {
+ LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes
begin/end/stride "
+ "tensors are well-formed. It could produce runtime error
when this assumption "
+ "turns out to be wrong.";
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
+ }
+ if (data_sinfo->IsUnknownDtype()) {
+ LOG(WARNING) << "When data type is unknown, dynamic strided slice assumes
to have a valid "
+ "dtype. It could produce runtime error when this
assumption "
+ "turns out to be wrong.";
+ }
+
+ int n_axis = data_sinfo->ndim;
+ auto diag_def = [&](const TensorStructInfoNode* sinfo, String name) {
+ ICHECK(sinfo) << "Dynamic strided slice requires the input " << name
+ << " to be have the struct info. Please try normalizing the
inputs.";
+ CHECK_EQ(sinfo->ndim, 1) << "Dynamic strided slice requires " << name
+ << " to be 1d tensor (list of values).";
+ const auto* shape = sinfo->shape.as<ShapeExprNode>();
+ ICHECK(shape) << "Dynamic strided slice requires the input " << name
+ << " to have well-defined shape.";
+ // NOTE(tvm-team): This strong restriction seems necessary for now until
we have a generic
+ // solution in converting 1d Tensor with unknown num_elem to
Array<PrimExpr>.
+ const auto* num_elem = shape->values[0].as<IntImmNode>();
+ ICHECK(num_elem) << "Dynamic strided slice requires the input " << name
+ << " to have a known integer shape value.";
+ CHECK_EQ(num_elem->value, n_axis) << "Dynamic strided slice requires the
number of indices in "
+ << name << " to equal the number of
axes.";
+ if (sinfo->IsUnknownDtype()) {
+ LOG(WARNING) << "Dynamic strided slice assumes " << name
+ << " to be int64 when it is not specified.";
+ } else {
+ CHECK(sinfo->dtype == DataType::Int(64))
+ << "Dynamic strided_slice expects the input " << name
+ << "values to be all int64. However, " << name << " has dtype " <<
sinfo->dtype << ".";
+ }
+ };
+ diag_def(begin_sinfo, "begin");
+ diag_def(end_sinfo, "end");
+ diag_def(strides_sinfo, "stride");
+
+ // The output shape will depend on the runtime value in begin/end/stride
tensors.
+ // TODO(tvm-team): Extract more compile-time info when those tensors are
constants.
+ return TensorStructInfo(data_sinfo->dtype, n_axis);
Review Comment:
Assuming "symbolic variable" you mentioned is just another `PrimExpr`, I
expect we can fill in a tensor by that value, just like any other `PrimExpr`.
Why not try `Can begin(ind) be symbolic in practice?` thing?
> Can't we access the value by using the index in the above example?
What I meant was, as soon as we put a symbolic expression in a TE tensor, we
lost all symbolic-specific information. You can index it, but what you get is
an opaque value. So we cannot exploit any symbolic information about it.
So output shape represented by TE Tensor loses symbolic or constant shape
information. That's why I think it is better to directly generate shape func
via TIR.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]