This is an automated email from the ASF dual-hosted git repository. kparzysz pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 6a6093bc18 fold const or empty iter partition (#12080) 6a6093bc18 is described below commit 6a6093bc180ed762b3e0d19eb37fcf10d97289c1 Author: wrongtest <wrongte...@gmail.com> AuthorDate: Wed Jul 13 22:52:35 2022 +0800 fold const or empty iter partition (#12080) --- src/tir/transforms/loop_partition.cc | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 59ac339006..677506889e 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -587,16 +587,17 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim if (middle_interval_i->HasLowerBound()) { body_begin = analyzer_.Simplify(middle_interval.min()); if (!analyzer_.CanProve(body_begin == min)) { - PrimExpr cond = (body_begin - min >= 0); - if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; - body_begin = Max(body_begin, min); + PrimExpr extent = analyzer_.Simplify(body_begin - min); + if (!analyzer_.CanProve(extent > 0)) { + body_begin = tvm::max(body_begin, min); // stop recursing on this interval if we can't prove it has non-negative length pre_stmt_recurse = false; } - if (!partition_thread_scope) { - Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); - pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body); + if (!analyzer_.CanProve(extent <= 0)) { + if (!partition_thread_scope) { + Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); + pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body); + } } } } else { @@ -612,16 +613,17 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1); if (!analyzer_.CanProve(middle_interval.max() == max)) { // require the extent to be non-negative - PrimExpr cond = (max - post_doubt_begin + 1 >= 0); - if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; - post_doubt_begin = Min(post_doubt_begin, max + 1); + PrimExpr extent = analyzer_.Simplify(max - post_doubt_begin + 1); + if (!analyzer_.CanProve(extent > 0)) { + post_doubt_begin = tvm::min(post_doubt_begin, max + 1); // stop recursing on this interval if we can't prove it has non-negative length post_stmt_recurse = false; } - if (!partition_thread_scope) { - Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); - post_stmt = MakeFor(stmt.get(), max - post_doubt_begin + 1, post_body); + if (!analyzer_.CanProve(extent <= 0)) { + if (!partition_thread_scope) { + Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); + post_stmt = MakeFor(stmt.get(), extent, post_body); + } } } } else {