https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81245
>From 4bb65218698f0104775f3eea05817c6d5228a9e7 Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Wed, 14 Feb 2024 16:17:03 +0000 Subject: [PATCH] [mlir][Transforms][NFC] Turn in-place op modifications into `RewriteAction`s This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC --- .../mlir/Transforms/DialectConversion.h | 4 +- .../Transforms/Utils/DialectConversion.cpp | 139 +++++++++--------- 2 files changed, 68 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 15fa39bde104b9..0d7722aa07ee38 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -744,8 +744,8 @@ class ConversionPatternRewriter final : public PatternRewriter { /// PatternRewriter hook for updating the given operation in-place. /// Note: These methods only track updates to the given operation itself, - /// and not nested regions. Updates to regions will still require - /// notification through other more specific hooks above. + /// and not nested regions. Updates to regions will still require notification + /// through other more specific hooks above. void startOpModification(Operation *op) override; /// PatternRewriter hook for updating the given operation in-place. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 84597fb7986b07..5206a65608ba14 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -154,14 +154,12 @@ namespace { struct RewriterState { RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, unsigned numReplacements, unsigned numArgReplacements, - unsigned numRewrites, unsigned numIgnoredOperations, - unsigned numRootUpdates) + unsigned numRewrites, unsigned numIgnoredOperations) : numCreatedOps(numCreatedOps), numUnresolvedMaterializations(numUnresolvedMaterializations), numReplacements(numReplacements), numArgReplacements(numArgReplacements), numRewrites(numRewrites), - numIgnoredOperations(numIgnoredOperations), - numRootUpdates(numRootUpdates) {} + numIgnoredOperations(numIgnoredOperations) {} /// The current number of created operations. unsigned numCreatedOps; @@ -180,44 +178,6 @@ struct RewriterState { /// The current number of ignored operations. unsigned numIgnoredOperations; - - /// The current number of operations that were updated in place. - unsigned numRootUpdates; -}; - -//===----------------------------------------------------------------------===// -// OperationTransactionState - -/// The state of an operation that was updated by a pattern in-place. This -/// contains all of the necessary information to reconstruct an operation that -/// was updated in place. -class OperationTransactionState { -public: - OperationTransactionState() = default; - OperationTransactionState(Operation *op) - : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()), - operands(op->operand_begin(), op->operand_end()), - successors(op->successor_begin(), op->successor_end()) {} - - /// Discard the transaction state and reset the state of the original - /// operation. - void resetOperation() const { - op->setLoc(loc); - op->setAttrs(attrs); - op->setOperands(operands); - for (const auto &it : llvm::enumerate(successors)) - op->setSuccessor(it.value(), it.index()); - } - - /// Return the original operation of this state. - Operation *getOperation() const { return op; } - -private: - Operation *op; - LocationAttr loc; - DictionaryAttr attrs; - SmallVector<Value, 8> operands; - SmallVector<Block *, 2> successors; }; //===----------------------------------------------------------------------===// @@ -761,7 +721,8 @@ class IRRewrite { MoveBlock, SplitBlock, BlockTypeConversion, - MoveOperation + MoveOperation, + ModifyOperation }; virtual ~IRRewrite() = default; @@ -992,7 +953,7 @@ class OperationRewrite : public IRRewrite { static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() >= Kind::MoveOperation && - rewrite->getKind() <= Kind::MoveOperation; + rewrite->getKind() <= Kind::ModifyOperation; } protected: @@ -1031,8 +992,50 @@ class MoveOperationRewrite : public OperationRewrite { // this operation was the only operation in the region. Operation *insertBeforeOp; }; + +/// In-place modification of an op. This rewrite is immediately reflected in +/// the IR. The previous state of the operation is stored in this object. +class ModifyOperationRewrite : public OperationRewrite { +public: + ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Operation *op) + : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op), + loc(op->getLoc()), attrs(op->getAttrDictionary()), + operands(op->operand_begin(), op->operand_end()), + successors(op->successor_begin(), op->successor_end()) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ModifyOperation; + } + + /// Discard the transaction state and reset the state of the original + /// operation. + void rollback() override { + op->setLoc(loc); + op->setAttrs(attrs); + op->setOperands(operands); + for (const auto &it : llvm::enumerate(successors)) + op->setSuccessor(it.value(), it.index()); + } + +private: + LocationAttr loc; + DictionaryAttr attrs; + SmallVector<Value, 8> operands; + SmallVector<Block *, 2> successors; +}; } // namespace +/// Return "true" if there is an operation rewrite that matches the specified +/// rewrite type and operation among the given rewrites. +template <typename RewriteTy, typename R> +static bool hasRewrite(R &&rewrites, Operation *op) { + return any_of(std::move(rewrites), [&](auto &rewrite) { + auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); + return rewriteTy && rewriteTy->getOperation() == op; + }); +} + //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// @@ -1184,9 +1187,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// operation was ignored. SetVector<Operation *> ignoredOps; - /// A transaction state for each of operations that were updated in-place. - SmallVector<OperationTransactionState, 4> rootUpdates; - /// A vector of indices into `replacements` of operations that were replaced /// with values with different result types than the original operation, e.g. /// 1->N conversion of some kind. @@ -1238,10 +1238,6 @@ static void detachNestedAndErase(Operation *op) { } void ConversionPatternRewriterImpl::discardRewrites() { - // Reset any operations that were updated in place. - for (auto &state : rootUpdates) - state.resetOperation(); - undoRewrites(); // Remove any newly created ops. @@ -1316,15 +1312,10 @@ void ConversionPatternRewriterImpl::applyRewrites() { RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), unresolvedMaterializations.size(), replacements.size(), argReplacements.size(), - rewrites.size(), ignoredOps.size(), rootUpdates.size()); + rewrites.size(), ignoredOps.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { - // Reset any operations that were updated in place. - for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) - rootUpdates[i].resetOperation(); - rootUpdates.resize(state.numRootUpdates); - // Reset any replaced arguments. for (BlockArgument replacedArg : llvm::drop_begin(argReplacements, state.numArgReplacements)) @@ -1750,7 +1741,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { #ifndef NDEBUG impl->pendingRootUpdates.insert(op); #endif - impl->rootUpdates.emplace_back(op); + impl->appendRewrite<ModifyOperationRewrite>(op); } void ConversionPatternRewriter::finalizeOpModification(Operation *op) { @@ -1769,13 +1760,15 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) { "operation did not have a pending in-place update"); #endif // Erase the last update for this operation. - auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; }; - auto &rootUpdates = impl->rootUpdates; - auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp); - assert(it != rootUpdates.rend() && "no root update started on op"); - (*it).resetOperation(); - int updateIdx = std::prev(rootUpdates.rend()) - it; - rootUpdates.erase(rootUpdates.begin() + updateIdx); + auto it = llvm::find_if( + llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) { + auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get()); + return modifyRewrite && modifyRewrite->getOperation() == op; + }); + assert(it != impl->rewrites.rend() && "no root update started on op"); + (*it)->rollback(); + int updateIdx = std::prev(impl->rewrites.rend()) - it; + impl->rewrites.erase(impl->rewrites.begin() + updateIdx); } detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { @@ -2118,7 +2111,6 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, #ifndef NDEBUG assert(impl.pendingRootUpdates.empty() && "dangling root updates"); -#endif // Check that the root was either replaced or updated in place. auto replacedRoot = [&] { @@ -2127,14 +2119,12 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, [op](auto &it) { return it.first == op; }); }; auto updatedRootInPlace = [&] { - return llvm::any_of( - llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates), - [op](auto &state) { return state.getOperation() == op; }); + return hasRewrite<ModifyOperationRewrite>( + llvm::drop_begin(impl.rewrites, curState.numRewrites), op); }; - (void)replacedRoot; - (void)updatedRootInPlace; assert((replacedRoot() || updatedRootInPlace()) && "expected pattern to replace the root operation"); +#endif // NDEBUG // Legalize each of the actions registered during application. RewriterState newState = impl.getCurrentState(); @@ -2221,8 +2211,11 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations( LogicalResult OperationLegalizer::legalizePatternRootUpdates( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, RewriterState &state, RewriterState &newState) { - for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) { - Operation *op = impl.rootUpdates[i].getOperation(); + for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { + auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get()); + if (!rewrite) + continue; + Operation *op = rewrite->getOperation(); if (failed(legalize(op, rewriter))) { LLVM_DEBUG(logFailure( impl.logger, "failed to legalize operation updated in-place '{0}'", _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits