[Mlir-commits] [mlir] [mlir][IR] Set insertion point when erasing an operation (PR #146955)

Matthias Springer llvmlistbot at llvm.org
Sun Jul 27 00:56:44 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/146955

>From 68af84183c349db15b749021b2f600c1d642e2d5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 3 Jul 2025 19:54:19 +0000
Subject: [PATCH 1/2] [mlir][IR][WIP] Set insertion point when erasing an
 operation

---
 mlir/lib/IR/PatternMatch.cpp | 38 ++++++++++++++++++++++++++++++++++++
 1 file changed, 38 insertions(+)

diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5c98417c874d3..5a08c6534b6b6 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -150,12 +150,45 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
   eraseOp(op);
 }
 
+/// Returns the given block iterator if it lies within the block `b`.
+/// Otherwise, otherwise finds the ancestor of the given block iterator that
+/// lies within `b`. Returns and "empty" iterator if the latter fails.
+///
+/// Note: This is a variant of Block::findAncestorOpInBlock that operates on
+/// block iterators instead of ops.
+static std::pair<Block *, Block::iterator>
+findAncestorIteratorInBlock(Block *b, Block *itBlock, Block::iterator it) {
+  // Case 1: The iterator lies within the block.
+  if (itBlock == b)
+    return std::make_pair(itBlock, it);
+
+  // Otherwise: Find ancestor iterator. Bail if we run out of parent ops.
+  Operation *parentOp = itBlock->getParentOp();
+  if (!parentOp)
+    return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
+  Operation *op = b->findAncestorOpInBlock(*parentOp);
+  if (!op)
+    return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
+  return std::make_pair(op->getBlock(), op->getIterator());
+}
+
 /// This method erases an operation that is known to have no uses. The uses of
 /// the given operation *must* be known to be dead.
 void RewriterBase::eraseOp(Operation *op) {
   assert(op->use_empty() && "expected 'op' to have no uses");
   auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
 
+  // If the current insertion point is before/within the erased operation, we
+  // need to adjust the insertion point to be after the operation.
+  if (getInsertionBlock()) {
+    Block *insertionBlock;
+    Block::iterator insertionPoint;
+    std::tie(insertionBlock, insertionPoint) = findAncestorIteratorInBlock(
+        op->getBlock(), getInsertionBlock(), getInsertionPoint());
+    if (insertionBlock && insertionPoint == op->getIterator())
+      setInsertionPointAfter(op);
+  }
+
   // Fast path: If no listener is attached, the op can be dropped in one go.
   if (!rewriteListener) {
     op->erase();
@@ -320,6 +353,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
       moveOpBefore(&source->front(), dest, before);
   }
 
+  // If the current insertion point is within the source block, adjust the
+  // insertion point to the destination block.
+  if (getInsertionBlock() == source)
+    setInsertionPoint(dest, getInsertionPoint());
+
   // Erase the source block.
   assert(source->empty() && "expected 'source' to be empty");
   eraseBlock(source);

>From 1d70ab4e41faf4a3dd9c6a8d5a4a8ca98794f7cc Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 4 Jul 2025 09:48:51 +0000
Subject: [PATCH 2/2] address comments

---
 mlir/include/mlir/IR/PatternMatch.h           | 14 ++++++++
 mlir/lib/IR/PatternMatch.cpp                  | 36 +++----------------
 .../Transforms/Utils/DialectConversion.cpp    | 23 ++++++++++++
 3 files changed, 41 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index b3608b4394f45..b5a93a0c5a898 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -525,6 +525,11 @@ class RewriterBase : public OpBuilder {
   }
 
   /// This method erases an operation that is known to have no uses.
+  ///
+  /// If the current insertion point is before the erased operation, it is
+  /// adjusted to the following operation (or the end of the block). If the
+  /// current insertion point is within the erased operation, the insertion
+  /// point is left in an invalid state.
   virtual void eraseOp(Operation *op);
 
   /// This method erases all operations in a block.
@@ -539,6 +544,9 @@ class RewriterBase : public OpBuilder {
   /// somewhere in the middle (or beginning) of the dest block, the source block
   /// must have no successors. Otherwise, the resulting IR would have
   /// unreachable operations.
+  ///
+  /// If the insertion point is within the source block, it is adjusted to the
+  /// destination block.
   virtual void inlineBlockBefore(Block *source, Block *dest,
                                  Block::iterator before,
                                  ValueRange argValues = {});
@@ -549,6 +557,9 @@ class RewriterBase : public OpBuilder {
   ///
   /// The source block must have no successors. Otherwise, the resulting IR
   /// would have unreachable operations.
+  ///
+  /// If the insertion point is within the source block, it is adjusted to the
+  /// destination block.
   void inlineBlockBefore(Block *source, Operation *op,
                          ValueRange argValues = {});
 
@@ -558,6 +569,9 @@ class RewriterBase : public OpBuilder {
   ///
   /// The dest block must have no successors. Otherwise, the resulting IR would
   /// have unreachable operation.
+  ///
+  /// If the insertion point is within the source block, it is adjusted to the
+  /// destination block.
   void mergeBlocks(Block *source, Block *dest, ValueRange argValues = {});
 
   /// Split the operations starting at "before" (inclusive) out of the given
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5a08c6534b6b6..9332f55bd9393 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -150,44 +150,16 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
   eraseOp(op);
 }
 
-/// Returns the given block iterator if it lies within the block `b`.
-/// Otherwise, otherwise finds the ancestor of the given block iterator that
-/// lies within `b`. Returns and "empty" iterator if the latter fails.
-///
-/// Note: This is a variant of Block::findAncestorOpInBlock that operates on
-/// block iterators instead of ops.
-static std::pair<Block *, Block::iterator>
-findAncestorIteratorInBlock(Block *b, Block *itBlock, Block::iterator it) {
-  // Case 1: The iterator lies within the block.
-  if (itBlock == b)
-    return std::make_pair(itBlock, it);
-
-  // Otherwise: Find ancestor iterator. Bail if we run out of parent ops.
-  Operation *parentOp = itBlock->getParentOp();
-  if (!parentOp)
-    return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
-  Operation *op = b->findAncestorOpInBlock(*parentOp);
-  if (!op)
-    return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
-  return std::make_pair(op->getBlock(), op->getIterator());
-}
-
 /// This method erases an operation that is known to have no uses. The uses of
 /// the given operation *must* be known to be dead.
 void RewriterBase::eraseOp(Operation *op) {
   assert(op->use_empty() && "expected 'op' to have no uses");
   auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
 
-  // If the current insertion point is before/within the erased operation, we
-  // need to adjust the insertion point to be after the operation.
-  if (getInsertionBlock()) {
-    Block *insertionBlock;
-    Block::iterator insertionPoint;
-    std::tie(insertionBlock, insertionPoint) = findAncestorIteratorInBlock(
-        op->getBlock(), getInsertionBlock(), getInsertionPoint());
-    if (insertionBlock && insertionPoint == op->getIterator())
-      setInsertionPointAfter(op);
-  }
+  // If the current insertion point is before the erased operation, we adjust
+  // the insertion point to be after the operation.
+  if (getInsertionPoint() == op->getIterator())
+    setInsertionPointAfter(op);
 
   // Fast path: If no listener is attached, the op can be dropped in one go.
   if (!rewriteListener) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index df255cfcf3ec1..b8c40e34c91a7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1758,6 +1758,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
     impl->logger.startLine()
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
+
+  // If the current insertion point is before the erased operation, we adjust
+  // the insertion point to be after the operation.
+  if (getInsertionPoint() == op->getIterator())
+    setInsertionPointAfter(op);
+
   SmallVector<SmallVector<Value>> newVals =
       llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
         return v ? SmallVector<Value>{v} : SmallVector<Value>();
@@ -1773,6 +1779,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
     impl->logger.startLine()
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
+
+  // If the current insertion point is before the erased operation, we adjust
+  // the insertion point to be after the operation.
+  if (getInsertionPoint() == op->getIterator())
+    setInsertionPointAfter(op);
+
   impl->replaceOp(op, std::move(newValues));
 }
 
@@ -1781,6 +1793,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
     impl->logger.startLine()
         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
   });
+
+  // If the current insertion point is before the erased operation, we adjust
+  // the insertion point to be after the operation.
+  if (getInsertionPoint() == op->getIterator())
+    setInsertionPointAfter(op);
+
   SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
   impl->replaceOp(op, std::move(nullRepls));
 }
@@ -1887,6 +1905,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
       moveOpBefore(&source->front(), dest, before);
   }
 
+  // If the current insertion point is within the source block, adjust the
+  // insertion point to the destination block.
+  if (getInsertionBlock() == source)
+    setInsertionPoint(dest, getInsertionPoint());
+
   // Erase the source block.
   eraseBlock(source);
 }



More information about the Mlir-commits mailing list