https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/84131
>From 0aef4b91f6aad0335e7eae2849edffd4338f4c40 Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Fri, 8 Mar 2024 03:42:15 +0000 Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end --- mlir/include/mlir/IR/PatternMatch.h | 30 ++++++++-- mlir/include/mlir/Rewrite/PatternApplicator.h | 6 +- mlir/lib/Rewrite/PatternApplicator.cpp | 6 +- .../Transforms/Utils/DialectConversion.cpp | 29 +++++++--- .../Utils/GreedyPatternRewriteDriver.cpp | 57 +++++++++++-------- 5 files changed, 85 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e3500b3f9446d8..49544c42790d4d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder { /// Note: This notification is not triggered when unlinking an operation. virtual void notifyOperationErased(Operation *op) {} - /// Notify the listener that the pattern failed to match the given - /// operation, and provide a callback to populate a diagnostic with the - /// reason why the failure occurred. This method allows for derived - /// listeners to optionally hook into the reason why a rewrite failed, and - /// display it to users. + /// Notify the listener that the specified pattern is about to be applied + /// at the specified root operation. + virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {} + + /// Notify the listener that a pattern application finished with the + /// specified status. "success" indicates that the pattern was applied + /// successfully. "failure" indicates that the pattern could not be + /// applied. The pattern may have communicated the reason for the failure + /// with `notifyMatchFailure`. + virtual void notifyPatternEnd(const Pattern &pattern, + LogicalResult status) {} + + /// Notify the listener that the pattern failed to match, and provide a + /// callback to populate a diagnostic with the reason why the failure + /// occurred. This method allows for derived listeners to optionally hook + /// into the reason why a rewrite failed, and display it to users. virtual void notifyMatchFailure(Location loc, function_ref<void(Diagnostic &)> reasonCallback) {} @@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder { if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener)) rewriteListener->notifyOperationErased(op); } + void notifyPatternBegin(const Pattern &pattern, Operation *op) override { + if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener)) + rewriteListener->notifyPatternBegin(pattern, op); + } + void notifyPatternEnd(const Pattern &pattern, + LogicalResult status) override { + if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener)) + rewriteListener->notifyPatternEnd(pattern, status); + } void notifyMatchFailure( Location loc, function_ref<void(Diagnostic &)> reasonCallback) override { diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h index f7871f819a273b..c767bf8fee9073 100644 --- a/mlir/include/mlir/Rewrite/PatternApplicator.h +++ b/mlir/include/mlir/Rewrite/PatternApplicator.h @@ -68,9 +68,9 @@ class PatternApplicator { /// invalidate the match and try another pattern. LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, - function_ref<bool(const Pattern &)> canApply = {}, - function_ref<void(const Pattern &)> onFailure = {}, - function_ref<LogicalResult(const Pattern &)> onSuccess = {}); + std::function<bool(const Pattern &)> canApply = {}, + std::function<void(const Pattern &)> onFailure = {}, + std::function<LogicalResult(const Pattern &)> onSuccess = {}); /// Apply a cost model to the patterns within this applicator. void applyCostModel(CostModel model); diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp index ea43f8a147d479..fecfb030a77fbf 100644 --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -129,9 +129,9 @@ void PatternApplicator::walkAllPatterns( LogicalResult PatternApplicator::matchAndRewrite( Operation *op, PatternRewriter &rewriter, - function_ref<bool(const Pattern &)> canApply, - function_ref<void(const Pattern &)> onFailure, - function_ref<LogicalResult(const Pattern &)> onSuccess) { + std::function<bool(const Pattern &)> canApply, + std::function<void(const Pattern &)> onFailure, + std::function<LogicalResult(const Pattern &)> onSuccess) { // Before checking native patterns, first match against the bytecode. This // won't automatically perform any rewrites so there is no need to worry about // conflicts. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c1a261eab8487d..cd49bd121a62e5 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1856,7 +1856,8 @@ class OperationLegalizer { using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(const ConversionTarget &targetInfo, - const FrozenRewritePatternSet &patterns); + const FrozenRewritePatternSet &patterns, + const ConversionConfig &config); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -1948,12 +1949,16 @@ class OperationLegalizer { /// The pattern applicator to use for conversions. PatternApplicator applicator; + + /// Dialect conversion configuration. + const ConversionConfig &config; }; } // namespace OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo, - const FrozenRewritePatternSet &patterns) - : target(targetInfo), applicator(patterns) { + const FrozenRewritePatternSet &patterns, + const ConversionConfig &config) + : target(targetInfo), applicator(patterns), config(config) { // The set of patterns that can be applied to illegal operations to transform // them into legal ones. DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; @@ -2098,7 +2103,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern &pattern) { - return canApplyPattern(op, pattern, rewriter); + bool canApply = canApplyPattern(op, pattern, rewriter); + if (canApply && config.listener) + config.listener->notifyPatternBegin(pattern, op); + return canApply; }; // Functor that cleans up the rewriter state after a pattern failed to match. @@ -2115,6 +2123,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op, rewriterImpl.config.notifyCallback(diag); } }); + if (config.listener) + config.listener->notifyPatternEnd(pattern, failure()); rewriterImpl.resetState(curState); appliedPatterns.erase(&pattern); }; @@ -2127,6 +2137,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op, appliedPatterns.erase(&pattern); if (failed(result)) rewriterImpl.resetState(curState); + if (config.listener) + config.listener->notifyPatternEnd(pattern, result); return result; }; @@ -2502,7 +2514,8 @@ struct OperationConverter { const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode) - : opLegalizer(target, patterns), config(config), mode(mode) {} + : config(config), opLegalizer(target, patterns, this->config), + mode(mode) {} /// Converts the given operations to the conversion target. LogicalResult convertOperations(ArrayRef<Operation *> ops); @@ -2539,12 +2552,12 @@ struct OperationConverter { ConversionPatternRewriterImpl &rewriterImpl, const DenseMap<Value, SmallVector<Value>> &inverseMapping); - /// The legalizer to use when converting operations. - OperationLegalizer opLegalizer; - /// Dialect conversion configuration. ConversionConfig config; + /// The legalizer to use when converting operations. + OperationLegalizer opLegalizer; + /// The conversion mode to use when legalizing operations. OpConversionMode mode; }; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 51d2f5e01b7235..3b42516e040013 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -562,30 +562,39 @@ bool GreedyPatternRewriteDriver::processWorklist() { // Try to match one of the patterns. The rewriter is automatically // notified of any necessary changes, so there is nothing else to do // here. + std::function<bool(const Pattern &)> canApply = nullptr; + std::function<void(const Pattern &)> onFailure = nullptr; + std::function<LogicalResult(const Pattern &)> onSuccess = nullptr; + bool debugBuild = false; #ifndef NDEBUG - auto canApply = [&](const Pattern &pattern) { - LLVM_DEBUG({ - logger.getOStream() << "\n"; - logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" - << op->getName() << " -> ("; - llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); - logger.getOStream() << ")' {\n"; - logger.indent(); - }); - return true; - }; - auto onFailure = [&](const Pattern &pattern) { - LLVM_DEBUG(logResult("failure", "pattern failed to match")); - }; - auto onSuccess = [&](const Pattern &pattern) { - LLVM_DEBUG(logResult("success", "pattern applied successfully")); - return success(); - }; -#else - function_ref<bool(const Pattern &)> canApply = {}; - function_ref<void(const Pattern &)> onFailure = {}; - function_ref<LogicalResult(const Pattern &)> onSuccess = {}; -#endif + debugBuild = true; +#endif // NDEBUG + if (debugBuild || config.listener) { + canApply = [&](const Pattern &pattern) { + LLVM_DEBUG({ + logger.getOStream() << "\n"; + logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" + << op->getName() << " -> ("; + llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); + logger.getOStream() << ")' {\n"; + logger.indent(); + }); + if (config.listener) + config.listener->notifyPatternBegin(pattern, op); + return true; + }; + onFailure = [&](const Pattern &pattern) { + LLVM_DEBUG(logResult("failure", "pattern failed to match")); + if (config.listener) + config.listener->notifyPatternEnd(pattern, failure()); + }; + onSuccess = [&](const Pattern &pattern) { + LLVM_DEBUG(logResult("success", "pattern applied successfully")); + if (config.listener) + config.listener->notifyPatternEnd(pattern, success()); + return success(); + }; + } #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (config.scope) { @@ -731,7 +740,7 @@ void GreedyPatternRewriteDriver::notifyMatchFailure( LLVM_DEBUG({ Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); - logger.startLine() << "** Failure : " << diag.str() << "\n"; + logger.startLine() << "** Match Failure : " << diag.str() << "\n"; }); if (config.listener) config.listener->notifyMatchFailure(loc, reasonCallback); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits