[Mlir-commits] [mlir] [mlir][IR] Adjust insertion block when splitting blocks / moving ops (PR #150819)
Matthias Springer
llvmlistbot at llvm.org
Sun Jul 27 02:59:30 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/150819
>From b7bd595d950ba52d85882c7c844cb7f446169b06 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 27 Jul 2025 08:35:51 +0000
Subject: [PATCH] [mlir][IR] Move insertion point when splitting blocks /
moving ops
---
mlir/include/mlir/IR/Block.h | 4 +++
mlir/include/mlir/IR/PatternMatch.h | 15 ++++++++++
.../GPU/Transforms/AllReduceLowering.cpp | 1 +
mlir/lib/IR/Block.cpp | 10 +++++++
mlir/lib/IR/Dominance.cpp | 18 ++----------
mlir/lib/IR/PatternMatch.cpp | 28 +++++++++++++++++--
6 files changed, 57 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index e486bb627474d..416e8e510cd4b 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -152,6 +152,10 @@ class alignas(8) Block : public IRObjectWithUseList<BlockOperand>,
Operation &back() { return operations.back(); }
Operation &front() { return operations.front(); }
+ /// Return if the iterator `a` is before `b`. Both iterators must point into
+ /// this block.
+ bool isBeforeInBlock(iterator a, iterator b);
+
/// Returns 'op' if 'op' lies in this block, or otherwise finds the
/// ancestor operation of 'op' that lies in this block. Returns nullptr if
/// the latter fails.
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index b5a93a0c5a898..fa87d6987c52c 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -576,24 +576,39 @@ class RewriterBase : public OpBuilder {
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
+ ///
+ /// If the current insertion point is before the split point, the insertion
+ /// point is adjusted to the new block.
Block *splitBlock(Block *block, Block::iterator before);
/// Unlink this operation from its current block and insert it right before
/// `existingOp` which may be in the same or another block in the same
/// function.
+ ///
+ /// If the insertion point is before the moved operation, the insertion block
+ /// is adjusted to the block of `existingOp`.
void moveOpBefore(Operation *op, Operation *existingOp);
/// Unlink this operation from its current block and insert it right before
/// `iterator` in the specified block.
+ ///
+ /// If the insertion point is before the moved operation, the insertion block
+ /// is adjusted to the specified block.
void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
/// Unlink this operation from its current block and insert it right after
/// `existingOp` which may be in the same or another block in the same
/// function.
+ ///
+ /// If the insertion point is before the moved operation, the insertion block
+ /// is adjusted to the block of `existingOp`.
void moveOpAfter(Operation *op, Operation *existingOp);
/// Unlink this operation from its current block and insert it right after
/// `iterator` in the specified block.
+ ///
+ /// If the insertion point is before the moved operation, the insertion block
+ /// is adjusted to the specified block.
void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
/// Unlink this block and insert it right before `existingBlock`.
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 8c449144af3a9..0155b478f7cfc 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -184,6 +184,7 @@ struct GpuAllReduceRewriter {
return [&body, this](Value lhs, Value rhs) -> Value {
Block *block = rewriter.getInsertionBlock();
Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
+ rewriter.setInsertionPointToEnd(block);
// Insert accumulator body between split block.
IRMapping mapping;
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 57825d9b42178..9dc486726cbaf 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -68,6 +68,16 @@ void Block::erase() {
getParent()->getBlocks().erase(this);
}
+bool Block::isBeforeInBlock(iterator a, iterator b) {
+ if (a == b)
+ return false;
+ if (a == end())
+ return false;
+ if (b == end())
+ return true;
+ return a->isBeforeInBlock(&*b);
+}
+
/// Returns 'op' if 'op' lies in this block, or otherwise finds the
/// ancestor operation of 'op' that lies in this block. Returns nullptr if
/// the latter fails.
diff --git a/mlir/lib/IR/Dominance.cpp b/mlir/lib/IR/Dominance.cpp
index 0e53b431b5d31..b256137360b8e 100644
--- a/mlir/lib/IR/Dominance.cpp
+++ b/mlir/lib/IR/Dominance.cpp
@@ -235,20 +235,6 @@ findAncestorIteratorInRegion(Region *r, Block *b, Block::iterator it) {
return std::make_pair(op->getBlock(), op->getIterator());
}
-/// Given two iterators into the same block, return "true" if `a` is before `b.
-/// Note: This is a variant of Operation::isBeforeInBlock that operates on
-/// block iterators instead of ops.
-static bool isBeforeInBlock(Block *block, Block::iterator a,
- Block::iterator b) {
- if (a == b)
- return false;
- if (a == block->end())
- return false;
- if (b == block->end())
- return true;
- return a->isBeforeInBlock(&*b);
-}
-
template <bool IsPostDom>
bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
Block *aBlock, Block::iterator aIt, Block *bBlock, Block::iterator bIt,
@@ -290,9 +276,9 @@ bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
if (!hasSSADominance(aBlock))
return true;
if constexpr (IsPostDom) {
- return isBeforeInBlock(aBlock, bIt, aIt);
+ return aBlock->isBeforeInBlock(bIt, aIt);
} else {
- return isBeforeInBlock(aBlock, aIt, bIt);
+ return aBlock->isBeforeInBlock(aIt, bIt);
}
}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 9332f55bd9393..2cb45419f0306 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -9,6 +9,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
using namespace mlir;
@@ -348,14 +349,29 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
+ Block *newBlock;
+
+ // If the current insertion point is at or after the split point, adjust the
+ // insertion point to the new block.
+ bool moveIpToNewBlock = getBlock() == block &&
+ !block->isBeforeInBlock(getInsertionPoint(), before);
+ auto adjustInsertionPoint = llvm::make_scope_exit([&]() {
+ if (getInsertionPoint() == block->end()) {
+ // If the insertion point is at the end of the block, move it to the end
+ // of the new block.
+ setInsertionPointToEnd(newBlock);
+ } else if (moveIpToNewBlock) {
+ setInsertionPoint(newBlock, getInsertionPoint());
+ }
+ });
+
// Fast path: If no listener is attached, split the block directly.
if (!listener)
- return block->splitBlock(before);
+ return newBlock = block->splitBlock(before);
// `createBlock` sets the insertion point at the beginning of the new block.
InsertionGuard g(*this);
- Block *newBlock =
- createBlock(block->getParent(), std::next(block->getIterator()));
+ newBlock = createBlock(block->getParent(), std::next(block->getIterator()));
// If `before` points to end of the block, no ops should be moved.
if (before == block->end())
@@ -413,6 +429,12 @@ void RewriterBase::moveOpBefore(Operation *op, Block *block,
Block *currentBlock = op->getBlock();
Block::iterator nextIterator = std::next(op->getIterator());
op->moveBefore(block, iterator);
+
+ // If the current insertion point is before the moved operation, we may have
+ // to adjust the insertion block.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPoint(block, op->getIterator());
+
if (listener)
listener->notifyOperationInserted(
op, /*previous=*/InsertPoint(currentBlock, nextIterator));
More information about the Mlir-commits
mailing list