https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/82474
>From 33f2ae9da319110ca8d2581ec6d66d2db83201cb Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Wed, 21 Feb 2024 08:41:44 +0000 Subject: [PATCH] [mlir][Transforms] Support rolling back properties in dialect conversion The dialect conversion rolls back inplace op modifications upon failure. Rolling back modifications of op properties was not supported before this commit. --- .../Transforms/Utils/DialectConversion.cpp | 28 ++++++++++++++++++- mlir/test/Transforms/test-legalizer.mlir | 12 ++++++++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 18 +++++++++++- 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 673bd0383809cb..5be3e6b90b5d08 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1002,12 +1002,31 @@ class ModifyOperationRewrite : public OperationRewrite { : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op), loc(op->getLoc()), attrs(op->getAttrDictionary()), operands(op->operand_begin(), op->operand_end()), - successors(op->successor_begin(), op->successor_end()) {} + successors(op->successor_begin(), op->successor_end()) { + if (OpaqueProperties prop = op->getPropertiesStorage()) { + // Make a copy of the properties. + propertiesStorage = operator new(op->getPropertiesStorageSize()); + OpaqueProperties propCopy(propertiesStorage); + op->getName().copyOpProperties(propCopy, prop); + } + } static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::ModifyOperation; } + ~ModifyOperationRewrite() override { + assert(!propertiesStorage && + "rewrite was neither committed nor rolled back"); + } + + void commit() override { + if (propertiesStorage) { + operator delete(propertiesStorage); + propertiesStorage = nullptr; + } + } + /// Discard the transaction state and reset the state of the original /// operation. void rollback() override { @@ -1016,6 +1035,12 @@ class ModifyOperationRewrite : public OperationRewrite { op->setOperands(operands); for (const auto &it : llvm::enumerate(successors)) op->setSuccessor(it.value(), it.index()); + if (propertiesStorage) { + OpaqueProperties prop(propertiesStorage); + op->copyProperties(prop); + operator delete(propertiesStorage); + propertiesStorage = nullptr; + } } private: @@ -1023,6 +1048,7 @@ class ModifyOperationRewrite : public OperationRewrite { DictionaryAttr attrs; SmallVector<Value, 8> operands; SmallVector<Block *, 2> successors; + void *propertiesStorage = nullptr; }; } // namespace diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 84fcc18ab7d370..62d776cd7573ee 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -334,3 +334,15 @@ func.func @test_move_op_before_rollback() { }) : () -> () "test.return"() : () -> () } + +// ----- + +// CHECK-LABEL: func @test_properties_rollback() +func.func @test_properties_rollback() { + // CHECK: test.with_properties <{a = 32 : i64, + // expected-remark @below{{op 'test.with_properties' is not legalizable}} + test.with_properties + <{a = 32 : i64, array = array<i64: 1, 2, 3, 4>, b = "foo"}> + {modify_inplace} + "test.return"() : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 1c02232b8adbb1..57e846294f8b9f 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -806,6 +806,21 @@ struct TestUndoBlockErase : public ConversionPattern { } }; +/// A pattern that modifies a property in-place, but keeps the op illegal. +struct TestUndoPropertiesModification : public ConversionPattern { + TestUndoPropertiesModification(MLIRContext *ctx) + : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + if (!op->hasAttr("modify_inplace")) + return failure(); + rewriter.modifyOpInPlace( + op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Type-Conversion Rewrite Testing @@ -1085,7 +1100,8 @@ struct TestLegalizePatternDriver TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, - TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext()); + TestCreateUnregisteredOp, TestUndoMoveOpBefore, + TestUndoPropertiesModification>(&getContext()); patterns.add<TestDropOpSignatureConversion>(&getContext(), converter); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits