Author: Frederik Gossen Date: 2020-11-27T10:08:56+01:00 New Revision: 6484567f14881003a7c46d1587dbb0cf8082282a
URL: https://github.com/llvm/llvm-project/commit/6484567f14881003a7c46d1587dbb0cf8082282a DIFF: https://github.com/llvm/llvm-project/commit/6484567f14881003a7c46d1587dbb0cf8082282a.diff LOG: [MLIR][SCF] Find all innermost loops for parallel loop tiling Overcome the assumption that parallel loops are only nested in other parallel loops. Differential Revision: https://reviews.llvm.org/D92188 Added: Modified: mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp mlir/test/Dialect/SCF/parallel-loop-tiling.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp index 7bcc989a5b28..7bd589214f4c 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -22,15 +22,15 @@ using namespace mlir::scf; /// Tile a parallel loop of the form /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) -/// step (%arg4, %arg5) +/// step (%arg4, %arg5) /// /// into /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) -/// step (%arg4*tileSize[0], -/// %arg5*tileSize[1]) +/// step (%arg4*tileSize[0], +/// %arg5*tileSize[1]) /// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0) -/// min(%arg5*tileSize[1], %arg3-%i1)) -/// step (%arg4, %arg5) +/// min(%arg5*tileSize[1], %arg3-%i1)) +/// step (%arg4, %arg5) /// /// where the uses of %i0 and %i1 in the loop body are replaced by /// %i0 + j0 and %i1 + %j1. @@ -126,17 +126,27 @@ void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) { op.erase(); } -/// Get a list of most nested parallel loops. Assumes that ParallelOps are -/// only directly nested. -static bool getInnermostNestedLoops(Block *block, - SmallVectorImpl<ParallelOp> &loops) { - bool hasInnerLoop = false; - for (auto parallelOp : block->getOps<ParallelOp>()) { - hasInnerLoop = true; - if (!getInnermostNestedLoops(parallelOp.getBody(), loops)) - loops.push_back(parallelOp); +/// Get a list of most nested parallel loops. +static bool getInnermostPloops(Operation *rootOp, + SmallVectorImpl<ParallelOp> &result) { + assert(rootOp != nullptr && "Root operation must not be a nullptr."); + bool rootEnclosesPloops = false; + for (Region ®ion : rootOp->getRegions()) { + for (Block &block : region.getBlocks()) { + for (Operation &op : block) { + bool enclosesPloops = getInnermostPloops(&op, result); + rootEnclosesPloops |= enclosesPloops; + if (auto ploop = dyn_cast<ParallelOp>(op)) { + rootEnclosesPloops = true; + + // Collect ploop if it is an innermost one. + if (!enclosesPloops) + result.push_back(ploop); + } + } + } } - return hasInnerLoop; + return rootEnclosesPloops; } namespace { @@ -148,14 +158,12 @@ struct ParallelLoopTiling } void runOnFunction() override { - SmallVector<ParallelOp, 2> mostNestedParallelOps; - for (Block &block : getFunction()) { - getInnermostNestedLoops(&block, mostNestedParallelOps); - } - for (ParallelOp pLoop : mostNestedParallelOps) { + SmallVector<ParallelOp, 2> innermostPloops; + getInnermostPloops(getFunction().getOperation(), innermostPloops); + for (ParallelOp ploop : innermostPloops) { // FIXME: Add reduction support. - if (pLoop.getNumReductions() == 0) - tileParallelLoop(pLoop, tileSizes); + if (ploop.getNumReductions() == 0) + tileParallelLoop(ploop, tileSizes); } } }; diff --git a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir index e0dc8344f14d..5d3a676f58ab 100644 --- a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir @@ -112,3 +112,29 @@ func @tile_nested_innermost() { // CHECK: } // CHECK: return // CHECK: } + +// ----- + +func @tile_nested_in_non_ploop() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + scf.for %i = %c0 to %c2 step %c1 { + scf.for %j = %c0 to %c2 step %c1 { + scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + } + } + return +} + +// CHECK-LABEL: func @tile_nested_in_non_ploop +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.parallel +// CHECK: scf.parallel +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits