llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed. --- Full diff: https://github.com/llvm/llvm-project/pull/81245.diff 1 Files Affected: - (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+58-70) ``````````diff diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index ffdb069f6e9b8..d0114a148cd37 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -154,15 +154,13 @@ namespace { struct RewriterState { RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, unsigned numReplacements, unsigned numArgReplacements, - unsigned numRewriteActions, unsigned numIgnoredOperations, - unsigned numRootUpdates) + unsigned numRewriteActions, unsigned numIgnoredOperations) : numCreatedOps(numCreatedOps), numUnresolvedMaterializations(numUnresolvedMaterializations), numReplacements(numReplacements), numArgReplacements(numArgReplacements), numRewriteActions(numRewriteActions), - numIgnoredOperations(numIgnoredOperations), - numRootUpdates(numRootUpdates) {} + numIgnoredOperations(numIgnoredOperations) {} /// The current number of created operations. unsigned numCreatedOps; @@ -181,44 +179,6 @@ struct RewriterState { /// The current number of ignored operations. unsigned numIgnoredOperations; - - /// The current number of operations that were updated in place. - unsigned numRootUpdates; -}; - -//===----------------------------------------------------------------------===// -// OperationTransactionState - -/// The state of an operation that was updated by a pattern in-place. This -/// contains all of the necessary information to reconstruct an operation that -/// was updated in place. -class OperationTransactionState { -public: - OperationTransactionState() = default; - OperationTransactionState(Operation *op) - : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()), - operands(op->operand_begin(), op->operand_end()), - successors(op->successor_begin(), op->successor_end()) {} - - /// Discard the transaction state and reset the state of the original - /// operation. - void resetOperation() const { - op->setLoc(loc); - op->setAttrs(attrs); - op->setOperands(operands); - for (const auto &it : llvm::enumerate(successors)) - op->setSuccessor(it.value(), it.index()); - } - - /// Return the original operation of this state. - Operation *getOperation() const { return op; } - -private: - Operation *op; - LocationAttr loc; - DictionaryAttr attrs; - SmallVector<Value, 8> operands; - SmallVector<Block *, 2> successors; }; //===----------------------------------------------------------------------===// @@ -758,7 +718,8 @@ class RewriteAction { MoveBlock, SplitBlock, BlockTypeConversion, - MoveOperation + MoveOperation, + ModifyOperation }; virtual ~RewriteAction() = default; @@ -980,7 +941,7 @@ class OperationAction : public RewriteAction { static bool classof(const RewriteAction *action) { return action->getKind() >= Kind::MoveOperation && - action->getKind() <= Kind::MoveOperation; + action->getKind() <= Kind::ModifyOperation; } protected: @@ -1019,6 +980,34 @@ class MoveOperationAction : public OperationAction { // this operation was the only operation in the region. Operation *insertBeforeOp; }; + +/// Rewrite action that represents the in-place modification of an operation. +/// The previous state of the operation is stored in this action. +class ModifyOperationAction : public OperationAction { +public: + ModifyOperationAction(ConversionPatternRewriterImpl &rewriterImpl, + Operation *op) + : OperationAction(Kind::ModifyOperation, rewriterImpl, op), + loc(op->getLoc()), attrs(op->getAttrDictionary()), + operands(op->operand_begin(), op->operand_end()), + successors(op->successor_begin(), op->successor_end()) {} + + /// Discard the transaction state and reset the state of the original + /// operation. + void rollback() override { + op->setLoc(loc); + op->setAttrs(attrs); + op->setOperands(operands); + for (const auto &it : llvm::enumerate(successors)) + op->setSuccessor(it.value(), it.index()); + } + +private: + LocationAttr loc; + DictionaryAttr attrs; + SmallVector<Value, 8> operands; + SmallVector<Block *, 2> successors; +}; } // namespace //===----------------------------------------------------------------------===// @@ -1172,9 +1161,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// operation was ignored. SetVector<Operation *> ignoredOps; - /// A transaction state for each of operations that were updated in-place. - SmallVector<OperationTransactionState, 4> rootUpdates; - /// A vector of indices into `replacements` of operations that were replaced /// with values with different result types than the original operation, e.g. /// 1->N conversion of some kind. @@ -1226,10 +1212,6 @@ static void detachNestedAndErase(Operation *op) { } void ConversionPatternRewriterImpl::discardRewrites() { - // Reset any operations that were updated in place. - for (auto &state : rootUpdates) - state.resetOperation(); - undoRewriteActions(); // Remove any newly created ops. @@ -1304,16 +1286,10 @@ void ConversionPatternRewriterImpl::applyRewrites() { RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), unresolvedMaterializations.size(), replacements.size(), argReplacements.size(), - rewriteActions.size(), ignoredOps.size(), - rootUpdates.size()); + rewriteActions.size(), ignoredOps.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { - // Reset any operations that were updated in place. - for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) - rootUpdates[i].resetOperation(); - rootUpdates.resize(state.numRootUpdates); - // Reset any replaced arguments. for (BlockArgument replacedArg : llvm::drop_begin(argReplacements, state.numArgReplacements)) @@ -1740,7 +1716,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { #ifndef NDEBUG impl->pendingRootUpdates.insert(op); #endif - impl->rootUpdates.emplace_back(op); + impl->appendRewriteAction<ModifyOperationAction>(op); } void ConversionPatternRewriter::finalizeOpModification(Operation *op) { @@ -1759,13 +1735,17 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) { "operation did not have a pending in-place update"); #endif // Erase the last update for this operation. - auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; }; - auto &rootUpdates = impl->rootUpdates; - auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp); - assert(it != rootUpdates.rend() && "no root update started on op"); - (*it).resetOperation(); - int updateIdx = std::prev(rootUpdates.rend()) - it; - rootUpdates.erase(rootUpdates.begin() + updateIdx); + auto it = + llvm::find_if(llvm::reverse(impl->rewriteActions), + [&](std::unique_ptr<RewriteAction> &action) { + auto *modifyAction = + dynamic_cast<ModifyOperationAction *>(action.get()); + return modifyAction && modifyAction->getOperation() == op; + }); + assert(it != impl->rewriteActions.rend() && "no root update started on op"); + (*it)->rollback(); + int updateIdx = std::prev(impl->rewriteActions.rend()) - it; + impl->rewriteActions.erase(impl->rewriteActions.begin() + updateIdx); } detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { @@ -2118,8 +2098,11 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, }; auto updatedRootInPlace = [&] { return llvm::any_of( - llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates), - [op](auto &state) { return state.getOperation() == op; }); + llvm::drop_begin(impl.rewriteActions, curState.numRewriteActions), + [op](auto &action) { + auto *modifyAction = dyn_cast<ModifyOperationAction>(action.get()); + return modifyAction && modifyAction->getOperation() == op; + }); }; (void)replacedRoot; (void)updatedRootInPlace; @@ -2213,8 +2196,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations( LogicalResult OperationLegalizer::legalizePatternRootUpdates( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, RewriterState &state, RewriterState &newState) { - for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) { - Operation *op = impl.rootUpdates[i].getOperation(); + for (int i = state.numRewriteActions, e = newState.numRewriteActions; i != e; + ++i) { + auto *action = + dyn_cast<ModifyOperationAction>(impl.rewriteActions[i].get()); + if (!action) + continue; + Operation *op = action->getOperation(); if (failed(legalize(op, rewriter))) { LLVM_DEBUG(logFailure( impl.logger, "failed to legalize operation updated in-place '{0}'", `````````` </details> https://github.com/llvm/llvm-project/pull/81245 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits