[Mlir-commits] [mlir] [mlir][IR] Send missing notification when splitting a block (PR #79597)

Matthias Springer llvmlistbot at llvm.org
Wed Jan 31 05:43:12 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/79597

>From 0282ce7c1246d50b1e0a2dd6fdfa032b02ec07e0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 26 Jan 2024 14:02:54 +0000
Subject: [PATCH] [mlir][IR] Send missing notification when splitting a block

When a block is split with `RewriterBase::splitBlock`, a `notifyBlockInserted` notification, followed by `notifyOperationInserted` notifications (for moving over the operations into the new block) should be sent. This commit adds those notifications.

Depends on #79593. Review only the top commit.
---
 mlir/lib/IR/PatternMatch.cpp                  | 20 ++++++++++++++-
 .../Transforms/Utils/DialectConversion.cpp    |  2 +-
 .../test-strict-pattern-driver.mlir           | 25 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 24 ++++++++++++++++--
 4 files changed, 67 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 0bf79f5817b5a..22b5ad749f0c6 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -343,7 +343,25 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
 /// Split the operations starting at "before" (inclusive) out of the given
 /// block into a new block, and return it.
 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
-  return block->splitBlock(before);
+  // Fast path: If no listener is attached, split the block directly.
+  if (!listener)
+    return block->splitBlock(before);
+
+  // `createBlock` sets the insertion point at the beginning of the new block.
+  InsertionGuard g(*this);
+  Block *newBlock =
+      createBlock(block->getParent(), std::next(block->getIterator()));
+
+  // If `before` points to end of the block, no ops should be moved.
+  if (before == block->end())
+    return newBlock;
+
+  // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
+  // Stop when the operation pointed to by `before` has been moved.
+  while (before->getBlock() != newBlock)
+    moveOpBefore(&block->back(), newBlock, newBlock->begin());
+
+  return newBlock;
 }
 
 /// Move the blocks that belong to "region" before the given position in
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a79e9076fc28f..3928b98568bf3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1548,7 +1548,7 @@ void ConversionPatternRewriter::notifyBlockInserted(
 
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {
-  auto *continuation = PatternRewriter::splitBlock(block, before);
+  auto *continuation = block->splitBlock(before);
   impl->notifySplitBlock(block, continuation);
   return continuation;
 }
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 4785795121653..6d7ccf161c35d 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -269,3 +269,28 @@ func.func @test_inline_block_before() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
+// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
+// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
+// CHECK-AN: notifyOperationRemoved: test.split_block_here
+// CHECK-AN-LABEL: func @test_split_block(
+//          CHECK:   "test.op_with_region"() ({
+//          CHECK:     test.op_1
+//          CHECK:   ^{{.*}}:
+//          CHECK:     test.new_op
+//          CHECK:     test.op_2
+//          CHECK:     test.op_3
+//          CHECK:   }) : () -> ()
+func.func @test_split_block() {
+  "test.op_with_region"() ({
+    "test.op_1"() : () -> ()
+    "test.split_block_here"() : () -> ()
+    "test.op_2"() : () -> ()
+    "test.op_3"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 9ce2ab8daa234..307ae58ba74c5 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -233,6 +233,24 @@ struct InlineBlocksIntoParent : public RewritePattern {
   }
 };
 
+/// This pattern splits blocks at "test.split_block_here" and replaces the op
+/// with a new op (to prevent an infinite loop of block splitting).
+struct SplitBlockHere : public RewritePattern {
+  SplitBlockHere(MLIRContext *context)
+      : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.splitBlock(op->getBlock(), op->getIterator());
+    Operation *newOp = rewriter.create(
+        op->getLoc(),
+        OperationName("test.new_op", op->getContext()).getIdentifier(),
+        op->getOperands(), op->getResultTypes());
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -318,7 +336,8 @@ struct TestStrictPatternDriver
         InlineBlocksIntoParent,
         InsertSameOp,
         MoveBeforeParentOp,
-        ReplaceWithNewOp
+        ReplaceWithNewOp,
+        SplitBlockHere
         // clang-format on
         >(ctx);
     SmallVector<Operation *> ops;
@@ -327,7 +346,8 @@ struct TestStrictPatternDriver
       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
           opName == "test.move_before_parent_op" ||
-          opName == "test.inline_blocks_into_parent") {
+          opName == "test.inline_blocks_into_parent" ||
+          opName == "test.split_block_here") {
         ops.push_back(op);
       }
     });



More information about the Mlir-commits mailing list