[Mlir-commits] [mlir] [mlir][IR] Add `RewriterBase::moveBlockBefore` and fix bug in `moveOpBefore` (PR #79579)
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 26 03:38:07 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/79579
This commit adds a new method to the rewriter API: `moveBlockBefore`. This op is utilized by `inlineRegionBefore` and covered by dialect conversion test cases.
Also fixes a bug in `moveOpBefore`, where the previous op location was not passed correctly. Adds a test case to `test-strict-pattern-driver.mlir`.
>From d6d6ca49430c8b5898222a102ed403328160a20b Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 26 Jan 2024 11:29:58 +0000
Subject: [PATCH] [mlir][IR] Add `RewriterBase::moveBlockBefore` and fix bug in
`moveOpBefore`
This commit adds a new method to the rewriter API: `moveBlockBefore`. This op is utilized by `inlineRegionBefore` and covered by dialect conversion test cases.
Also fixes a bug in `moveOpBefore`, where the previous op location was not passed correctly. Adds a test case to `test-strict-pattern-driver.mlir`.
---
mlir/include/mlir/IR/PatternMatch.h | 7 ++++
mlir/lib/IR/PatternMatch.cpp | 35 ++++++++++++-------
.../test-strict-pattern-driver.mlir | 20 +++++++++++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 35 +++++++++++++++++--
4 files changed, 82 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8eb129206b95ef6..72cbc85e9081d7e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -614,6 +614,13 @@ class RewriterBase : public OpBuilder {
virtual void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator);
+ /// Unlink this block and insert it right before `existingBlock`.
+ void moveBlockBefore(Block *block, Block *anotherBlock);
+
+ /// Unlink this block and insert it right before the location that the given
+ /// iterator points to in the given region.
+ void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
+
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 817bbb363e0d585..17a56982ddeafb5 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -350,11 +350,8 @@ void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent,
}
// Move blocks from the beginning of the region one-by-one.
- while (!region.empty()) {
- Block *block = ®ion.front();
- parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
- listener->notifyBlockInserted(block, ®ion, region.begin());
- }
+ while (!region.empty())
+ moveBlockBefore(®ion.front(), &parent, before);
}
void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) {
inlineRegionBefore(region, *before->getParent(), before->getIterator());
@@ -378,6 +375,22 @@ void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
+void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
+ moveBlockBefore(block, anotherBlock->getParent(),
+ anotherBlock->getIterator());
+}
+
+void RewriterBase::moveBlockBefore(Block *block, Region *region,
+ Region::iterator iterator) {
+ Region *currentRegion = block->getParent();
+ Region::iterator nextIterator = std::next(block->getIterator());
+ region->getBlocks().splice(iterator, currentRegion->getBlocks(),
+ block->getIterator());
+ if (listener)
+ listener->notifyBlockInserted(block, /*previous=*/currentRegion,
+ /*previousIt=*/nextIterator);
+}
+
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
}
@@ -385,11 +398,11 @@ void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
void RewriterBase::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
- Block::iterator currentIterator = op->getIterator();
+ Block::iterator nextIterator = std::next(op->getIterator());
op->moveBefore(block, iterator);
if (listener)
listener->notifyOperationInserted(
- op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+ op, /*previous=*/InsertPoint(currentBlock, nextIterator));
}
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
@@ -398,10 +411,6 @@ void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
void RewriterBase::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
- Block *currentBlock = op->getBlock();
- Block::iterator currentIterator = op->getIterator();
- op->moveAfter(block, iterator);
- if (listener)
- listener->notifyOperationInserted(
- op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+ assert(iterator != block->end() && "cannot move after end of block");
+ moveOpBefore(op, block, std::next(iterator));
}
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index a5ab8f97c74ce33..2ebc66b4f26d0e8 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -24,6 +24,7 @@ func.func @test_erase() {
// -----
+// CHECK-EN: notifyOperationInserted: test.insert_same_op, was unlinked
// CHECK-EN-LABEL: func @test_insert_same_op
// CHECK-EN-SAME: {pattern_driver_all_erased = false, pattern_driver_changed = true}
// CHECK-EN: "test.insert_same_op"() {skip = true}
@@ -35,6 +36,7 @@ func.func @test_insert_same_op() {
// -----
+// CHECK-EN: notifyOperationInserted: test.new_op, was unlinked
// CHECK-EN-LABEL: func @test_replace_with_new_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN: %[[n:.*]] = "test.new_op"
@@ -49,6 +51,9 @@ func.func @test_replace_with_new_op() {
// -----
+// CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
+// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
+// CHECK-EN: notifyOperationRemoved: test.erase_op
// CHECK-EN-LABEL: func @test_replace_with_erase_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN-NOT: "test.replace_with_new_op"
@@ -229,3 +234,18 @@ func.func @test_remove_diamond(%c: i1) {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.move_before_parent_op, previous = test.dummy_terminator
+// CHECK-AN-LABEL: func @test_move_op_before(
+// CHECK-AN: test.move_before_parent_op
+// CHECK-AN: test.op_with_region
+// CHECK-AN: test.dummy_terminator
+func.func @test_move_op_before() {
+ "test.op_with_region"() ({
+ "test.move_before_parent_op"() : () -> ()
+ "test.dummy_terminator"() : () ->()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 89b9d1ce78a52b6..2165b388d559c5e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -198,6 +198,21 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
}
};
+/// This pattern moves "test.move_before_parent_op" before the parent op.
+struct MoveBeforeParentOp : public RewritePattern {
+ MoveBeforeParentOp(MLIRContext *context)
+ : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not hoist past functions.
+ if (isa<FunctionOpInterface>(op->getParentOp()))
+ return failure();
+ rewriter.moveOpBefore(op, op->getParentOp());
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -238,6 +253,20 @@ struct TestPatternDriver
};
struct DumpNotifications : public RewriterBase::Listener {
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override {
+ llvm::outs() << "notifyOperationInserted: " << op->getName();
+ if (!previous.isSet()) {
+ llvm::outs() << ", was unlinked\n";
+ } else {
+ if (previous.getPoint() == previous.getBlock()->end()) {
+ llvm::outs() << ", was last in block\n";
+ } else {
+ llvm::outs() << ", previous = " << previous.getPoint()->getName()
+ << "\n";
+ }
+ }
+ }
void notifyOperationRemoved(Operation *op) override {
llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
}
@@ -267,14 +296,16 @@ struct TestStrictPatternDriver
ReplaceWithNewOp,
EraseOp,
ChangeBlockOp,
- ImplicitChangeOp
+ ImplicitChangeOp,
+ MoveBeforeParentOp
// clang-format on
>(ctx);
SmallVector<Operation *> ops;
getOperation()->walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
- opName == "test.replace_with_new_op" || opName == "test.erase_op") {
+ opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
+ opName == "test.move_before_parent_op") {
ops.push_back(op);
}
});
More information about the Mlir-commits
mailing list