Author: Alex Zinenko Date: 2020-11-23T23:28:02+01:00 New Revision: f7d033f4d80f476246a70f165e7455639818f907
URL: https://github.com/llvm/llvm-project/commit/f7d033f4d80f476246a70f165e7455639818f907 DIFF: https://github.com/llvm/llvm-project/commit/f7d033f4d80f476246a70f165e7455639818f907.diff LOG: [mlir] Support WsLoopOp in OpenMP to LLVM dialect conversion It is a simple conversion that only requires to change the region argument types, generalize it from ParallelOp. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D91989 Added: Modified: mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir Removed: ################################################################################ diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index cfb553da407c..91e97ca1ec50 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -16,18 +16,23 @@ using namespace mlir; namespace { -struct ParallelOpConversion : public ConvertToLLVMPattern { - explicit ParallelOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(omp::ParallelOp::getOperationName(), context, +/// A pattern that converts the region arguments in a single-region OpenMP +/// operation to the LLVM dialect. The body of the region is not modified and is +/// expected to either be processed by the conversion infrastructure or already +/// contain ops compatible with LLVM dialect types. +template <typename OpType> +struct RegionOpConversion : public ConvertToLLVMPattern { + explicit RegionOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(OpType::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { - auto curOp = cast<omp::ParallelOp>(op); - auto newOp = rewriter.create<omp::ParallelOp>(curOp.getLoc(), TypeRange(), - operands, curOp.getAttrs()); + auto curOp = cast<OpType>(op); + auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands, + curOp.getAttrs()); rewriter.inlineRegionBefore(curOp.region(), newOp.region(), newOp.region().end()); if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter))) @@ -42,7 +47,8 @@ struct ParallelOpConversion : public ConvertToLLVMPattern { void mlir::populateOpenMPToLLVMConversionPatterns( MLIRContext *context, LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert<ParallelOpConversion>(context, converter); + patterns.insert<RegionOpConversion<omp::ParallelOp>, + RegionOpConversion<omp::WsLoopOp>>(context, converter); } namespace { @@ -63,8 +69,8 @@ void ConvertOpenMPToLLVMPass::runOnOperation() { populateOpenMPToLLVMConversionPatterns(context, converter, patterns); LLVMConversionTarget target(getContext()); - target.addDynamicallyLegalOp<omp::ParallelOp>( - [&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); }); + target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>( + [&](Operation *op) { return converter.isLegal(&op->getRegion(0)); }); target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp, omp::BarrierOp, omp::TaskwaitOp>(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir index d38a6ea7e3a9..62ea39f078b2 100644 --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -28,3 +28,22 @@ func @branch_loop() { } return } + +// CHECK-LABEL: @wsloop +// CHECK: (%[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64, %[[ARG3:.*]]: !llvm.i64, %[[ARG4:.*]]: !llvm.i64, %[[ARG5:.*]]: !llvm.i64) +func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { + // CHECK: omp.parallel + omp.parallel { + // CHECK: omp.wsloop + // CHECK: (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]) + "omp.wsloop"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) ( { + // CHECK: ^{{.*}}(%[[ARG6:.*]]: !llvm.i64, %[[ARG7:.*]]: !llvm.i64): + ^bb0(%arg6: index, %arg7: index): // no predecessors + // CHECK: "test.payload"(%[[ARG6]], %[[ARG7]]) : (!llvm.i64, !llvm.i64) -> () + "test.payload"(%arg6, %arg7) : (index, index) -> () + omp.yield + }) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (index, index, index, index, index, index) -> () + omp.terminator + } + return +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits