[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Simplify handling of erased IR (PR #83423)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Feb 29 04:55:59 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/83423

The dialect conversion uses a `SingleEraseRewriter` to ensure that an op/block is not erased twice. This can happen during the "commit" phase when an unresolved materialization is inserted into a block and the enclosing op is erased by the user. In that case, the unresolved materialization should not be erased a second time later in the "commit" phase.

This problem cannot happen during "rollback", so ops/block can be erased directly without using the rewriter. With this change, the `SingleEraseRewriter` is used only during "commit"/"cleanup". At that point, the dialect conversion is guaranteed to succeed and no rollback can happen. Therefore, it is not necessary to store the number of erased IR objects (because we will never "reset" the rewriter to previous a previous state).


>From 6b6c4b1a7dd4943bfe2d97245e8369b9ba63aa20 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 29 Feb 2024 12:48:28 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Do not use SingleEraseRewriter during
 rollback

---
 .../Transforms/Utils/DialectConversion.cpp    | 22 +++++++------------
 1 file changed, 8 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index cac990d498d7d3..9f6468402686bd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numErased, unsigned numReplacedOps)
+                unsigned numReplacedOps)
       : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
-        numErased(numErased), numReplacedOps(numReplacedOps) {}
+        numReplacedOps(numReplacedOps) {}
 
   /// The current number of rewrites performed.
   unsigned numRewrites;
@@ -163,9 +163,6 @@ struct RewriterState {
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
 
-  /// The current number of erased operations/blocks.
-  unsigned numErased;
-
   /// The current number of replaced ops that are scheduled for erasure.
   unsigned numReplacedOps;
 };
@@ -273,8 +270,9 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
+    block->dropAllUses();
     if (block->getParent())
-      eraseBlock(block);
+      block->erase();
     else
       delete block;
   }
@@ -858,7 +856,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
     /// Pointers to all erased operations and blocks.
-    SetVector<void *> erased;
+    DenseSet<void *> erased;
   };
 
   //===--------------------------------------------------------------------===//
@@ -1044,7 +1042,7 @@ void CreateOperationRewrite::rollback() {
       region.getBlocks().remove(region.getBlocks().begin());
   }
   op->dropAllUses();
-  eraseOp(op);
+  op->erase();
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
@@ -1052,7 +1050,7 @@ void UnresolvedMaterializationRewrite::rollback() {
     for (Value input : op->getOperands())
       rewriterImpl.mapping.erase(input);
   }
-  eraseOp(op);
+  op->erase();
 }
 
 void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
@@ -1069,8 +1067,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(rewrites.size(), ignoredOps.size(),
-                       eraseRewriter.erased.size(), replacedOps.size());
+  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1081,9 +1078,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
 
-  while (eraseRewriter.erased.size() != state.numErased)
-    eraseRewriter.erased.pop_back();
-
   while (replacedOps.size() != state.numReplacedOps)
     replacedOps.pop_back();
 }



More information about the llvm-branch-commits mailing list