Author: Alex Zinenko
Date: 2020-11-23T15:04:31+01:00
New Revision: 31a233d46367636f94c487b51aa2931a1cc9cf79

URL: 
https://github.com/llvm/llvm-project/commit/31a233d46367636f94c487b51aa2931a1cc9cf79
DIFF: 
https://github.com/llvm/llvm-project/commit/31a233d46367636f94c487b51aa2931a1cc9cf79.diff

LOG: [mlir] canonicalize away zero-iteration SCF for loops

An SCF 'for' loop does not iterate if its lower bound is equal to its upper
bound. Remove loops where both bounds are the same SSA value as such bounds are
guaranteed to be equal. Similarly, remove 'parallel' loops where at least one
pair of respective lower/upper bounds is specified by the same SSA value.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D91880

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 5da9f7c29cab..48b1b473f86d 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -521,6 +521,13 @@ struct SimplifyTrivialLoops : public 
OpRewritePattern<ForOp> {
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
+    // If the upper bound is the same as the lower bound, the loop does not
+    // iterate, just remove it.
+    if (op.lowerBound() == op.upperBound()) {
+      rewriter.replaceOp(op, op.getIterOperands());
+      return success();
+    }
+
     auto lb = op.lowerBound().getDefiningOp<ConstantOp>();
     auto ub = op.upperBound().getDefiningOp<ConstantOp>();
     if (!lb || !ub)
@@ -1066,11 +1073,30 @@ struct CollapseSingleIterationLoops : public 
OpRewritePattern<ParallelOp> {
     return success();
   }
 };
+
+/// Removes parallel loops in which at least one lower/upper bound pair 
consists
+/// of the same values - such loops have an empty iteration domain.
+struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
+  using OpRewritePattern<ParallelOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ParallelOp op,
+                                PatternRewriter &rewriter) const override {
+    for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) {
+      if (std::get<0>(dim) == std::get<1>(dim)) {
+        rewriter.replaceOp(op, op.initVals());
+        return success();
+      }
+    }
+    return failure();
+  }
+};
+
 } // namespace
 
 void ParallelOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                              MLIRContext *context) {
-  results.insert<CollapseSingleIterationLoops>(context);
+  results.insert<CollapseSingleIterationLoops, RemoveEmptyParallelLoops>(
+      context);
 }
 
 
//===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir 
b/mlir/test/Dialect/SCF/canonicalize.mlir
index faac86b94cdb..d57563461241 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -32,30 +32,6 @@ func @single_iteration(%A: memref<?x?x?xi32>) {
 
 // -----
 
-func @no_iteration(%A: memref<?x?xi32>) {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  scf.parallel (%i0, %i1) = (%c0, %c0) to (%c1, %c0) step (%c1, %c1) {
-    %c42 = constant 42 : i32
-    store %c42, %A[%i0, %i1] : memref<?x?xi32>
-    scf.yield
-  }
-  return
-}
-
-// CHECK-LABEL:   func @no_iteration(
-// CHECK-SAME:                        [[ARG0:%.*]]: memref<?x?xi32>) {
-// CHECK:           [[C0:%.*]] = constant 0 : index
-// CHECK:           [[C1:%.*]] = constant 1 : index
-// CHECK:           [[C42:%.*]] = constant 42 : i32
-// CHECK:           scf.parallel ([[V1:%.*]]) = ([[C0]]) to ([[C0]]) step 
([[C1]]) {
-// CHECK:             store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V1]]] : 
memref<?x?xi32>
-// CHECK:             scf.yield
-// CHECK:           }
-// CHECK:           return
-
-// -----
-
 func @one_unused(%cond: i1) -> (index) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -241,6 +217,22 @@ func @remove_zero_iteration_loop() {
   return
 }
 
+// CHECK-LABEL: @remove_zero_iteration_loop_vals
+func @remove_zero_iteration_loop_vals(%arg0: index) {
+  %c2 = constant 2 : index
+  // CHECK: %[[INIT:.*]] = "test.init"
+  %init = "test.init"() : () -> i32
+  // CHECK-NOT: scf.for
+  // CHECK-NOT: test.op
+  %0 = scf.for %i = %arg0 to %arg0 step %c2 iter_args(%arg = %init) -> (i32) {
+    %1 = "test.op"(%i, %arg) : (index, i32) -> i32
+    scf.yield %1 : i32
+  }
+  // CHECK: "test.consume"(%[[INIT]])
+  "test.consume"(%0) : (i32) -> ()
+  return
+}
+
 // CHECK-LABEL: @replace_single_iteration_loop
 func @replace_single_iteration_loop() {
   // CHECK: %[[LB:.*]] = constant 42
@@ -278,3 +270,24 @@ func @replace_single_iteration_loop_non_unit_step() {
   "test.consume"(%0) : (i32) -> ()
   return
 }
+
+// CHECK-LABEL: @remove_empty_parallel_loop
+func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
+  // CHECK: %[[INIT:.*]] = "test.init"
+  %init = "test.init"() : () -> f32
+  // CHECK-NOT: scf.parallel
+  // CHECK-NOT: test.produce
+  // CHECK-NOT: test.transform
+  %0 = scf.parallel (%i, %j, %k) = (%lb, %ub, %lb) to (%ub, %ub, %ub) step 
(%s, %s, %s) init(%init) -> f32 {
+    %1 = "test.produce"() : () -> f32
+    scf.reduce(%1) : f32 {
+    ^bb0(%lhs: f32, %rhs: f32):
+      %2 = "test.transform"(%lhs, %rhs) : (f32, f32) -> f32
+      scf.reduce.return %2 : f32
+    }
+    scf.yield
+  }
+  // CHECK: "test.consume"(%[[INIT]])
+  "test.consume"(%0) : (f32) -> ()
+  return
+}


        
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to