[Mlir-commits] [mlir] [mlir][IR] Adjust insertion block when splitting blocks / moving ops (PR #150819)
Matthias Springer
llvmlistbot at llvm.org
Sun Jul 27 02:18:04 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/150819
>From 98fdd7cce4126ce1149efd0addcc1c15f6171368 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 27 Jul 2025 08:35:51 +0000
Subject: [PATCH] [mlir][IR] Move insertion point when splitting blocks /
moving ops
---
mlir/include/mlir/IR/PatternMatch.h | 15 ++++++++++++++
.../GPU/Transforms/AllReduceLowering.cpp | 1 +
mlir/lib/IR/PatternMatch.cpp | 20 ++++++++++++++++---
3 files changed, 33 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index b5a93a0c5a898..fa87d6987c52c 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -576,24 +576,39 @@ class RewriterBase : public OpBuilder {
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
+ ///
+ /// If the current insertion point is before the split point, the insertion
+ /// point is adjusted to the new block.
Block *splitBlock(Block *block, Block::iterator before);
/// Unlink this operation from its current block and insert it right before
/// `existingOp` which may be in the same or another block in the same
/// function.
+ ///
+ /// If the insertion point is before the moved operation, the insertion block
+ /// is adjusted to the block of `existingOp`.
void moveOpBefore(Operation *op, Operation *existingOp);
/// Unlink this operation from its current block and insert it right before
/// `iterator` in the specified block.
+ ///
+ /// If the insertion point is before the moved operation, the insertion block
+ /// is adjusted to the specified block.
void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
/// Unlink this operation from its current block and insert it right after
/// `existingOp` which may be in the same or another block in the same
/// function.
+ ///
+ /// If the insertion point is before the moved operation, the insertion block
+ /// is adjusted to the block of `existingOp`.
void moveOpAfter(Operation *op, Operation *existingOp);
/// Unlink this operation from its current block and insert it right after
/// `iterator` in the specified block.
+ ///
+ /// If the insertion point is before the moved operation, the insertion block
+ /// is adjusted to the specified block.
void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
/// Unlink this block and insert it right before `existingBlock`.
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 8c449144af3a9..0155b478f7cfc 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -184,6 +184,7 @@ struct GpuAllReduceRewriter {
return [&body, this](Value lhs, Value rhs) -> Value {
Block *block = rewriter.getInsertionBlock();
Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
+ rewriter.setInsertionPointToEnd(block);
// Insert accumulator body between split block.
IRMapping mapping;
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 9332f55bd9393..016569718275e 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -9,6 +9,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
using namespace mlir;
@@ -348,14 +349,21 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
+ Block *newBlock;
+ auto adjustInsertionPoint = llvm::make_scope_exit([&]() {
+ // If the current insertion point is before the split point, adjust the
+ // insertion point to the new block.
+ if (getInsertionPoint() == before)
+ setInsertionPoint(newBlock, before);
+ });
+
// Fast path: If no listener is attached, split the block directly.
if (!listener)
- return block->splitBlock(before);
+ return newBlock = block->splitBlock(before);
// `createBlock` sets the insertion point at the beginning of the new block.
InsertionGuard g(*this);
- Block *newBlock =
- createBlock(block->getParent(), std::next(block->getIterator()));
+ newBlock = createBlock(block->getParent(), std::next(block->getIterator()));
// If `before` points to end of the block, no ops should be moved.
if (before == block->end())
@@ -413,6 +421,12 @@ void RewriterBase::moveOpBefore(Operation *op, Block *block,
Block *currentBlock = op->getBlock();
Block::iterator nextIterator = std::next(op->getIterator());
op->moveBefore(block, iterator);
+
+ // If the current insertion point is before the moved operation, we may have
+ // to adjust the insertion block.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPoint(block, op->getIterator());
+
if (listener)
listener->notifyOperationInserted(
op, /*previous=*/InsertPoint(currentBlock, nextIterator));
More information about the Mlir-commits
mailing list