[Mlir-commits] [mlir] [mlir][IR] Send missing notification when splitting a block (PR #79597)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 26 06:05:23 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

When a block is split with `RewriterBase::splitBlock`, a `notifyBlockInserted` notification, followed by `notifyOperationInserted` notifications (for moving over the operations into the new block) should be sent. This commit adds those notifications.

Depends on #<!-- -->79593. Review only the top commit.


---
Full diff: https://github.com/llvm/llvm-project/pull/79597.diff


8 Files Affected:

- (modified) mlir/include/mlir/IR/Block.h (+4) 
- (modified) mlir/include/mlir/IR/PatternMatch.h (+7) 
- (modified) mlir/lib/IR/Block.cpp (+7-2) 
- (modified) mlir/lib/IR/PatternMatch.cpp (+50-15) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+1-1) 
- (modified) mlir/test/Dialect/Affine/simplify-structures.mlir (-2) 
- (modified) mlir/test/Transforms/test-strict-pattern-driver.mlir (+65) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+75-2) 


``````````diff
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 4139dcaeea81bb9..c14e9aad8f6d1e2 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -67,6 +67,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
   /// specific block.
   void moveBefore(Block *block);
 
+  /// Unlink this block from its current region and insert it right before the
+  /// block that the given iterator points to in the region region.
+  void moveBefore(Region *region, llvm::iplist<Block>::iterator iterator);
+
   /// Unlink this Block from its parent region and delete it.
   void erase();
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8eb129206b95ef6..72cbc85e9081d7e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -614,6 +614,13 @@ class RewriterBase : public OpBuilder {
   virtual void moveOpAfter(Operation *op, Block *block,
                            Block::iterator iterator);
 
+  /// Unlink this block and insert it right before `existingBlock`.
+  void moveBlockBefore(Block *block, Block *anotherBlock);
+
+  /// Unlink this block and insert it right before the location that the given
+  /// iterator points to in the given region.
+  void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
+
   /// This method is used to notify the rewriter that an in-place operation
   /// modification is about to happen. A call to this function *must* be
   /// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 82ea303cf0171f3..65099f8ff15a6f7 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -52,8 +52,13 @@ void Block::insertAfter(Block *block) {
 /// specific block.
 void Block::moveBefore(Block *block) {
   assert(block->getParent() && "cannot insert before a block without a parent");
-  block->getParent()->getBlocks().splice(
-      block->getIterator(), getParent()->getBlocks(), getIterator());
+  moveBefore(block->getParent(), block->getIterator());
+}
+
+/// Unlink this block from its current region and insert it right before the
+/// block that the given iterator points to in the region region.
+void Block::moveBefore(Region *region, llvm::iplist<Block>::iterator iterator) {
+  region->getBlocks().splice(iterator, getParent()->getBlocks(), getIterator());
 }
 
 /// Unlink this Block from its parent Region and delete it.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 817bbb363e0d585..22b5ad749f0c6a4 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -317,7 +317,16 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
 
   // Move operations from the source block to the dest block and erase the
   // source block.
-  dest->getOperations().splice(before, source->getOperations());
+  if (!listener) {
+    // Fast path: If no listener is attached, move all operations at once.
+    dest->getOperations().splice(before, source->getOperations());
+  } else {
+    while (!source->empty())
+      moveOpBefore(&source->front(), dest, before);
+  }
+
+  // Erase the source block.
+  assert(source->empty() && "expected 'source' to be empty");
   eraseBlock(source);
 }
 
@@ -334,7 +343,25 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
 /// Split the operations starting at "before" (inclusive) out of the given
 /// block into a new block, and return it.
 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
-  return block->splitBlock(before);
+  // Fast path: If no listener is attached, split the block directly.
+  if (!listener)
+    return block->splitBlock(before);
+
+  // `createBlock` sets the insertion point at the beginning of the new block.
+  InsertionGuard g(*this);
+  Block *newBlock =
+      createBlock(block->getParent(), std::next(block->getIterator()));
+
+  // If `before` points to end of the block, no ops should be moved.
+  if (before == block->end())
+    return newBlock;
+
+  // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
+  // Stop when the operation pointed to by `before` has been moved.
+  while (before->getBlock() != newBlock)
+    moveOpBefore(&block->back(), newBlock, newBlock->begin());
+
+  return newBlock;
 }
 
 /// Move the blocks that belong to "region" before the given position in
@@ -350,11 +377,8 @@ void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
   }
 
   // Move blocks from the beginning of the region one-by-one.
-  while (!region.empty()) {
-    Block *block = &region.front();
-    parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
-    listener->notifyBlockInserted(block, &region, region.begin());
-  }
+  while (!region.empty())
+    moveBlockBefore(&region.front(), &parent, before);
 }
 void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
   inlineRegionBefore(region, *before->getParent(), before->getIterator());
@@ -378,6 +402,21 @@ void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
   cloneRegionBefore(region, *before->getParent(), before->getIterator());
 }
 
+void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
+  moveBlockBefore(block, anotherBlock->getParent(),
+                  anotherBlock->getIterator());
+}
+
+void RewriterBase::moveBlockBefore(Block *block, Region *region,
+                                   Region::iterator iterator) {
+  Region *currentRegion = block->getParent();
+  Region::iterator nextIterator = std::next(block->getIterator());
+  block->moveBefore(region, iterator);
+  if (listener)
+    listener->notifyBlockInserted(block, /*previous=*/currentRegion,
+                                  /*previousIt=*/nextIterator);
+}
+
 void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
   moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
 }
@@ -385,11 +424,11 @@ void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
 void RewriterBase::moveOpBefore(Operation *op, Block *block,
                                 Block::iterator iterator) {
   Block *currentBlock = op->getBlock();
-  Block::iterator currentIterator = op->getIterator();
+  Block::iterator nextIterator = std::next(op->getIterator());
   op->moveBefore(block, iterator);
   if (listener)
     listener->notifyOperationInserted(
-        op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+        op, /*previous=*/InsertPoint(currentBlock, nextIterator));
 }
 
 void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
@@ -398,10 +437,6 @@ void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
 
 void RewriterBase::moveOpAfter(Operation *op, Block *block,
                                Block::iterator iterator) {
-  Block *currentBlock = op->getBlock();
-  Block::iterator currentIterator = op->getIterator();
-  op->moveAfter(block, iterator);
-  if (listener)
-    listener->notifyOperationInserted(
-        op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+  assert(iterator != block->end() && "cannot move after end of block");
+  moveOpBefore(op, block, std::next(iterator));
 }
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a79e9076fc28faf..3928b98568bf3c4 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1548,7 +1548,7 @@ void ConversionPatternRewriter::notifyBlockInserted(
 
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {
-  auto *continuation = PatternRewriter::splitBlock(block, before);
+  auto *continuation = block->splitBlock(before);
   impl->notifySplitBlock(block, continuation);
   return continuation;
 }
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 2c693ea1551c013..92d3d86bc93068f 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -411,8 +411,6 @@ func.func @test_trivially_false_returning_two_results(%arg0: index) -> (index, i
   // CHECK: %[[c13:.*]] = arith.constant 13 : index
   %c7 = arith.constant 7 : index
   %c13 = arith.constant 13 : index
-  // CHECK: %[[c2:.*]] = arith.constant 2 : index
-  // CHECK: %[[c3:.*]] = arith.constant 3 : index
   %res:2 = affine.if affine_set<(d0, d1) : (5 >= 0, -2 >= 0)> (%c7, %c13) -> (index, index) {
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index a5ab8f97c74ce33..6d7ccf161c35dea 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -24,6 +24,7 @@ func.func @test_erase() {
 
 // -----
 
+// CHECK-EN: notifyOperationInserted: test.insert_same_op, was unlinked
 // CHECK-EN-LABEL: func @test_insert_same_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = false, pattern_driver_changed = true}
 //       CHECK-EN:   "test.insert_same_op"() {skip = true}
@@ -35,6 +36,7 @@ func.func @test_insert_same_op() {
 
 // -----
 
+// CHECK-EN: notifyOperationInserted: test.new_op, was unlinked
 // CHECK-EN-LABEL: func @test_replace_with_new_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
 //       CHECK-EN:   %[[n:.*]] = "test.new_op"
@@ -49,6 +51,9 @@ func.func @test_replace_with_new_op() {
 
 // -----
 
+// CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
+// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
+// CHECK-EN: notifyOperationRemoved: test.erase_op
 // CHECK-EN-LABEL: func @test_replace_with_erase_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
 //   CHECK-EN-NOT:   "test.replace_with_new_op"
@@ -229,3 +234,63 @@ func.func @test_remove_diamond(%c: i1) {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.move_before_parent_op, previous = test.dummy_terminator
+// CHECK-AN-LABEL: func @test_move_op_before(
+//       CHECK-AN:   test.move_before_parent_op
+//       CHECK-AN:   test.op_with_region
+//       CHECK-AN:     test.dummy_terminator
+func.func @test_move_op_before() {
+  "test.op_with_region"() ({
+    "test.move_before_parent_op"() : () -> ()
+    "test.dummy_terminator"() : () ->()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_1, previous = test.op_2
+// CHECK-AN: notifyOperationInserted: test.op_2, previous = test.op_3
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN-LABEL: func @test_inline_block_before(
+//       CHECK-AN:   test.op_1
+//       CHECK-AN:   test.op_2
+//       CHECK-AN:   test.op_3
+//       CHECK-AN:   test.inline_blocks_into_parent
+//       CHECK-AN:   return
+func.func @test_inline_block_before() {
+  "test.inline_blocks_into_parent"() ({
+    "test.op_1"() : () -> ()
+    "test.op_2"() : () -> ()
+    "test.op_3"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
+// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
+// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
+// CHECK-AN: notifyOperationRemoved: test.split_block_here
+// CHECK-AN-LABEL: func @test_split_block(
+//          CHECK:   "test.op_with_region"() ({
+//          CHECK:     test.op_1
+//          CHECK:   ^{{.*}}:
+//          CHECK:     test.new_op
+//          CHECK:     test.op_2
+//          CHECK:     test.op_3
+//          CHECK:   }) : () -> ()
+func.func @test_split_block() {
+  "test.op_with_region"() ({
+    "test.op_1"() : () -> ()
+    "test.split_block_here"() : () -> ()
+    "test.op_2"() : () -> ()
+    "test.op_3"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 89b9d1ce78a52b6..c84fa0ede687423 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -198,6 +198,59 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
   }
 };
 
+/// This pattern moves "test.move_before_parent_op" before the parent op.
+struct MoveBeforeParentOp : public RewritePattern {
+  MoveBeforeParentOp(MLIRContext *context)
+      : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Do not hoist past functions.
+    if (isa<FunctionOpInterface>(op->getParentOp()))
+      return failure();
+    rewriter.moveOpBefore(op, op->getParentOp());
+    return success();
+  }
+};
+
+/// This pattern inlines blocks that are nested in
+/// "test.inline_blocks_into_parent" into the parent block.
+struct InlineBlocksIntoParent : public RewritePattern {
+  InlineBlocksIntoParent(MLIRContext *context)
+      : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
+                       context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    for (Region &r : op->getRegions()) {
+      while (!r.empty()) {
+        rewriter.inlineBlockBefore(&r.front(), op);
+        changed = true;
+      }
+    }
+    return success(changed);
+  }
+};
+
+/// This pattern splits blocks at "test.split_block_here" and replaces the op
+/// with a new op (to prevent an infinite loop of block splitting).
+struct SplitBlockHere : public RewritePattern {
+  SplitBlockHere(MLIRContext *context)
+      : RewritePattern("test.split_block_here", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.splitBlock(op->getBlock(), op->getIterator());
+    Operation *newOp = rewriter.create(
+        op->getLoc(),
+        OperationName("test.new_op", op->getContext()).getIdentifier(),
+        op->getOperands(), op->getResultTypes());
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -238,6 +291,20 @@ struct TestPatternDriver
 };
 
 struct DumpNotifications : public RewriterBase::Listener {
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override {
+    llvm::outs() << "notifyOperationInserted: " << op->getName();
+    if (!previous.isSet()) {
+      llvm::outs() << ", was unlinked\n";
+    } else {
+      if (previous.getPoint() == previous.getBlock()->end()) {
+        llvm::outs() << ", was last in block\n";
+      } else {
+        llvm::outs() << ", previous = " << previous.getPoint()->getName()
+                     << "\n";
+      }
+    }
+  }
   void notifyOperationRemoved(Operation *op) override {
     llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
   }
@@ -267,14 +334,20 @@ struct TestStrictPatternDriver
         ReplaceWithNewOp,
         EraseOp,
         ChangeBlockOp,
-        ImplicitChangeOp
+        ImplicitChangeOp,
+        MoveBeforeParentOp,
+        InlineBlocksIntoParent,
+        SplitBlockHere
         // clang-format on
         >(ctx);
     SmallVector<Operation *> ops;
     getOperation()->walk([&](Operation *op) {
       StringRef opName = op->getName().getStringRef();
       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
-          opName == "test.replace_with_new_op" || opName == "test.erase_op") {
+          opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
+          opName == "test.move_before_parent_op" ||
+          opName == "test.inline_blocks_into_parent" ||
+          opName == "test.split_block_here") {
         ops.push_back(op);
       }
     });

``````````

</details>


https://github.com/llvm/llvm-project/pull/79597


More information about the Mlir-commits mailing list