llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

The dialect conversion uses a `SingleEraseRewriter` to ensure that an op/block 
is not erased twice. This can happen during the "commit" phase when an 
unresolved materialization is inserted into a block and the enclosing op is 
erased by the user. In that case, the unresolved materialization should not be 
erased a second time later in the "commit" phase.

This problem cannot happen during "rollback", so ops/block can be erased 
directly without using the rewriter. With this change, the 
`SingleEraseRewriter` is used only during "commit"/"cleanup". At that point, 
the dialect conversion is guaranteed to succeed and no rollback can happen. 
Therefore, it is not necessary to store the number of erased IR objects 
(because we will never "reset" the rewriter to previous a previous state).


---
Full diff: https://github.com/llvm/llvm-project/pull/83423.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+8-14) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index cac990d498d7d3..9f6468402686bd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numErased, unsigned numReplacedOps)
+                unsigned numReplacedOps)
       : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
-        numErased(numErased), numReplacedOps(numReplacedOps) {}
+        numReplacedOps(numReplacedOps) {}
 
   /// The current number of rewrites performed.
   unsigned numRewrites;
@@ -163,9 +163,6 @@ struct RewriterState {
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
 
-  /// The current number of erased operations/blocks.
-  unsigned numErased;
-
   /// The current number of replaced ops that are scheduled for erasure.
   unsigned numReplacedOps;
 };
@@ -273,8 +270,9 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
+    block->dropAllUses();
     if (block->getParent())
-      eraseBlock(block);
+      block->erase();
     else
       delete block;
   }
@@ -858,7 +856,7 @@ struct ConversionPatternRewriterImpl : public 
RewriterBase::Listener {
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
     /// Pointers to all erased operations and blocks.
-    SetVector<void *> erased;
+    DenseSet<void *> erased;
   };
 
   
//===--------------------------------------------------------------------===//
@@ -1044,7 +1042,7 @@ void CreateOperationRewrite::rollback() {
       region.getBlocks().remove(region.getBlocks().begin());
   }
   op->dropAllUses();
-  eraseOp(op);
+  op->erase();
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
@@ -1052,7 +1050,7 @@ void UnresolvedMaterializationRewrite::rollback() {
     for (Value input : op->getOperands())
       rewriterImpl.mapping.erase(input);
   }
-  eraseOp(op);
+  op->erase();
 }
 
 void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
@@ -1069,8 +1067,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(rewrites.size(), ignoredOps.size(),
-                       eraseRewriter.erased.size(), replacedOps.size());
+  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1081,9 +1078,6 @@ void 
ConversionPatternRewriterImpl::resetState(RewriterState state) {
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
 
-  while (eraseRewriter.erased.size() != state.numErased)
-    eraseRewriter.erased.pop_back();
-
   while (replacedOps.size() != state.numReplacedOps)
     replacedOps.pop_back();
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/83423
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to