https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/84131
>From a65d640a0ca2c6810da0878ed42db39f27cebfe1 Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Fri, 8 Mar 2024 07:19:33 +0000 Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end --- mlir/include/mlir/IR/PatternMatch.h | 30 ++++++++++++++--- .../Transforms/Utils/DialectConversion.cpp | 29 +++++++++++----- .../Utils/GreedyPatternRewriteDriver.cpp | 33 +++++++++++++------ 3 files changed, 69 insertions(+), 23 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e3500b3f9446d8..49544c42790d4d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder { /// Note: This notification is not triggered when unlinking an operation. virtual void notifyOperationErased(Operation *op) {} - /// Notify the listener that the pattern failed to match the given - /// operation, and provide a callback to populate a diagnostic with the - /// reason why the failure occurred. This method allows for derived - /// listeners to optionally hook into the reason why a rewrite failed, and - /// display it to users. + /// Notify the listener that the specified pattern is about to be applied + /// at the specified root operation. + virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {} + + /// Notify the listener that a pattern application finished with the + /// specified status. "success" indicates that the pattern was applied + /// successfully. "failure" indicates that the pattern could not be + /// applied. The pattern may have communicated the reason for the failure + /// with `notifyMatchFailure`. + virtual void notifyPatternEnd(const Pattern &pattern, + LogicalResult status) {} + + /// Notify the listener that the pattern failed to match, and provide a + /// callback to populate a diagnostic with the reason why the failure + /// occurred. This method allows for derived listeners to optionally hook + /// into the reason why a rewrite failed, and display it to users. virtual void notifyMatchFailure(Location loc, function_ref<void(Diagnostic &)> reasonCallback) {} @@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder { if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener)) rewriteListener->notifyOperationErased(op); } + void notifyPatternBegin(const Pattern &pattern, Operation *op) override { + if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener)) + rewriteListener->notifyPatternBegin(pattern, op); + } + void notifyPatternEnd(const Pattern &pattern, + LogicalResult status) override { + if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener)) + rewriteListener->notifyPatternEnd(pattern, status); + } void notifyMatchFailure( Location loc, function_ref<void(Diagnostic &)> reasonCallback) override { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c1a261eab8487d..cd49bd121a62e5 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1856,7 +1856,8 @@ class OperationLegalizer { using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(const ConversionTarget &targetInfo, - const FrozenRewritePatternSet &patterns); + const FrozenRewritePatternSet &patterns, + const ConversionConfig &config); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -1948,12 +1949,16 @@ class OperationLegalizer { /// The pattern applicator to use for conversions. PatternApplicator applicator; + + /// Dialect conversion configuration. + const ConversionConfig &config; }; } // namespace OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo, - const FrozenRewritePatternSet &patterns) - : target(targetInfo), applicator(patterns) { + const FrozenRewritePatternSet &patterns, + const ConversionConfig &config) + : target(targetInfo), applicator(patterns), config(config) { // The set of patterns that can be applied to illegal operations to transform // them into legal ones. DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; @@ -2098,7 +2103,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern &pattern) { - return canApplyPattern(op, pattern, rewriter); + bool canApply = canApplyPattern(op, pattern, rewriter); + if (canApply && config.listener) + config.listener->notifyPatternBegin(pattern, op); + return canApply; }; // Functor that cleans up the rewriter state after a pattern failed to match. @@ -2115,6 +2123,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op, rewriterImpl.config.notifyCallback(diag); } }); + if (config.listener) + config.listener->notifyPatternEnd(pattern, failure()); rewriterImpl.resetState(curState); appliedPatterns.erase(&pattern); }; @@ -2127,6 +2137,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op, appliedPatterns.erase(&pattern); if (failed(result)) rewriterImpl.resetState(curState); + if (config.listener) + config.listener->notifyPatternEnd(pattern, result); return result; }; @@ -2502,7 +2514,8 @@ struct OperationConverter { const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode) - : opLegalizer(target, patterns), config(config), mode(mode) {} + : config(config), opLegalizer(target, patterns, this->config), + mode(mode) {} /// Converts the given operations to the conversion target. LogicalResult convertOperations(ArrayRef<Operation *> ops); @@ -2539,12 +2552,12 @@ struct OperationConverter { ConversionPatternRewriterImpl &rewriterImpl, const DenseMap<Value, SmallVector<Value>> &inverseMapping); - /// The legalizer to use when converting operations. - OperationLegalizer opLegalizer; - /// Dialect conversion configuration. ConversionConfig config; + /// The legalizer to use when converting operations. + OperationLegalizer opLegalizer; + /// The conversion mode to use when legalizing operations. OpConversionMode mode; }; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 51d2f5e01b7235..6cb5635e68c922 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -562,8 +562,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { // Try to match one of the patterns. The rewriter is automatically // notified of any necessary changes, so there is nothing else to do // here. -#ifndef NDEBUG - auto canApply = [&](const Pattern &pattern) { + auto canApplyCallback = [&](const Pattern &pattern) { LLVM_DEBUG({ logger.getOStream() << "\n"; logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" @@ -572,20 +571,34 @@ bool GreedyPatternRewriteDriver::processWorklist() { logger.getOStream() << ")' {\n"; logger.indent(); }); + if (config.listener) + config.listener->notifyPatternBegin(pattern, op); return true; }; - auto onFailure = [&](const Pattern &pattern) { + function_ref<bool(const Pattern &)> canApply = canApplyCallback; + auto onFailureCallback = [&](const Pattern &pattern) { LLVM_DEBUG(logResult("failure", "pattern failed to match")); + if (config.listener) + config.listener->notifyPatternEnd(pattern, failure()); }; - auto onSuccess = [&](const Pattern &pattern) { + function_ref<void(const Pattern &)> onFailure = onFailureCallback; + auto onSuccessCallback = [&](const Pattern &pattern) { LLVM_DEBUG(logResult("success", "pattern applied successfully")); + if (config.listener) + config.listener->notifyPatternEnd(pattern, success()); return success(); }; -#else - function_ref<bool(const Pattern &)> canApply = {}; - function_ref<void(const Pattern &)> onFailure = {}; - function_ref<LogicalResult(const Pattern &)> onSuccess = {}; -#endif + function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback; + +#ifdef NDEBUG + // Optimization: PatternApplicator callbacks are not needed when running in + // optimized mode and without a listener. + if (!config.listener) { + canApply = nullptr; + onFailure = nullptr; + onSuccess = nullptr; + } +#endif // NDEBUG #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (config.scope) { @@ -731,7 +744,7 @@ void GreedyPatternRewriteDriver::notifyMatchFailure( LLVM_DEBUG({ Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); - logger.startLine() << "** Failure : " << diag.str() << "\n"; + logger.startLine() << "** Match Failure : " << diag.str() << "\n"; }); if (config.listener) config.listener->notifyMatchFailure(loc, reasonCallback); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits