[Mlir-commits] [mlir] [mlir][Transforms][NFC] Simplify handling of erased IR (PR #83423)

Matthias Springer llvmlistbot at llvm.org
Mon Mar 4 01:08:43 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/83423

>From 1d7d17655a56a3e64ef02d938d9569505560a05a Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 4 Mar 2024 09:07:11 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Do not use SingleEraseRewriter during
 rollback

---
 .../Transforms/Utils/DialectConversion.cpp    | 22 +++++++------------
 1 file changed, 8 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 71aee4447f7921..9dc806730d01a1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numErased, unsigned numReplacedOps)
+                unsigned numReplacedOps)
       : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
-        numErased(numErased), numReplacedOps(numReplacedOps) {}
+        numReplacedOps(numReplacedOps) {}
 
   /// The current number of rewrites performed.
   unsigned numRewrites;
@@ -163,9 +163,6 @@ struct RewriterState {
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
 
-  /// The current number of erased operations/blocks.
-  unsigned numErased;
-
   /// The current number of replaced ops that are scheduled for erasure.
   unsigned numReplacedOps;
 };
@@ -274,8 +271,9 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
+    block->dropAllUses();
     if (block->getParent())
-      eraseBlock(block);
+      block->erase();
     else
       delete block;
   }
@@ -905,7 +903,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
     /// Pointers to all erased operations and blocks.
-    SetVector<void *> erased;
+    DenseSet<void *> erased;
   };
 
   //===--------------------------------------------------------------------===//
@@ -1091,7 +1089,7 @@ void CreateOperationRewrite::rollback() {
       region.getBlocks().remove(region.getBlocks().begin());
   }
   op->dropAllUses();
-  eraseOp(op);
+  op->erase();
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
@@ -1099,7 +1097,7 @@ void UnresolvedMaterializationRewrite::rollback() {
     for (Value input : op->getOperands())
       rewriterImpl.mapping.erase(input);
   }
-  eraseOp(op);
+  op->erase();
 }
 
 void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
@@ -1116,8 +1114,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(rewrites.size(), ignoredOps.size(),
-                       eraseRewriter.erased.size(), replacedOps.size());
+  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1128,9 +1125,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
 
-  while (eraseRewriter.erased.size() != state.numErased)
-    eraseRewriter.erased.pop_back();
-
   while (replacedOps.size() != state.numReplacedOps)
     replacedOps.pop_back();
 }



More information about the Mlir-commits mailing list