[Mlir-commits] [mlir] 310a278 - [mlir][Transforms][NFC] Simplify handling of erased IR (#83423)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 4 01:21:47 PST 2024


Author: Matthias Springer
Date: 2024-03-04T18:21:43+09:00
New Revision: 310a278812319a7cae86239c975738019b5d90a3

URL: https://github.com/llvm/llvm-project/commit/310a278812319a7cae86239c975738019b5d90a3
DIFF: https://github.com/llvm/llvm-project/commit/310a278812319a7cae86239c975738019b5d90a3.diff

LOG: [mlir][Transforms][NFC] Simplify handling of erased IR (#83423)

The dialect conversion uses a `SingleEraseRewriter` to ensure that an
op/block is not erased twice. This can happen during the "commit" phase
when an unresolved materialization is inserted into a block and the
enclosing op is erased by the user. In that case, the unresolved
materialization should not be erased a second time later in the "commit"
phase.

This problem cannot happen during "rollback", so ops/block can be erased
directly without using the rewriter. With this change, the
`SingleEraseRewriter` is used only during "commit"/"cleanup". At that
point, the dialect conversion is guaranteed to succeed and no rollback
can happen. Therefore, it is not necessary to store the number of erased
IR objects (because we will never "reset" the rewriter to previous a
previous state).

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 71aee4447f7921..9dc806730d01a1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numErased, unsigned numReplacedOps)
+                unsigned numReplacedOps)
       : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
-        numErased(numErased), numReplacedOps(numReplacedOps) {}
+        numReplacedOps(numReplacedOps) {}
 
   /// The current number of rewrites performed.
   unsigned numRewrites;
@@ -163,9 +163,6 @@ struct RewriterState {
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
 
-  /// The current number of erased operations/blocks.
-  unsigned numErased;
-
   /// The current number of replaced ops that are scheduled for erasure.
   unsigned numReplacedOps;
 };
@@ -274,8 +271,9 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
+    block->dropAllUses();
     if (block->getParent())
-      eraseBlock(block);
+      block->erase();
     else
       delete block;
   }
@@ -905,7 +903,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
     /// Pointers to all erased operations and blocks.
-    SetVector<void *> erased;
+    DenseSet<void *> erased;
   };
 
   //===--------------------------------------------------------------------===//
@@ -1091,7 +1089,7 @@ void CreateOperationRewrite::rollback() {
       region.getBlocks().remove(region.getBlocks().begin());
   }
   op->dropAllUses();
-  eraseOp(op);
+  op->erase();
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
@@ -1099,7 +1097,7 @@ void UnresolvedMaterializationRewrite::rollback() {
     for (Value input : op->getOperands())
       rewriterImpl.mapping.erase(input);
   }
-  eraseOp(op);
+  op->erase();
 }
 
 void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
@@ -1116,8 +1114,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(rewrites.size(), ignoredOps.size(),
-                       eraseRewriter.erased.size(), replacedOps.size());
+  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1128,9 +1125,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
 
-  while (eraseRewriter.erased.size() != state.numErased)
-    eraseRewriter.erased.pop_back();
-
   while (replacedOps.size() != state.numReplacedOps)
     replacedOps.pop_back();
 }


        


More information about the Mlir-commits mailing list