Author: Alex Zinenko Date: 2020-11-23T15:04:31+01:00 New Revision: 31a233d46367636f94c487b51aa2931a1cc9cf79
URL: https://github.com/llvm/llvm-project/commit/31a233d46367636f94c487b51aa2931a1cc9cf79 DIFF: https://github.com/llvm/llvm-project/commit/31a233d46367636f94c487b51aa2931a1cc9cf79.diff LOG: [mlir] canonicalize away zero-iteration SCF for loops An SCF 'for' loop does not iterate if its lower bound is equal to its upper bound. Remove loops where both bounds are the same SSA value as such bounds are guaranteed to be equal. Similarly, remove 'parallel' loops where at least one pair of respective lower/upper bounds is specified by the same SSA value. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D91880 Added: Modified: mlir/lib/Dialect/SCF/SCF.cpp mlir/test/Dialect/SCF/canonicalize.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index 5da9f7c29cab..48b1b473f86d 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -521,6 +521,13 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { LogicalResult matchAndRewrite(ForOp op, PatternRewriter &rewriter) const override { + // If the upper bound is the same as the lower bound, the loop does not + // iterate, just remove it. + if (op.lowerBound() == op.upperBound()) { + rewriter.replaceOp(op, op.getIterOperands()); + return success(); + } + auto lb = op.lowerBound().getDefiningOp<ConstantOp>(); auto ub = op.upperBound().getDefiningOp<ConstantOp>(); if (!lb || !ub) @@ -1066,11 +1073,30 @@ struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> { return success(); } }; + +/// Removes parallel loops in which at least one lower/upper bound pair consists +/// of the same values - such loops have an empty iteration domain. +struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> { + using OpRewritePattern<ParallelOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ParallelOp op, + PatternRewriter &rewriter) const override { + for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) { + if (std::get<0>(dim) == std::get<1>(dim)) { + rewriter.replaceOp(op, op.initVals()); + return success(); + } + } + return failure(); + } +}; + } // namespace void ParallelOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert<CollapseSingleIterationLoops>(context); + results.insert<CollapseSingleIterationLoops, RemoveEmptyParallelLoops>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index faac86b94cdb..d57563461241 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -32,30 +32,6 @@ func @single_iteration(%A: memref<?x?x?xi32>) { // ----- -func @no_iteration(%A: memref<?x?xi32>) { - %c0 = constant 0 : index - %c1 = constant 1 : index - scf.parallel (%i0, %i1) = (%c0, %c0) to (%c1, %c0) step (%c1, %c1) { - %c42 = constant 42 : i32 - store %c42, %A[%i0, %i1] : memref<?x?xi32> - scf.yield - } - return -} - -// CHECK-LABEL: func @no_iteration( -// CHECK-SAME: [[ARG0:%.*]]: memref<?x?xi32>) { -// CHECK: [[C0:%.*]] = constant 0 : index -// CHECK: [[C1:%.*]] = constant 1 : index -// CHECK: [[C42:%.*]] = constant 42 : i32 -// CHECK: scf.parallel ([[V1:%.*]]) = ([[C0]]) to ([[C0]]) step ([[C1]]) { -// CHECK: store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V1]]] : memref<?x?xi32> -// CHECK: scf.yield -// CHECK: } -// CHECK: return - -// ----- - func @one_unused(%cond: i1) -> (index) { %c0 = constant 0 : index %c1 = constant 1 : index @@ -241,6 +217,22 @@ func @remove_zero_iteration_loop() { return } +// CHECK-LABEL: @remove_zero_iteration_loop_vals +func @remove_zero_iteration_loop_vals(%arg0: index) { + %c2 = constant 2 : index + // CHECK: %[[INIT:.*]] = "test.init" + %init = "test.init"() : () -> i32 + // CHECK-NOT: scf.for + // CHECK-NOT: test.op + %0 = scf.for %i = %arg0 to %arg0 step %c2 iter_args(%arg = %init) -> (i32) { + %1 = "test.op"(%i, %arg) : (index, i32) -> i32 + scf.yield %1 : i32 + } + // CHECK: "test.consume"(%[[INIT]]) + "test.consume"(%0) : (i32) -> () + return +} + // CHECK-LABEL: @replace_single_iteration_loop func @replace_single_iteration_loop() { // CHECK: %[[LB:.*]] = constant 42 @@ -278,3 +270,24 @@ func @replace_single_iteration_loop_non_unit_step() { "test.consume"(%0) : (i32) -> () return } + +// CHECK-LABEL: @remove_empty_parallel_loop +func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) { + // CHECK: %[[INIT:.*]] = "test.init" + %init = "test.init"() : () -> f32 + // CHECK-NOT: scf.parallel + // CHECK-NOT: test.produce + // CHECK-NOT: test.transform + %0 = scf.parallel (%i, %j, %k) = (%lb, %ub, %lb) to (%ub, %ub, %ub) step (%s, %s, %s) init(%init) -> f32 { + %1 = "test.produce"() : () -> f32 + scf.reduce(%1) : f32 { + ^bb0(%lhs: f32, %rhs: f32): + %2 = "test.transform"(%lhs, %rhs) : (f32, f32) -> f32 + scf.reduce.return %2 : f32 + } + scf.yield + } + // CHECK: "test.consume"(%[[INIT]]) + "test.consume"(%0) : (f32) -> () + return +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits