https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/81240
Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`. `RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.) >From 7503c0cb484c54249ff66c5780197d46937c660d Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Fri, 9 Feb 2024 09:58:46 +0000 Subject: [PATCH] [mlir][Transforms] Support `moveOpBefore`/`After` in dialect conversion Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`. `RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.) --- mlir/include/mlir/IR/PatternMatch.h | 6 +- .../mlir/Transforms/DialectConversion.h | 5 -- .../Transforms/Utils/DialectConversion.cpp | 74 +++++++++++++++---- mlir/test/Transforms/test-legalizer.mlir | 14 ++++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 20 ++++- 5 files changed, 93 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 78dcfe7f6fc3d2..b8aeea0d23475b 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder { /// Unlink this operation from its current block and insert it right before /// `iterator` in the specified block. - virtual void moveOpBefore(Operation *op, Block *block, - Block::iterator iterator); + void moveOpBefore(Operation *op, Block *block, Block::iterator iterator); /// Unlink this operation from its current block and insert it right after /// `existingOp` which may be in the same or another block in the same @@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder { /// Unlink this operation from its current block and insert it right after /// `iterator` in the specified block. - virtual void moveOpAfter(Operation *op, Block *block, - Block::iterator iterator); + void moveOpAfter(Operation *op, Block *block, Block::iterator iterator); /// Unlink this block and insert it right before `existingBlock`. void moveBlockBefore(Block *block, Block *anotherBlock); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f061d761ecefbb..c0c702a7d34821 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -738,11 +738,6 @@ class ConversionPatternRewriter final : public PatternRewriter { // Hide unsupported pattern rewriter API. using OpBuilder::setListener; - void moveOpBefore(Operation *op, Block *block, - Block::iterator iterator) override; - void moveOpAfter(Operation *op, Block *block, - Block::iterator iterator) override; - std::unique_ptr<detail::ConversionPatternRewriterImpl> impl; }; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 44c107c8733f3d..ffdb069f6e9b81 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -757,7 +757,8 @@ class RewriteAction { InlineBlock, MoveBlock, SplitBlock, - BlockTypeConversion + BlockTypeConversion, + MoveOperation }; virtual ~RewriteAction() = default; @@ -970,6 +971,54 @@ class BlockTypeConversionAction : public BlockAction { void rollback() override; }; + +/// An operation rewrite. +class OperationAction : public RewriteAction { +public: + /// Return the operation that this action operates on. + Operation *getOperation() const { return op; } + + static bool classof(const RewriteAction *action) { + return action->getKind() >= Kind::MoveOperation && + action->getKind() <= Kind::MoveOperation; + } + +protected: + OperationAction(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, + Operation *op) + : RewriteAction(kind, rewriterImpl), op(op) {} + + // The operation that this action operates on. + Operation *op; +}; + +/// Rewrite action that represent the moving of a block. +class MoveOperationAction : public OperationAction { +public: + MoveOperationAction(ConversionPatternRewriterImpl &rewriterImpl, + Operation *op, Block *block, Operation *insertBeforeOp) + : OperationAction(Kind::MoveOperation, rewriterImpl, op), block(block), + insertBeforeOp(insertBeforeOp) {} + + static bool classof(const RewriteAction *action) { + return action->getKind() == Kind::MoveOperation; + } + + void rollback() override { + // Move the operation back to its original position. + Block::iterator before = + insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end(); + block->getOperations().splice(before, op->getBlock()->getOperations(), op); + } + +private: + // The block in which this operation was previously contained. + Block *block; + + // The original successor of this operation before it was moved. "nullptr" if + // this operation was the only operation in the region. + Operation *insertBeforeOp; +}; } // namespace //===----------------------------------------------------------------------===// @@ -1468,12 +1517,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( void ConversionPatternRewriterImpl::notifyOperationInserted( Operation *op, OpBuilder::InsertPoint previous) { - assert(!previous.isSet() && "expected newly created op"); LLVM_DEBUG({ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); - createdOps.push_back(op); + if (!previous.isSet()) { + // This is a newly created op. + createdOps.push_back(op); + return; + } + Operation *prevOp = previous.getPoint() == previous.getBlock()->end() + ? nullptr + : &*previous.getPoint(); + appendRewriteAction<MoveOperationAction>(op, previous.getBlock(), prevOp); } void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, @@ -1712,18 +1768,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) { rootUpdates.erase(rootUpdates.begin() + updateIdx); } -void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block, - Block::iterator iterator) { - llvm_unreachable( - "moving single ops is not supported in a dialect conversion"); -} - -void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block, - Block::iterator iterator) { - llvm_unreachable( - "moving single ops is not supported in a dialect conversion"); -} - detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { return *impl; } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index d8cf6e4719cede..84fcc18ab7d370 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -320,3 +320,17 @@ module { return } } + +// ----- + +// CHECK-LABEL: func @test_move_op_before_rollback() +func.func @test_move_op_before_rollback() { + // CHECK: "test.one_region_op"() + // CHECK: "test.hoist_me"() + "test.one_region_op"() ({ + // expected-remark @below{{'test.hoist_me' is not legalizable}} + %0 = "test.hoist_me"() : () -> (i32) + "test.valid"(%0) : (i32) -> () + }) : () -> () + "test.return"() : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index d7e5d6db50c1fb..1c02232b8adbb1 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern { } }; +/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. +/// This is to test the rollback logic. +struct TestUndoMoveOpBefore : public ConversionPattern { + TestUndoMoveOpBefore(MLIRContext *ctx) + : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.moveOpBefore(op, op->getParentOp()); + // Replace with an illegal op to ensure the conversion fails. + rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); + return success(); + } +}; + /// A rewrite pattern that tests the undo mechanism when erasing a block. struct TestUndoBlockErase : public ConversionPattern { TestUndoBlockErase(MLIRContext *ctx) @@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, - TestCreateUnregisteredOp>(&getContext()); + TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext()); patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver ConversionTarget target(getContext()); target.addLegalOp<ModuleOp>(); target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, - TerminatorOp>(); + TerminatorOp, OneRegionOp>(); target .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits