[Mlir-commits] [mlir] [mlir][IR] Adjust insertion block when splitting blocks / moving ops (PR #150819)

Matthias Springer llvmlistbot at llvm.org
Sun Jul 27 02:53:55 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/150819

>From 36065f5400a796ff0b96c14dc95b467e86b13690 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 27 Jul 2025 08:35:51 +0000
Subject: [PATCH] [mlir][IR] Move insertion point when splitting blocks /
 moving ops

---
 mlir/include/mlir/IR/Block.h                  |  4 +++
 mlir/include/mlir/IR/PatternMatch.h           | 15 ++++++++++
 .../GPU/Transforms/AllReduceLowering.cpp      |  1 +
 mlir/lib/IR/Block.cpp                         | 10 +++++++
 mlir/lib/IR/Dominance.cpp                     | 18 ++----------
 mlir/lib/IR/PatternMatch.cpp                  | 28 +++++++++++++++++--
 6 files changed, 57 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index e486bb627474d..416e8e510cd4b 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -152,6 +152,10 @@ class alignas(8) Block : public IRObjectWithUseList<BlockOperand>,
   Operation &back() { return operations.back(); }
   Operation &front() { return operations.front(); }
 
+  /// Return if the iterator `a` is before `b`. Both iterators must point into
+  /// this block.
+  bool isBeforeInBlock(iterator a, iterator b);
+
   /// Returns 'op' if 'op' lies in this block, or otherwise finds the
   /// ancestor operation of 'op' that lies in this block. Returns nullptr if
   /// the latter fails.
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index b5a93a0c5a898..fa87d6987c52c 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -576,24 +576,39 @@ class RewriterBase : public OpBuilder {
 
   /// Split the operations starting at "before" (inclusive) out of the given
   /// block into a new block, and return it.
+  ///
+  /// If the current insertion point is before the split point, the insertion
+  /// point is adjusted to the new block.
   Block *splitBlock(Block *block, Block::iterator before);
 
   /// Unlink this operation from its current block and insert it right before
   /// `existingOp` which may be in the same or another block in the same
   /// function.
+  ///
+  /// If the insertion point is before the moved operation, the insertion block
+  /// is adjusted to the block of `existingOp`.
   void moveOpBefore(Operation *op, Operation *existingOp);
 
   /// Unlink this operation from its current block and insert it right before
   /// `iterator` in the specified block.
+  ///
+  /// If the insertion point is before the moved operation, the insertion block
+  /// is adjusted to the specified block.
   void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
 
   /// Unlink this operation from its current block and insert it right after
   /// `existingOp` which may be in the same or another block in the same
   /// function.
+  ///
+  /// If the insertion point is before the moved operation, the insertion block
+  /// is adjusted to the block of `existingOp`.
   void moveOpAfter(Operation *op, Operation *existingOp);
 
   /// Unlink this operation from its current block and insert it right after
   /// `iterator` in the specified block.
+  ///
+  /// If the insertion point is before the moved operation, the insertion block
+  /// is adjusted to the specified block.
   void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
 
   /// Unlink this block and insert it right before `existingBlock`.
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 8c449144af3a9..0155b478f7cfc 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -184,6 +184,7 @@ struct GpuAllReduceRewriter {
     return [&body, this](Value lhs, Value rhs) -> Value {
       Block *block = rewriter.getInsertionBlock();
       Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
+      rewriter.setInsertionPointToEnd(block);
 
       // Insert accumulator body between split block.
       IRMapping mapping;
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 57825d9b42178..9dc486726cbaf 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -68,6 +68,16 @@ void Block::erase() {
   getParent()->getBlocks().erase(this);
 }
 
+bool Block::isBeforeInBlock(iterator a, iterator b) {
+  if (a == b)
+    return false;
+  if (a == end())
+    return false;
+  if (b == end())
+    return true;
+  return a->isBeforeInBlock(&*b);
+}
+
 /// Returns 'op' if 'op' lies in this block, or otherwise finds the
 /// ancestor operation of 'op' that lies in this block. Returns nullptr if
 /// the latter fails.
diff --git a/mlir/lib/IR/Dominance.cpp b/mlir/lib/IR/Dominance.cpp
index 0e53b431b5d31..b256137360b8e 100644
--- a/mlir/lib/IR/Dominance.cpp
+++ b/mlir/lib/IR/Dominance.cpp
@@ -235,20 +235,6 @@ findAncestorIteratorInRegion(Region *r, Block *b, Block::iterator it) {
   return std::make_pair(op->getBlock(), op->getIterator());
 }
 
-/// Given two iterators into the same block, return "true" if `a` is before `b.
-/// Note: This is a variant of Operation::isBeforeInBlock that operates on
-/// block iterators instead of ops.
-static bool isBeforeInBlock(Block *block, Block::iterator a,
-                            Block::iterator b) {
-  if (a == b)
-    return false;
-  if (a == block->end())
-    return false;
-  if (b == block->end())
-    return true;
-  return a->isBeforeInBlock(&*b);
-}
-
 template <bool IsPostDom>
 bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
     Block *aBlock, Block::iterator aIt, Block *bBlock, Block::iterator bIt,
@@ -290,9 +276,9 @@ bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
     if (!hasSSADominance(aBlock))
       return true;
     if constexpr (IsPostDom) {
-      return isBeforeInBlock(aBlock, bIt, aIt);
+      return aBlock->isBeforeInBlock(bIt, aIt);
     } else {
-      return isBeforeInBlock(aBlock, aIt, bIt);
+      return aBlock->isBeforeInBlock(aIt, bIt);
     }
   }
 
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 9332f55bd9393..2cb45419f0306 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -9,6 +9,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Iterators.h"
 #include "mlir/IR/RegionKindInterface.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
 using namespace mlir;
@@ -348,14 +349,29 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
 /// Split the operations starting at "before" (inclusive) out of the given
 /// block into a new block, and return it.
 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
+  Block *newBlock;
+
+  // If the current insertion point is at or after the split point, adjust the
+  // insertion point to the new block.
+  bool moveIpToNewBlock = getBlock() == block &&
+                          !block->isBeforeInBlock(getInsertionPoint(), before);
+  auto adjustInsertionPoint = llvm::make_scope_exit([&]() {
+    if (getInsertionPoint() == block->end()) {
+      // If the insertion point is at the end of the block, move it to the end
+      // of the new block.
+      setInsertionPointToEnd(newBlock);
+    } else if (moveIpToNewBlock) {
+      setInsertionPoint(newBlock, getInsertionPoint());
+    }
+  });
+
   // Fast path: If no listener is attached, split the block directly.
   if (!listener)
-    return block->splitBlock(before);
+    return newBlock = block->splitBlock(before);
 
   // `createBlock` sets the insertion point at the beginning of the new block.
   InsertionGuard g(*this);
-  Block *newBlock =
-      createBlock(block->getParent(), std::next(block->getIterator()));
+  newBlock = createBlock(block->getParent(), std::next(block->getIterator()));
 
   // If `before` points to end of the block, no ops should be moved.
   if (before == block->end())
@@ -413,6 +429,12 @@ void RewriterBase::moveOpBefore(Operation *op, Block *block,
   Block *currentBlock = op->getBlock();
   Block::iterator nextIterator = std::next(op->getIterator());
   op->moveBefore(block, iterator);
+
+  // If the current insertion point is before the moved operation, we may have
+  // to adjust the insertion block.
+  if (getInsertionPoint() == op->getIterator())
+    setInsertionPoint(block, op->getIterator());
+
   if (listener)
     listener->notifyOperationInserted(
         op, /*previous=*/InsertPoint(currentBlock, nextIterator));



More information about the Mlir-commits mailing list