This is an automated email from the ASF dual-hosted git repository.

kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6a6093bc18 fold const or empty iter partition (#12080)
6a6093bc18 is described below

commit 6a6093bc180ed762b3e0d19eb37fcf10d97289c1
Author: wrongtest <wrongte...@gmail.com>
AuthorDate: Wed Jul 13 22:52:35 2022 +0800

    fold const or empty iter partition (#12080)
---
 src/tir/transforms/loop_partition.cc | 30 ++++++++++++++++--------------
 1 file changed, 16 insertions(+), 14 deletions(-)

diff --git a/src/tir/transforms/loop_partition.cc 
b/src/tir/transforms/loop_partition.cc
index 59ac339006..677506889e 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -587,16 +587,17 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var 
var, PrimExpr min, Prim
   if (middle_interval_i->HasLowerBound()) {
     body_begin = analyzer_.Simplify(middle_interval.min());
     if (!analyzer_.CanProve(body_begin == min)) {
-      PrimExpr cond = (body_begin - min >= 0);
-      if (!analyzer_.CanProve(cond)) {
-        LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre 
doubt loop";
-        body_begin = Max(body_begin, min);
+      PrimExpr extent = analyzer_.Simplify(body_begin - min);
+      if (!analyzer_.CanProve(extent > 0)) {
+        body_begin = tvm::max(body_begin, min);
         // stop recursing on this interval if we can't prove it has 
non-negative length
         pre_stmt_recurse = false;
       }
-      if (!partition_thread_scope) {
-        Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
-        pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
+      if (!analyzer_.CanProve(extent <= 0)) {
+        if (!partition_thread_scope) {
+          Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
+          pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
+        }
       }
     }
   } else {
@@ -612,16 +613,17 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var 
var, PrimExpr min, Prim
     post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1);
     if (!analyzer_.CanProve(middle_interval.max() == max)) {
       // require the extent to be non-negative
-      PrimExpr cond = (max - post_doubt_begin + 1 >= 0);
-      if (!analyzer_.CanProve(cond)) {
-        LOG(WARNING) << "Cannot prove: " << cond << ", when generating the 
post doubt loop";
-        post_doubt_begin = Min(post_doubt_begin, max + 1);
+      PrimExpr extent = analyzer_.Simplify(max - post_doubt_begin + 1);
+      if (!analyzer_.CanProve(extent > 0)) {
+        post_doubt_begin = tvm::min(post_doubt_begin, max + 1);
         // stop recursing on this interval if we can't prove it has 
non-negative length
         post_stmt_recurse = false;
       }
-      if (!partition_thread_scope) {
-        Stmt post_body = Substitute(body, {{Var{var}, var + 
post_doubt_begin}});
-        post_stmt = MakeFor(stmt.get(), max - post_doubt_begin + 1, post_body);
+      if (!analyzer_.CanProve(extent <= 0)) {
+        if (!partition_thread_scope) {
+          Stmt post_body = Substitute(body, {{Var{var}, var + 
post_doubt_begin}});
+          post_stmt = MakeFor(stmt.get(), extent, post_body);
+        }
       }
     }
   } else {

Reply via email to