[Mlir-commits] [mlir] da784a2 - [mlir][IR] Add `RewriterBase::moveBlockBefore` and fix bug in `moveOpBefore` (#79579)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 31 02:25:15 PST 2024
Author: Matthias Springer
Date: 2024-01-31T11:25:11+01:00
New Revision: da784a25557e29996bd33638d51d569ddf989faf
URL: https://github.com/llvm/llvm-project/commit/da784a25557e29996bd33638d51d569ddf989faf
DIFF: https://github.com/llvm/llvm-project/commit/da784a25557e29996bd33638d51d569ddf989faf.diff
LOG: [mlir][IR] Add `RewriterBase::moveBlockBefore` and fix bug in `moveOpBefore` (#79579)
This commit adds a new method to the rewriter API: `moveBlockBefore`.
This op is utilized by `inlineRegionBefore` and covered by dialect
conversion test cases.
Also fixes a bug in `moveOpBefore`, where the previous op location was
not passed correctly. Adds a test case to
`test-strict-pattern-driver.mlir`.
Added:
Modified:
mlir/include/mlir/IR/Block.h
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/IR/Block.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/test/Transforms/test-strict-pattern-driver.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 4139dcaeea81b..c14e9aad8f6d1 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -67,6 +67,10 @@ class Block : public IRObjectWithUseList<BlockOperand>,
/// specific block.
void moveBefore(Block *block);
+ /// Unlink this block from its current region and insert it right before the
+ /// block that the given iterator points to in the region region.
+ void moveBefore(Region *region, llvm::iplist<Block>::iterator iterator);
+
/// Unlink this Block from its parent region and delete it.
void erase();
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8eb129206b95e..72cbc85e9081d 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -614,6 +614,13 @@ class RewriterBase : public OpBuilder {
virtual void moveOpAfter(Operation *op, Block *block,
Block::iterator iterator);
+ /// Unlink this block and insert it right before `existingBlock`.
+ void moveBlockBefore(Block *block, Block *anotherBlock);
+
+ /// Unlink this block and insert it right before the location that the given
+ /// iterator points to in the given region.
+ void moveBlockBefore(Block *block, Region *region, Region::iterator iterator);
+
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 82ea303cf0171..65099f8ff15a6 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -52,8 +52,13 @@ void Block::insertAfter(Block *block) {
/// specific block.
void Block::moveBefore(Block *block) {
assert(block->getParent() && "cannot insert before a block without a parent");
- block->getParent()->getBlocks().splice(
- block->getIterator(), getParent()->getBlocks(), getIterator());
+ moveBefore(block->getParent(), block->getIterator());
+}
+
+/// Unlink this block from its current region and insert it right before the
+/// block that the given iterator points to in the region region.
+void Block::moveBefore(Region *region, llvm::iplist<Block>::iterator iterator) {
+ region->getBlocks().splice(iterator, getParent()->getBlocks(), getIterator());
}
/// Unlink this Block from its parent Region and delete it.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 817bbb363e0d5..ee285f094cdd1 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -350,11 +350,8 @@ void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent,
}
// Move blocks from the beginning of the region one-by-one.
- while (!region.empty()) {
- Block *block = ®ion.front();
- parent.getBlocks().splice(before, region.getBlocks(), block->getIterator());
- listener->notifyBlockInserted(block, ®ion, region.begin());
- }
+ while (!region.empty())
+ moveBlockBefore(®ion.front(), &parent, before);
}
void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) {
inlineRegionBefore(region, *before->getParent(), before->getIterator());
@@ -378,6 +375,21 @@ void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
+void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
+ moveBlockBefore(block, anotherBlock->getParent(),
+ anotherBlock->getIterator());
+}
+
+void RewriterBase::moveBlockBefore(Block *block, Region *region,
+ Region::iterator iterator) {
+ Region *currentRegion = block->getParent();
+ Region::iterator nextIterator = std::next(block->getIterator());
+ block->moveBefore(region, iterator);
+ if (listener)
+ listener->notifyBlockInserted(block, /*previous=*/currentRegion,
+ /*previousIt=*/nextIterator);
+}
+
void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
}
@@ -385,11 +397,11 @@ void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
void RewriterBase::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
Block *currentBlock = op->getBlock();
- Block::iterator currentIterator = op->getIterator();
+ Block::iterator nextIterator = std::next(op->getIterator());
op->moveBefore(block, iterator);
if (listener)
listener->notifyOperationInserted(
- op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+ op, /*previous=*/InsertPoint(currentBlock, nextIterator));
}
void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
@@ -398,10 +410,6 @@ void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
void RewriterBase::moveOpAfter(Operation *op, Block *block,
Block::iterator iterator) {
- Block *currentBlock = op->getBlock();
- Block::iterator currentIterator = op->getIterator();
- op->moveAfter(block, iterator);
- if (listener)
- listener->notifyOperationInserted(
- op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+ assert(iterator != block->end() && "cannot move after end of block");
+ moveOpBefore(op, block, std::next(iterator));
}
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index a5ab8f97c74ce..2ebc66b4f26d0 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -24,6 +24,7 @@ func.func @test_erase() {
// -----
+// CHECK-EN: notifyOperationInserted: test.insert_same_op, was unlinked
// CHECK-EN-LABEL: func @test_insert_same_op
// CHECK-EN-SAME: {pattern_driver_all_erased = false, pattern_driver_changed = true}
// CHECK-EN: "test.insert_same_op"() {skip = true}
@@ -35,6 +36,7 @@ func.func @test_insert_same_op() {
// -----
+// CHECK-EN: notifyOperationInserted: test.new_op, was unlinked
// CHECK-EN-LABEL: func @test_replace_with_new_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN: %[[n:.*]] = "test.new_op"
@@ -49,6 +51,9 @@ func.func @test_replace_with_new_op() {
// -----
+// CHECK-EN: notifyOperationInserted: test.erase_op, was unlinked
+// CHECK-EN: notifyOperationRemoved: test.replace_with_new_op
+// CHECK-EN: notifyOperationRemoved: test.erase_op
// CHECK-EN-LABEL: func @test_replace_with_erase_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
// CHECK-EN-NOT: "test.replace_with_new_op"
@@ -229,3 +234,18 @@ func.func @test_remove_diamond(%c: i1) {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.move_before_parent_op, previous = test.dummy_terminator
+// CHECK-AN-LABEL: func @test_move_op_before(
+// CHECK-AN: test.move_before_parent_op
+// CHECK-AN: test.op_with_region
+// CHECK-AN: test.dummy_terminator
+func.func @test_move_op_before() {
+ "test.op_with_region"() ({
+ "test.move_before_parent_op"() : () -> ()
+ "test.dummy_terminator"() : () ->()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 89b9d1ce78a52..2165b388d559c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -198,6 +198,21 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
}
};
+/// This pattern moves "test.move_before_parent_op" before the parent op.
+struct MoveBeforeParentOp : public RewritePattern {
+ MoveBeforeParentOp(MLIRContext *context)
+ : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not hoist past functions.
+ if (isa<FunctionOpInterface>(op->getParentOp()))
+ return failure();
+ rewriter.moveOpBefore(op, op->getParentOp());
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -238,6 +253,20 @@ struct TestPatternDriver
};
struct DumpNotifications : public RewriterBase::Listener {
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override {
+ llvm::outs() << "notifyOperationInserted: " << op->getName();
+ if (!previous.isSet()) {
+ llvm::outs() << ", was unlinked\n";
+ } else {
+ if (previous.getPoint() == previous.getBlock()->end()) {
+ llvm::outs() << ", was last in block\n";
+ } else {
+ llvm::outs() << ", previous = " << previous.getPoint()->getName()
+ << "\n";
+ }
+ }
+ }
void notifyOperationRemoved(Operation *op) override {
llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
}
@@ -267,14 +296,16 @@ struct TestStrictPatternDriver
ReplaceWithNewOp,
EraseOp,
ChangeBlockOp,
- ImplicitChangeOp
+ ImplicitChangeOp,
+ MoveBeforeParentOp
// clang-format on
>(ctx);
SmallVector<Operation *> ops;
getOperation()->walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef();
if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
- opName == "test.replace_with_new_op" || opName == "test.erase_op") {
+ opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
+ opName == "test.move_before_parent_op") {
ops.push_back(op);
}
});
More information about the Mlir-commits
mailing list