https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/82250
This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver. A few existing options are moved to this objects, simplifying the dialect conversion API. >From 819e5f95ed0857e88972501cb10b9931b1e91a1c Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Mon, 19 Feb 2024 14:02:14 +0000 Subject: [PATCH] [mlir][Transforms] Encapsulate dialect conversion options in `ConversionConfig` This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver. A few existing options are moved to this objects, simplifying the dialect conversion API. --- .../mlir/Transforms/DialectConversion.h | 75 ++++++---- .../Transforms/Utils/DialectConversion.cpp | 129 +++++++++--------- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 14 +- 3 files changed, 118 insertions(+), 100 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5c91a9498b35d4..8da5dcb0be3fd0 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -24,6 +24,7 @@ namespace mlir { // Forward declarations. class Attribute; class Block; +struct ConversionConfig; class ConversionPatternRewriter; class MLIRContext; class Operation; @@ -770,7 +771,8 @@ class ConversionPatternRewriter final : public PatternRewriter { /// Conversion pattern rewriters must not be used outside of dialect /// conversions. They apply some IR rewrites in a delayed fashion and could /// bring the IR into an inconsistent state when used standalone. - explicit ConversionPatternRewriter(MLIRContext *ctx); + explicit ConversionPatternRewriter(MLIRContext *ctx, + const ConversionConfig &config); // Hide unsupported pattern rewriter API. using OpBuilder::setListener; @@ -1070,6 +1072,30 @@ class PDLConversionConfig final { #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH +//===----------------------------------------------------------------------===// +// ConversionConfig +//===----------------------------------------------------------------------===// + +/// Dialect conversion configuration. +struct ConversionConfig { + /// An optional callback used to notify about match failure diagnostics during + /// the conversion. Diagnostics are only reported to this callback may only be + /// available in debug mode. + function_ref<void(Diagnostic &)> notifyCallback = nullptr; + + /// Partial conversion only. All operations that are found not to be + /// legalizable are placed in this set. (Note that if there is an op + /// explicitly marked as illegal, the conversion terminates and the set will + /// not necessarily be complete.) + DenseSet<Operation *> *unlegalizedOps = nullptr; + + /// Analysis conversion only. All operations that are found to be legalizable + /// are placed in this set. Note that no actual rewrites are applied to the + /// IR during an analysis conversion and only pre-existing operations are + /// added to the set. + DenseSet<Operation *> *legalizableOps = nullptr; +}; + //===----------------------------------------------------------------------===// // Op Conversion Entry Points //===----------------------------------------------------------------------===// @@ -1083,19 +1109,16 @@ class PDLConversionConfig final { /// Apply a partial conversion on the given operations and all nested /// operations. This method converts as many operations to the target as /// possible, ignoring operations that failed to legalize. This method only -/// returns failure if there ops explicitly marked as illegal. If an -/// `unconvertedOps` set is provided, all operations that are found not to be -/// legalizable to the given `target` are placed within that set. (Note that if -/// there is an op explicitly marked as illegal, the conversion terminates and -/// the `unconvertedOps` set will not necessarily be complete.) +/// returns failure if there ops explicitly marked as illegal. LogicalResult -applyPartialConversion(ArrayRef<Operation *> ops, const ConversionTarget &target, +applyPartialConversion(ArrayRef<Operation *> ops, + const ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> *unconvertedOps = nullptr); + ConversionConfig config = ConversionConfig()); LogicalResult applyPartialConversion(Operation *op, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> *unconvertedOps = nullptr); + ConversionConfig config = ConversionConfig()); /// Apply a complete conversion on the given operations, and all nested /// operations. This method returns failure if the conversion of any operation @@ -1103,31 +1126,27 @@ applyPartialConversion(Operation *op, const ConversionTarget &target, /// within 'ops'. LogicalResult applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target, - const FrozenRewritePatternSet &patterns); + const FrozenRewritePatternSet &patterns, + ConversionConfig config = ConversionConfig()); LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target, - const FrozenRewritePatternSet &patterns); + const FrozenRewritePatternSet &patterns, + ConversionConfig config = ConversionConfig()); /// Apply an analysis conversion on the given operations, and all nested /// operations. This method analyzes which operations would be successfully /// converted to the target if a conversion was applied. All operations that /// were found to be legalizable to the given 'target' are placed within the -/// provided 'convertedOps' set; note that no actual rewrites are applied to the -/// operations on success and only pre-existing operations are added to the set. -/// This method only returns failure if there are unreachable blocks in any of -/// the regions nested within 'ops'. There's an additional argument -/// `notifyCallback` which is used for collecting match failure diagnostics -/// generated during the conversion. Diagnostics are only reported to this -/// callback may only be available in debug mode. -LogicalResult applyAnalysisConversion( - ArrayRef<Operation *> ops, ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> &convertedOps, - function_ref<void(Diagnostic &)> notifyCallback = nullptr); -LogicalResult applyAnalysisConversion( - Operation *op, ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> &convertedOps, - function_ref<void(Diagnostic &)> notifyCallback = nullptr); +/// provided 'config.legalizableOps' set; note that no actual rewrites are +/// applied to the operations on success. This method only returns failure if +/// there are unreachable blocks in any of the regions nested within 'ops'. +LogicalResult +applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + ConversionConfig config = ConversionConfig()); +LogicalResult +applyAnalysisConversion(Operation *op, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + ConversionConfig config = ConversionConfig()); } // namespace mlir #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 6cf178e149be7f..30fc2298b3deb3 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -224,6 +224,8 @@ class IRRewrite { /// Erase the given block (unless it was already erased). void eraseBlock(Block *block); + const ConversionConfig &getConfig() const; + const Kind kind; ConversionPatternRewriterImpl &rewriterImpl; }; @@ -723,9 +725,10 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) { namespace mlir { namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { - explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter) + explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter, + const ConversionConfig &config) : rewriter(rewriter), eraseRewriter(rewriter.getContext()), - notifyCallback(nullptr) {} + config(config) {} //===--------------------------------------------------------------------===// // State Management @@ -931,10 +934,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// converting the arguments of blocks within that region. DenseMap<Region *, const TypeConverter *> regionToConverter; - /// This allows the user to collect the match failure message. - function_ref<void(Diagnostic &)> notifyCallback; - - DenseSet<Operation *> *trackedOps = nullptr; + /// Dialect conversion configuration. + const ConversionConfig &config; #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't @@ -957,6 +958,10 @@ void IRRewrite::eraseBlock(Block *block) { rewriterImpl.eraseRewriter.eraseBlock(block); } +const ConversionConfig &IRRewrite::getConfig() const { + return rewriterImpl.config; +} + void BlockTypeConversionRewrite::commit() { // Process the remapping for each of the original arguments. for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { @@ -1074,8 +1079,8 @@ void ReplaceOperationRewrite::commit() { if (Value newValue = rewriterImpl.mapping.lookupOrNull(result, result.getType())) result.replaceAllUsesWith(newValue); - if (rewriterImpl.trackedOps) - rewriterImpl.trackedOps->erase(op); + if (getConfig().unlegalizedOps) + getConfig().unlegalizedOps->erase(op); // Do not erase the operation yet. It may still be referenced in `mapping`. op->getBlock()->getOperations().remove(op); } @@ -1510,8 +1515,8 @@ void ConversionPatternRewriterImpl::notifyMatchFailure( Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); logger.startLine() << "** Failure : " << diag.str() << "\n"; - if (notifyCallback) - notifyCallback(diag); + if (config.notifyCallback) + config.notifyCallback(diag); }); } @@ -1519,9 +1524,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure( // ConversionPatternRewriter //===----------------------------------------------------------------------===// -ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) +ConversionPatternRewriter::ConversionPatternRewriter( + MLIRContext *ctx, const ConversionConfig &config) : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(*this)) { + impl(new detail::ConversionPatternRewriterImpl(*this, config)) { setListener(impl.get()); } @@ -1972,12 +1978,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op, assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); LLVM_DEBUG({ logFailure(rewriterImpl.logger, "pattern failed to match"); - if (rewriterImpl.notifyCallback) { + if (rewriterImpl.config.notifyCallback) { Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); diag << "Failed to apply pattern \"" << pattern.getDebugName() << "\" on op:\n" << *op; - rewriterImpl.notifyCallback(diag); + rewriterImpl.config.notifyCallback(diag); } }); rewriterImpl.resetState(curState); @@ -2365,14 +2371,12 @@ namespace mlir { struct OperationConverter { explicit OperationConverter(const ConversionTarget &target, const FrozenRewritePatternSet &patterns, - OpConversionMode mode, - DenseSet<Operation *> *trackedOps = nullptr) - : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} + const ConversionConfig &config, + OpConversionMode mode) + : opLegalizer(target, patterns), config(config), mode(mode) {} /// Converts the given operations to the conversion target. - LogicalResult - convertOperations(ArrayRef<Operation *> ops, - function_ref<void(Diagnostic &)> notifyCallback = nullptr); + LogicalResult convertOperations(ArrayRef<Operation *> ops); private: /// Converts an operation with the given rewriter. @@ -2409,14 +2413,11 @@ struct OperationConverter { /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; + /// Dialect conversion configuration. + ConversionConfig config; + /// The conversion mode to use when legalizing operations. OpConversionMode mode; - - /// A set of pre-existing operations. When mode == OpConversionMode::Analysis, - /// this is populated with ops found to be legalizable to the target. - /// When mode == OpConversionMode::Partial, this is populated with ops found - /// *not* to be legalizable to the target. - DenseSet<Operation *> *trackedOps; }; } // namespace mlir @@ -2430,28 +2431,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, return op->emitError() << "failed to legalize operation '" << op->getName() << "'"; // Partial conversions allow conversions to fail iff the operation was not - // explicitly marked as illegal. If the user provided a nonlegalizableOps - // set, non-legalizable ops are included. + // explicitly marked as illegal. If the user provided a `unlegalizedOps` + // set, non-legalizable ops are added to that set. if (mode == OpConversionMode::Partial) { if (opLegalizer.isIllegal(op)) return op->emitError() << "failed to legalize operation '" << op->getName() << "' that was explicitly marked illegal"; - if (trackedOps) - trackedOps->insert(op); + if (config.unlegalizedOps) + config.unlegalizedOps->insert(op); } } else if (mode == OpConversionMode::Analysis) { // Analysis conversions don't fail if any operations fail to legalize, // they are only interested in the operations that were successfully // legalized. - trackedOps->insert(op); + if (config.legalizableOps) + config.legalizableOps->insert(op); } return success(); } -LogicalResult OperationConverter::convertOperations( - ArrayRef<Operation *> ops, - function_ref<void(Diagnostic &)> notifyCallback) { +LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { if (ops.empty()) return success(); const ConversionTarget &target = opLegalizer.getTarget(); @@ -2472,10 +2472,8 @@ LogicalResult OperationConverter::convertOperations( } // Convert each operation and discard rewrites on failure. - ConversionPatternRewriter rewriter(ops.front()->getContext()); + ConversionPatternRewriter rewriter(ops.front()->getContext(), config); ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - rewriterImpl.notifyCallback = notifyCallback; - rewriterImpl.trackedOps = trackedOps; for (auto *op : toConvert) if (failed(convert(rewriter, op))) @@ -3461,56 +3459,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { //===----------------------------------------------------------------------===// // Partial Conversion -LogicalResult -mlir::applyPartialConversion(ArrayRef<Operation *> ops, - const ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> *unconvertedOps) { - OperationConverter opConverter(target, patterns, OpConversionMode::Partial, - unconvertedOps); +LogicalResult mlir::applyPartialConversion( + ArrayRef<Operation *> ops, const ConversionTarget &target, + const FrozenRewritePatternSet &patterns, ConversionConfig config) { + OperationConverter opConverter(target, patterns, config, + OpConversionMode::Partial); return opConverter.convertOperations(ops); } LogicalResult mlir::applyPartialConversion(Operation *op, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> *unconvertedOps) { - return applyPartialConversion(llvm::ArrayRef(op), target, patterns, - unconvertedOps); + ConversionConfig config) { + return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config); } //===----------------------------------------------------------------------===// // Full Conversion -LogicalResult -mlir::applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target, - const FrozenRewritePatternSet &patterns) { - OperationConverter opConverter(target, patterns, OpConversionMode::Full); +LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops, + const ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + ConversionConfig config) { + OperationConverter opConverter(target, patterns, config, + OpConversionMode::Full); return opConverter.convertOperations(ops); } -LogicalResult -mlir::applyFullConversion(Operation *op, const ConversionTarget &target, - const FrozenRewritePatternSet &patterns) { - return applyFullConversion(llvm::ArrayRef(op), target, patterns); +LogicalResult mlir::applyFullConversion(Operation *op, + const ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + ConversionConfig config) { + return applyFullConversion(llvm::ArrayRef(op), target, patterns, config); } //===----------------------------------------------------------------------===// // Analysis Conversion -LogicalResult -mlir::applyAnalysisConversion(ArrayRef<Operation *> ops, - ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> &convertedOps, - function_ref<void(Diagnostic &)> notifyCallback) { - OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, - &convertedOps); - return opConverter.convertOperations(ops, notifyCallback); +LogicalResult mlir::applyAnalysisConversion( + ArrayRef<Operation *> ops, ConversionTarget &target, + const FrozenRewritePatternSet &patterns, ConversionConfig config) { + OperationConverter opConverter(target, patterns, config, + OpConversionMode::Analysis); + return opConverter.convertOperations(ops); } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, const FrozenRewritePatternSet &patterns, - DenseSet<Operation *> &convertedOps, - function_ref<void(Diagnostic &)> notifyCallback) { - return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, - convertedOps, notifyCallback); + ConversionConfig config) { + return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 1c02232b8adbb1..c04d5cc80446f4 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1135,8 +1135,10 @@ struct TestLegalizePatternDriver // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet<Operation *> unlegalizedOps; - if (failed(applyPartialConversion( - getOperation(), target, std::move(patterns), &unlegalizedOps))) { + ConversionConfig config; + config.unlegalizedOps = &unlegalizedOps; + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns), config))) { getOperation()->emitRemark() << "applyPartialConversion failed"; } // Emit remarks for each legalizable operation. @@ -1164,8 +1166,10 @@ struct TestLegalizePatternDriver // Analyze the convertible operations. DenseSet<Operation *> legalizedOps; + ConversionConfig config; + config.legalizableOps = &legalizedOps; if (failed(applyAnalysisConversion(getOperation(), target, - std::move(patterns), legalizedOps))) + std::move(patterns), config))) return signalPassFailure(); // Emit remarks for each legalizable operation. @@ -1789,8 +1793,10 @@ struct TestMergeBlocksPatternDriver }); DenseSet<Operation *> unlegalizedOps; + ConversionConfig config; + config.unlegalizedOps = &unlegalizedOps; (void)applyPartialConversion(getOperation(), target, std::move(patterns), - &unlegalizedOps); + config); for (auto *op : unlegalizedOps) op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits