[Mlir-commits] [mlir] [mlir][IR] Send missing notification when splitting a block (PR #79597)
Matthias Springer
llvmlistbot at llvm.org
Wed Jan 31 05:43:12 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/79597
>From 0282ce7c1246d50b1e0a2dd6fdfa032b02ec07e0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 26 Jan 2024 14:02:54 +0000
Subject: [PATCH] [mlir][IR] Send missing notification when splitting a block
When a block is split with `RewriterBase::splitBlock`, a `notifyBlockInserted` notification, followed by `notifyOperationInserted` notifications (for moving over the operations into the new block) should be sent. This commit adds those notifications.
Depends on #79593. Review only the top commit.
---
mlir/lib/IR/PatternMatch.cpp | 20 ++++++++++++++-
.../Transforms/Utils/DialectConversion.cpp | 2 +-
.../test-strict-pattern-driver.mlir | 25 +++++++++++++++++++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 24 ++++++++++++++++--
4 files changed, 67 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 0bf79f5817b5a..22b5ad749f0c6 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -343,7 +343,25 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
- return block->splitBlock(before);
+ // Fast path: If no listener is attached, split the block directly.
+ if (!listener)
+ return block->splitBlock(before);
+
+ // `createBlock` sets the insertion point at the beginning of the new block.
+ InsertionGuard g(*this);
+ Block *newBlock =
+ createBlock(block->getParent(), std::next(block->getIterator()));
+
+ // If `before` points to end of the block, no ops should be moved.
+ if (before == block->end())
+ return newBlock;
+
+ // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
+ // Stop when the operation pointed to by `before` has been moved.
+ while (before->getBlock() != newBlock)
+ moveOpBefore(&block->back(), newBlock, newBlock->begin());
+
+ return newBlock;
}
/// Move the blocks that belong to "region" before the given position in
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a79e9076fc28f..3928b98568bf3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1548,7 +1548,7 @@ void ConversionPatternRewriter::notifyBlockInserted(
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
- auto *continuation = PatternRewriter::splitBlock(block, before);
+ auto *continuation = block->splitBlock(before);
impl->notifySplitBlock(block, continuation);
return continuation;
}
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 4785795121653..6d7ccf161c35d 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -269,3 +269,28 @@ func.func @test_inline_block_before() {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
+// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
+// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
+// CHECK-AN: notifyOperationRemoved: test.split_block_here
+// CHECK-AN-LABEL: func @test_split_block(
+// CHECK: "test.op_with_region"() ({
+// CHECK: test.op_1
+// CHECK: ^{{.*}}:
+// CHECK: test.new_op
+// CHECK: test.op_2
+// CHECK: test.op_3
+// CHECK: }) : () -> ()
+func.func @test_split_block() {
+ "test.op_with_region"() ({
+ "test.op_1"() : () -> ()
+ "test.split_block_here"() : () -> ()
+ "test.op_2"() : () -> ()
+ "test.op_3"() : () -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 9ce2ab8daa234..307ae58ba74c5 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -233,6 +233,24 @@ struct InlineBlocksIntoParent : public RewritePattern {
}
};
+/// This pattern splits blocks at "test.split_block_here" and replaces the op
+/// with a new op (to prevent an infinite loop of block splitting).
+struct SplitBlockHere : public RewritePattern {
+ SplitBlockHere(MLIRContext *context)
+ : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ rewriter.splitBlock(op->getBlock(), op->getIterator());
+ Operation *newOp = rewriter.create(
+ op->getLoc(),
+ OperationName("test.new_op", op->getContext()).getIdentifier(),
+ op->getOperands(), op->getResultTypes());
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -318,7 +336,8 @@ struct TestStrictPatternDriver
InlineBlocksIntoParent,
InsertSameOp,
MoveBeforeParentOp,
- ReplaceWithNewOp
+ ReplaceWithNewOp,
+ SplitBlockHere
// clang-format on
>(ctx);
SmallVector<Operation *> ops;
@@ -327,7 +346,8 @@ struct TestStrictPatternDriver
if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
opName == "test.move_before_parent_op" ||
- opName == "test.inline_blocks_into_parent") {
+ opName == "test.inline_blocks_into_parent" ||
+ opName == "test.split_block_here") {
ops.push_back(op);
}
});
More information about the Mlir-commits
mailing list