[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Simplify handling of erased IR (PR #83423)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Feb 29 04:56:28 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

The dialect conversion uses a `SingleEraseRewriter` to ensure that an op/block is not erased twice. This can happen during the "commit" phase when an unresolved materialization is inserted into a block and the enclosing op is erased by the user. In that case, the unresolved materialization should not be erased a second time later in the "commit" phase.

This problem cannot happen during "rollback", so ops/block can be erased directly without using the rewriter. With this change, the `SingleEraseRewriter` is used only during "commit"/"cleanup". At that point, the dialect conversion is guaranteed to succeed and no rollback can happen. Therefore, it is not necessary to store the number of erased IR objects (because we will never "reset" the rewriter to previous a previous state).


---
Full diff: https://github.com/llvm/llvm-project/pull/83423.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+8-14) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index cac990d498d7d3..9f6468402686bd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numErased, unsigned numReplacedOps)
+                unsigned numReplacedOps)
       : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
-        numErased(numErased), numReplacedOps(numReplacedOps) {}
+        numReplacedOps(numReplacedOps) {}
 
   /// The current number of rewrites performed.
   unsigned numRewrites;
@@ -163,9 +163,6 @@ struct RewriterState {
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
 
-  /// The current number of erased operations/blocks.
-  unsigned numErased;
-
   /// The current number of replaced ops that are scheduled for erasure.
   unsigned numReplacedOps;
 };
@@ -273,8 +270,9 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
+    block->dropAllUses();
     if (block->getParent())
-      eraseBlock(block);
+      block->erase();
     else
       delete block;
   }
@@ -858,7 +856,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
     /// Pointers to all erased operations and blocks.
-    SetVector<void *> erased;
+    DenseSet<void *> erased;
   };
 
   //===--------------------------------------------------------------------===//
@@ -1044,7 +1042,7 @@ void CreateOperationRewrite::rollback() {
       region.getBlocks().remove(region.getBlocks().begin());
   }
   op->dropAllUses();
-  eraseOp(op);
+  op->erase();
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
@@ -1052,7 +1050,7 @@ void UnresolvedMaterializationRewrite::rollback() {
     for (Value input : op->getOperands())
       rewriterImpl.mapping.erase(input);
   }
-  eraseOp(op);
+  op->erase();
 }
 
 void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
@@ -1069,8 +1067,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(rewrites.size(), ignoredOps.size(),
-                       eraseRewriter.erased.size(), replacedOps.size());
+  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1081,9 +1078,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
 
-  while (eraseRewriter.erased.size() != state.numErased)
-    eraseRewriter.erased.pop_back();
-
   while (replacedOps.size() != state.numReplacedOps)
     replacedOps.pop_back();
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/83423


More information about the llvm-branch-commits mailing list