kevinthesun commented on a change in pull request #7018:
URL: https://github.com/apache/tvm/pull/7018#discussion_r534584859



##########
File path: include/tvm/topi/transform.h
##########
@@ -598,17 +598,69 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& 
x, const te::Tensor& b
  *
  * \return A Tensor whose op member is the split operation
  */
-inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, 
const Array<Integer>& end,
-                            const Array<Integer>& strides, std::string 
slice_mode = "end",
-                            std::string name = "T_strided_slice", std::string 
tag = kInjective) {
+inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
+                            const Array<PrimExpr>& end, const Array<PrimExpr>& 
strides,
+                            std::string slice_mode = "end", std::string name = 
"T_strided_slice",
+                            std::string tag = kInjective) {
   size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
+  // Quick path for dynamic shape strided slice.
+  // This is for ease of use to dynamice strided slice in topi.
+  bool is_dyn = false;
+  for (size_t i = 0; i < src_tensor_dim; ++i) {
+    if (!IsConstInt(x->shape[i])) {
+      is_dyn = true;
+      break;
+    }
+  }
+  if (!is_dyn) {
+    for (size_t i = 0; i < begin.size(); ++i) {
+      if (begin[i].defined() && !IsConstInt(begin[i])) {
+        is_dyn = true;
+        break;
+      }
+    }
+  }
+  if (!is_dyn) {
+    for (size_t i = 0; i < end.size(); ++i) {
+      if (end[i].defined() && !IsConstInt(end[i])) {
+        is_dyn = true;
+        break;
+      }
+    }
+  }
+  if (!is_dyn) {
+    for (size_t i = 0; i < strides.size(); ++i) {
+      if (strides[i].defined() && !IsConstInt(strides[i])) {
+        is_dyn = true;
+        break;
+      }
+    }
+  }
+
+  Array<PrimExpr> out_shape;
+  if (is_dyn) {
+    for (size_t i = 0; i < src_tensor_dim; ++i) {
+      out_shape.push_back(indexdiv(end[i] - begin[i], strides[i]));
+    }
+    return te::compute(
+        out_shape,
+        [&](const Array<tvm::tir::Var>& indices) {
+          Array<PrimExpr> real_indices;
+          for (size_t i = 0; i < src_tensor_dim; ++i) {
+            real_indices.push_back(indices[i] * strides[i] + begin[i]);
+          }
+          return x(real_indices);
+        },
+        name, tag);
+  }
+

Review comment:
       That ```dynamic_strided_slice``` has ```begin```, ```end``` and 
```strides``` as tensors. As a result we can't use them to do PrimExpr 
computation.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to