sunggg commented on code in PR #14548:
URL: https://github.com/apache/tvm/pull/14548#discussion_r1161859857


##########
src/relax/op/tensor/index.cc:
##########
@@ -239,5 +239,78 @@ TVM_REGISTER_OP("relax.strided_slice")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
+/* relax.dynamic_strided_slice */
+Expr dynamic_strided_slice(Expr x,      //
+                           Expr begin,  //
+                           Expr end,    //
+                           Expr strides) {
+  static const Op& op = Op::Get("relax.dynamic_strided_slice");
+  return Call(op, {std::move(x), std::move(begin), std::move(end), 
std::move(strides)}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice);
+
+StructInfo InferStructInfoDynStridedSlice(const Call& call, const 
BlockBuilder& ctx) {
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* begin_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  const auto* end_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+  const auto* strides_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[3]);
+
+  ICHECK(data_sinfo);
+  if (data_sinfo->IsUnknownNdim()) {
+    LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes 
begin/end/stride "
+                    "tensors are well-formed. It could produce runtime error 
when this assumption "
+                    "turns out to be wrong.";
+    return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
+  }
+  if (data_sinfo->IsUnknownDtype()) {
+    LOG(WARNING) << "When data type is unknown, dynamic strided slice assumes 
to have a valid "
+                    "dtype. It could produce runtime error when this 
assumption "
+                    "turns out to be wrong.";
+  }
+
+  int n_axis = data_sinfo->ndim;
+  auto diag_def = [&](const TensorStructInfoNode* sinfo, String name) {
+    ICHECK(sinfo) << "Dynamic strided slice requires the input " << name
+                  << " to be have the struct info. Please try normalizing the 
inputs.";
+    CHECK_EQ(sinfo->ndim, 1) << "Dynamic strided slice requires " << name
+                             << " to be 1d tensor (list of values).";
+    const auto* shape = sinfo->shape.as<ShapeExprNode>();
+    ICHECK(shape) << "Dynamic strided slice requires the input " << name
+                  << " to have well-defined shape.";
+    // NOTE(tvm-team): This strong restriction seems necessary for now until 
we have a generic
+    // solution in converting 1d Tensor with unknown num_elem to 
Array<PrimExpr>.
+    const auto* num_elem = shape->values[0].as<IntImmNode>();
+    ICHECK(num_elem) << "Dynamic strided slice requires the input " << name
+                     << " to have a known integer shape value.";
+    CHECK_EQ(num_elem->value, n_axis) << "Dynamic strided slice requires the 
number of indices in "
+                                      << name << " to equal the number of 
axes.";
+    if (sinfo->IsUnknownDtype()) {
+      LOG(WARNING) << "Dynamic strided slice assumes " << name
+                   << " to be int64 when it is not specified.";
+    } else {
+      CHECK(sinfo->dtype == DataType::Int(64))
+          << "Dynamic strided_slice expects the input " << name
+          << "values to be all int64. However, " << name << " has dtype " << 
sinfo->dtype << ".";
+    }
+  };
+  diag_def(begin_sinfo, "begin");
+  diag_def(end_sinfo, "end");
+  diag_def(strides_sinfo, "stride");
+
+  // The output shape will depend on the runtime value in begin/end/stride 
tensors.
+  // TODO(tvm-team): Extract more compile-time info when those tensors are 
constants.
+  return TensorStructInfo(data_sinfo->dtype, n_axis);

Review Comment:
   Yes, that is current limitation of this PR so I left it as TODO. 
   
   Actually, I have a question regarding 
https://github.com/apache/tvm/pull/8165. Can `te::Tensor` contain symbolic 
variables? Maybe I missed but couldn't find any relevant test cases. 
   Please correct me if I'm wrong, based on my current understanding, data 
within `te::Tensor` will be certain values rather than symbolic variables. So I 
cannot see how partially static output shape will be represented if we go 
through F1->F2 route below. 
   ```C++
   // F1
   inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const 
te::Tensor& begin,
                                           const te::Tensor& end, const 
te::Tensor& strides, ...)
   {
        // ...
        Array<PrimExpr> begin_expr, end_expr, strides_expr;
        for (int64_t i = 0; i < num_dynamic_axes; ++i) {
            auto ind = make_const(index_dtype, i);
            begin_expr.push_back(begin(ind));    // <- Can `begin(ind)` be 
symbolic in practice?
            end_expr.push_back(end(ind));
            strides_expr.push_back(strides(ind));
         } 
         // Call F2
         return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, 
name, tag);
   }
   
   // F2
   inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& 
begin,
                                       const Array<PrimExpr>& end, const 
Array<PrimExpr>& strides, ...)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to