Author: Lei Zhang Date: 2021-01-11T09:13:06-05:00 New Revision: 55225471d9838e452cfb31e0edae6162b7226221
URL: https://github.com/llvm/llvm-project/commit/55225471d9838e452cfb31e0edae6162b7226221 DIFF: https://github.com/llvm/llvm-project/commit/55225471d9838e452cfb31e0edae6162b7226221.diff LOG: [mlir][linalg] Support permutation when lowering to loop nests Linalg ops are perfect loop nests. When materializing the concrete loop nest, the default order specified by the Linalg op's iterators may not be the best for further CodeGen: targets frequently need to plan the loop order in order to gain better data access. And different targets can have different preferences. So there should exist a way to control the order. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D91795 Added: mlir/test/Dialect/Linalg/loop-order.mlir Modified: mlir/include/mlir/Dialect/Linalg/Passes.td mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h mlir/lib/Dialect/Linalg/Transforms/Loops.cpp Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 14f845589a6f..a20289af3054 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -28,8 +28,8 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> { let options = [ Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool", /*default=*/"false", - "Only folds the one-trip loops from Linalg ops on tensors " - "(for testing purposes only)"> + "Only folds the one-trip loops from Linalg ops on tensors " + "(for testing purposes only)"> ]; let dependentDialects = ["linalg::LinalgDialect"]; } @@ -52,12 +52,24 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; let constructor = "mlir::createConvertLinalgToAffineLoopsPass()"; + let options = [ + ListOption<"interchangeVector", "interchange-vector", "unsigned", + "Permute the loops in the nest following the given " + "interchange vector", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> + ]; let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"]; } def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> { let summary = "Lower the operations from the linalg dialect into loops"; let constructor = "mlir::createConvertLinalgToLoopsPass()"; + let options = [ + ListOption<"interchangeVector", "interchange-vector", "unsigned", + "Permute the loops in the nest following the given " + "interchange vector", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> + ]; let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect", "AffineDialect"]; } @@ -72,6 +84,12 @@ def LinalgLowerToParallelLoops let summary = "Lower the operations from the linalg dialect into parallel " "loops"; let constructor = "mlir::createConvertLinalgToParallelLoopsPass()"; + let options = [ + ListOption<"interchangeVector", "interchange-vector", "unsigned", + "Permute the loops in the nest following the given " + "interchange vector", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> + ]; let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index dc82569aac38..d816414ef8b4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -267,16 +267,28 @@ void vectorizeLinalgOp(OpBuilder &builder, Operation *op); /// Emits a loop nest of `LoopTy` with the proper body for `op`. template <typename LoopTy> -Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op); - -/// Emits a loop nest of `scf.for` with the proper body for `op`. -LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op); - -/// Emits a loop nest of `scf.parallel` with the proper body for `op`. -LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op); +Optional<LinalgLoops> +linalgLowerOpToLoops(OpBuilder &builder, Operation *op, + ArrayRef<unsigned> interchangeVector = {}); + +/// Emits a loop nest of `scf.for` with the proper body for `op`. The generated +/// loop nest will follow the `interchangeVector`-permutated iterator order. If +/// `interchangeVector` is empty, then no permutation happens. +LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op, + ArrayRef<unsigned> interchangeVector = {}); + +/// Emits a loop nest of `scf.parallel` with the proper body for `op`. The +/// generated loop nest will follow the `interchangeVector`-permutated +// iterator order. If `interchangeVector` is empty, then no permutation happens. +LogicalResult +linalgOpToParallelLoops(OpBuilder &builder, Operation *op, + ArrayRef<unsigned> interchangeVector = {}); -/// Emits a loop nest of `affine.for` with the proper body for `op`. -LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op); +/// Emits a loop nest of `affine.for` with the proper body for `op`. The +/// generated loop nest will follow the `interchangeVector`-permutated +// iterator order. If `interchangeVector` is empty, then no permutation happens. +LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op, + ArrayRef<unsigned> interchangeVector = {}); //===----------------------------------------------------------------------===// // Preconditions that ensure the corresponding transformation succeeds and can @@ -587,13 +599,17 @@ enum class LinalgLoweringType { AffineLoops = 2, ParallelLoops = 3 }; + template <typename OpTy> struct LinalgLoweringPattern : public RewritePattern { LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType, LinalgMarker marker = LinalgMarker(), + ArrayRef<unsigned> interchangeVector = {}, PatternBenefit benefit = 1) : RewritePattern(OpTy::getOperationName(), {}, benefit, context), - marker(marker), loweringType(loweringType) {} + marker(marker), loweringType(loweringType), + interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} + // TODO: Move implementation to .cpp once named ops are auto-generated. LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { @@ -603,18 +619,24 @@ struct LinalgLoweringPattern : public RewritePattern { if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); - if (loweringType == LinalgLoweringType::LibraryCall) { + switch (loweringType) { + case LinalgLoweringType::LibraryCall: // TODO: Move lowering to library calls here. return failure(); - } else if (loweringType == LinalgLoweringType::Loops) { - if (failed(linalgOpToLoops(rewriter, op))) + case LinalgLoweringType::Loops: + if (failed(linalgOpToLoops(rewriter, op, interchangeVector))) return failure(); - } else if (loweringType == LinalgLoweringType::AffineLoops) { - if (failed(linalgOpToAffineLoops(rewriter, op))) + break; + case LinalgLoweringType::AffineLoops: + if (failed(linalgOpToAffineLoops(rewriter, op, interchangeVector))) return failure(); - } else if (failed(linalgOpToParallelLoops(rewriter, op))) { - return failure(); + break; + case LinalgLoweringType::ParallelLoops: + if (failed(linalgOpToParallelLoops(rewriter, op, interchangeVector))) + return failure(); + break; } + rewriter.eraseOp(op); return success(); } @@ -625,6 +647,8 @@ struct LinalgLoweringPattern : public RewritePattern { /// Controls whether the pattern lowers to library calls, scf.for, affine.for /// or scf.parallel. LinalgLoweringType loweringType; + /// Permutated loop order in the generated loop nest. + SmallVector<unsigned, 4> interchangeVector; }; /// Linalg generalization patterns diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 3a5b79176959..09b5c5ee562b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -23,7 +23,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" - #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -505,10 +504,10 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, } template <typename LoopTy> -static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, - OpBuilder &builder) { +static Optional<LinalgLoops> +linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, + ArrayRef<unsigned> interchangeVector) { using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy; - ScopedContext scope(builder, op->getLoc()); // The flattened loopToOperandRangesMaps is expected to be an invertible @@ -516,10 +515,20 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op, auto linalgOp = cast<LinalgOp>(op); assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); + auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc()); + auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); + + if (!interchangeVector.empty()) { + assert(interchangeVector.size() == loopRanges.size()); + assert(interchangeVector.size() == iteratorTypes.size()); + applyPermutationToVector(loopRanges, interchangeVector); + applyPermutationToVector(iteratorTypes, interchangeVector); + } + SmallVector<Value, 4> allIvs; GenerateLoopNest<LoopTy>::doit( - loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(), + loopRanges, /*iterInitArgs=*/{}, iteratorTypes, [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { assert(iterArgs.empty() && "unexpected iterArgs"); allIvs.append(ivs.begin(), ivs.end()); @@ -552,26 +561,33 @@ namespace { template <typename LoopType> class LinalgRewritePattern : public RewritePattern { public: - LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + LinalgRewritePattern(ArrayRef<unsigned> interchangeVector) + : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), + interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (!isa<LinalgOp>(op)) return failure(); - if (!linalgOpToLoopsImpl<LoopType>(op, rewriter)) + if (!linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector)) return failure(); rewriter.eraseOp(op); return success(); } + +private: + SmallVector<unsigned, 4> interchangeVector; }; struct FoldAffineOp; } // namespace template <typename LoopType> -static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) { +static void lowerLinalgToLoopsImpl(FuncOp funcOp, + ArrayRef<unsigned> interchangeVector) { + MLIRContext *context = funcOp.getContext(); OwningRewritePatternList patterns; - patterns.insert<LinalgRewritePattern<LoopType>>(); + patterns.insert<LinalgRewritePattern<LoopType>>(interchangeVector); DimOp::getCanonicalizationPatterns(patterns, context); AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.insert<FoldAffineOp>(context); @@ -620,20 +636,20 @@ struct FoldAffineOp : public RewritePattern { struct LowerToAffineLoops : public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> { void runOnFunction() override { - lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext()); + lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), interchangeVector); } }; struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> { void runOnFunction() override { - lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext()); + lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), interchangeVector); } }; struct LowerToParallelLoops : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { void runOnFunction() override { - lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), &getContext()); + lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), interchangeVector); } }; } // namespace @@ -654,38 +670,43 @@ mlir::createConvertLinalgToAffineLoopsPass() { /// Emits a loop nest with the proper body for `op`. template <typename LoopTy> -Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, - Operation *op) { - return linalgOpToLoopsImpl<LoopTy>(op, builder); +Optional<LinalgLoops> +mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op, + ArrayRef<unsigned> interchangeVector) { + return linalgOpToLoopsImpl<LoopTy>(op, builder, interchangeVector); } +template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<AffineForOp>( + OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector); +template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<scf::ForOp>( + OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector); template Optional<LinalgLoops> -mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder, - Operation *op); -template Optional<LinalgLoops> -mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder, - Operation *op); -template Optional<LinalgLoops> -mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder, - Operation *op); +mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>( + OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector); /// Emits a loop nest of `affine.for` with the proper body for `op`. -LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, - Operation *op) { - Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op); +LogicalResult +mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op, + ArrayRef<unsigned> interchangeVector) { + Optional<LinalgLoops> loops = + linalgLowerOpToLoops<AffineForOp>(builder, op, interchangeVector); return loops ? success() : failure(); } /// Emits a loop nest of `scf.for` with the proper body for `op`. -LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { - Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op); +LogicalResult +mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op, + ArrayRef<unsigned> interchangeVector) { + Optional<LinalgLoops> loops = + linalgLowerOpToLoops<scf::ForOp>(builder, op, interchangeVector); return loops ? success() : failure(); } /// Emits a loop nest of `scf.parallel` with the proper body for `op`. -LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, - Operation *op) { +LogicalResult +mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op, + ArrayRef<unsigned> interchangeVector) { Optional<LinalgLoops> loops = - linalgLowerOpToLoops<scf::ParallelOp>(builder, op); + linalgLowerOpToLoops<scf::ParallelOp>(builder, op, interchangeVector); return loops ? success() : failure(); } diff --git a/mlir/test/Dialect/Linalg/loop-order.mlir b/mlir/test/Dialect/Linalg/loop-order.mlir new file mode 100644 index 000000000000..d1ff47977c35 --- /dev/null +++ b/mlir/test/Dialect/Linalg/loop-order.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=LOOP %s +// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=PARALLEL %s +// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=AFFINE %s + +func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) { + linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32> + return +} + +// LOOP: scf.for %{{.*}} = %c0 to %c5 step %c1 +// LOOP: scf.for %{{.*}} = %c0 to %c1 step %c1 +// LOOP: scf.for %{{.*}} = %c0 to %c4 step %c1 +// LOOP: scf.for %{{.*}} = %c0 to %c2 step %c1 +// LOOP: scf.for %{{.*}} = %c0 to %c3 step %c1 + +// PARALLEL: scf.parallel +// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3) + +// AFFINE: affine.for %{{.*}} = 0 to 5 +// AFFINE: affine.for %{{.*}} = 0 to 1 +// AFFINE: affine.for %{{.*}} = 0 to 4 +// AFFINE: affine.for %{{.*}} = 0 to 2 +// AFFINE: affine.for %{{.*}} = 0 to 3 + _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits