llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit is a refactoring of the dialect conversion. The dialect conversion 
maintains a list of "IR rewrites" that can be committed (upon success) or 
rolled back (upon failure).

Until now, the dialect conversion kept track of "op creation" in separate 
internal data structures. This commit turns "op creation" into an `IRRewrite` 
that can be committed and rolled back just like any other rewrite. This commit 
simplifies the internal state of the dialect conversion.

---
Full diff: https://github.com/llvm/llvm-project/pull/81759.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+66-38) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a07c8a56822de5..2bb56d5df1ebd3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,17 +152,12 @@ namespace {
 /// This class contains a snapshot of the current conversion rewriter state.
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
-  RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
-                unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numErased)
-      : numCreatedOps(numCreatedOps),
-        numUnresolvedMaterializations(numUnresolvedMaterializations),
+  RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
+                unsigned numIgnoredOperations, unsigned numErased)
+      : numUnresolvedMaterializations(numUnresolvedMaterializations),
         numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
         numErased(numErased) {}
 
-  /// The current number of created operations.
-  unsigned numCreatedOps;
-
   /// The current number of unresolved materializations.
   unsigned numUnresolvedMaterializations;
 
@@ -299,7 +294,8 @@ class IRRewrite {
     ReplaceBlockArg,
     MoveOperation,
     ModifyOperation,
-    ReplaceOperation
+    ReplaceOperation,
+    CreateOperation
   };
 
   virtual ~IRRewrite() = default;
@@ -372,7 +368,11 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
-    eraseBlock(block);
+    if (block->getParent()) {
+      eraseBlock(block);
+    } else {
+      delete block;
+    }
   }
 };
 
@@ -602,7 +602,7 @@ class OperationRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::MoveOperation &&
-           rewrite->getKind() <= Kind::ReplaceOperation;
+           rewrite->getKind() <= Kind::CreateOperation;
   }
 
 protected:
@@ -708,6 +708,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
   /// 1->N conversion of some kind.
   bool changedResults;
 };
+
+class CreateOperationRewrite : public OperationRewrite {
+public:
+  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                         Operation *op)
+      : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::CreateOperation;
+  }
+
+  void rollback() override;
+};
 } // namespace
 
 /// Return "true" if there is an operation rewrite that matches the specified
@@ -925,9 +938,6 @@ struct ConversionPatternRewriterImpl : public 
RewriterBase::Listener {
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
 
-  /// Ordered vector of all of the newly created operations during conversion.
-  SmallVector<Operation *> createdOps;
-
   /// Ordered vector of all unresolved type conversion materializations during
   /// conversion.
   SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1110,7 +1120,18 @@ void ReplaceOperationRewrite::rollback() {
 
 void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
 
+void CreateOperationRewrite::rollback() {
+  for (Region &region : op->getRegions()) {
+    while (!region.getBlocks().empty())
+      region.getBlocks().remove(region.getBlocks().begin());
+  }
+  op->dropAllUses();
+  eraseOp(op);
+}
+
 void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
+  // if (erasedIR.erasedOps.contains(op)) return;
+
   for (Region &region : op->getRegions()) {
     for (Block &block : region.getBlocks()) {
       while (!block.getOperations().empty())
@@ -1127,8 +1148,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
   // Remove any newly created ops.
   for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
     detachNestedAndErase(materialization.getOp());
-  for (auto *op : llvm::reverse(createdOps))
-    detachNestedAndErase(op);
 }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
@@ -1148,9 +1167,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
-                       rewrites.size(), ignoredOps.size(),
-                       eraseRewriter.erased.size());
+  return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
+                       ignoredOps.size(), eraseRewriter.erased.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1171,12 +1189,6 @@ void 
ConversionPatternRewriterImpl::resetState(RewriterState state) {
     detachNestedAndErase(op);
   }
 
-  // Pop all of the newly created operations.
-  while (createdOps.size() != state.numCreatedOps) {
-    detachNestedAndErase(createdOps.back());
-    createdOps.pop_back();
-  }
-
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
@@ -1460,7 +1472,7 @@ void 
ConversionPatternRewriterImpl::notifyOperationInserted(
   });
   if (!previous.isSet()) {
     // This is a newly created op.
-    createdOps.push_back(op);
+    appendRewrite<CreateOperationRewrite>(op);
     return;
   }
   Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1961,13 +1973,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
   rewriter.replaceOp(op, replacementValues);
 
   // Recursively legalize any new constant operations.
-  for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
+  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
        i != e; ++i) {
-    Operation *cstOp = rewriterImpl.createdOps[i];
-    if (failed(legalize(cstOp, rewriter))) {
+    auto *createOp =
+        dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
+    if (!createOp)
+      continue;
+    if (failed(legalize(createOp->getOperation(), rewriter))) {
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
-                            cstOp->getName()));
+                            createOp->getOperation()->getName()));
       rewriterImpl.resetState(curState);
       return failure();
     }
@@ -2112,9 +2127,14 @@ LogicalResult 
OperationLegalizer::legalizePatternBlockRewrites(
     // blocks in regions created by this pattern will already be legalized 
later
     // on. If we haven't built the set yet, build it now.
     if (operationsToIgnore.empty()) {
-      auto createdOps = ArrayRef<Operation *>(impl.createdOps)
-                            .drop_front(state.numCreatedOps);
-      operationsToIgnore.insert(createdOps.begin(), createdOps.end());
+      for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
+           ++i) {
+        auto *createOp =
+            dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+        if (!createOp)
+          continue;
+        operationsToIgnore.insert(createOp->getOperation());
+      }
     }
 
     // If this operation should be considered for re-legalization, try it.
@@ -2132,8 +2152,11 @@ LogicalResult 
OperationLegalizer::legalizePatternBlockRewrites(
 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
     RewriterState &state, RewriterState &newState) {
-  for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
-    Operation *op = impl.createdOps[i];
+  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+    auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+    if (!createOp)
+      continue;
+    Operation *op = createOp->getOperation();
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(impl.logger,
                             "failed to legalize generated operation 
'{0}'({1})",
@@ -2563,10 +2586,15 @@ LogicalResult 
OperationConverter::legalizeConvertedArgumentTypes(
     });
     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
   };
-  for (auto &r : rewriterImpl.rewrites)
-    if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
-      if (failed(rewrite->materializeLiveConversions(findLiveUser)))
+  // Note: `rewrites` may be reallocated as the loop is running.
+  for (int64_t i = 0; i < rewriterImpl.rewrites.size(); ++i) {
+    auto &rewrite = rewriterImpl.rewrites[i];
+    if (auto *blockTypeConversionRewrite =
+            dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
+      if (failed(blockTypeConversionRewrite->materializeLiveConversions(
+              findLiveUser)))
         return failure();
+  }
   return success();
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/81759
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to