masahi commented on code in PR #14548:
URL: https://github.com/apache/tvm/pull/14548#discussion_r1162025341


##########
include/tvm/topi/transform.h:
##########
@@ -2035,6 +2034,73 @@ inline Tensor adv_index(const Tensor& data, const 
Array<Tensor>& indices,
       name, tag);
 }
 
+namespace relax {
+// relax dynamic slice
+inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& 
begin,
+                                        const te::Tensor& end, const 
te::Tensor& strides,
+                                        Array<PrimExpr> output_shape,
+                                        std::string name = 
"T_strided_slice_dynamic",
+                                        std::string tag = kInjective) {
+  const size_t num_dynamic_axes = x.ndim();
+  ICHECK_EQ(begin.ndim(), 1);
+  ICHECK_EQ(end.ndim(), 1);
+  ICHECK_EQ(strides.ndim(), 1);
+  const auto* len_begin = begin->shape[0].as<IntImmNode>();
+  const auto* len_end = end->shape[0].as<IntImmNode>();
+  const auto* len_strides = strides->shape[0].as<IntImmNode>();
+  ICHECK(len_begin);
+  ICHECK(len_end);
+  ICHECK(len_strides);
+  ICHECK_EQ(len_begin->value, num_dynamic_axes);
+  ICHECK_EQ(len_end->value, num_dynamic_axes);
+  ICHECK_EQ(len_strides->value, num_dynamic_axes);
+
+  return te::compute(
+      output_shape,
+      [&](const Array<tvm::tir::Var>& indices) {
+        Array<PrimExpr> real_indices;
+        for (size_t i = 0; i < num_dynamic_axes; ++i) {
+          auto ind = make_const(DataType::Int(64), i);
+          real_indices.push_back(indices[i] * strides(ind) + 
tvm::min(begin(ind), x->shape[i] - 1));
+        }
+        return x(real_indices);
+      },
+      name, tag);
+}
+
+inline te::Tensor shape_func_dynamic_strided_slice(
+    const te::Tensor& data, const te::Tensor& begin, const te::Tensor& end,
+    const te::Tensor& strides, std::string name = 
"T_shape_func_strided_slice_dynamic") {
+  return te::compute(
+      {begin->shape[0]},
+      [&](const Array<tvm::tir::Var>& indices) {
+        ICHECK(indices.size() == 1);
+        auto CanonicalizeIndex = [&](PrimExpr index, PrimExpr extent, PrimExpr 
stride) {
+          PrimExpr begin_range = if_then_else(stride < 0, -1, 0);
+          PrimExpr end_range = if_then_else(stride < 0, extent - 1, extent);
+          index = if_then_else(index < 0, index + extent, index);
+          return min(max(index, begin_range), end_range);
+        };
+
+        auto GetLength = [&](PrimExpr begin, PrimExpr end, PrimExpr stride, 
PrimExpr length) {
+          begin = CanonicalizeIndex(begin, length, stride);
+          end = CanonicalizeIndex(end, length, stride);
+          PrimExpr len1 = ceildiv(begin - end, -stride);
+          PrimExpr len2 = ceildiv(end - begin, stride);
+          return if_then_else(stride < 0, len1, len2);
+        };
+        PrimExpr length(-1);
+        int ndim = data.ndim();
+        for (int i = 0; i < ndim; i++) {
+          length = if_then_else(indices[0] == i, data->shape[i], length);
+        }
+        return GetLength(begin(indices), end(indices), strides(indices), 
length);

Review Comment:
   I think the question is whether or not we consider shape funcs a part of "op 
definition". This is the case for Relay, and one of the reason adding a new op 
in Relay is complicated. I don't know what the goal of #14278 is, but the 
current discussion in this PR suggests shape funcs need not implemented under 
topi. It seems more like "implementation details" of the legalizer.
   
    So unless #14278 is aiming to generate an op legalizer definition, I think 
this work is completely decoupled from #14278.   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to