[Mlir-commits] [mlir] [mlir][IR] Notify about block insertion when cloning an op (PR #80262)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 1 01:43:25 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
`OpBuilder::clone(Operation &)` should trigger not only `notifyOperationInserted` but also `notifyBlockInserted` (for all block contained in `op`).
---
Full diff: https://github.com/llvm/llvm-project/pull/80262.diff
3 Files Affected:
- (modified) mlir/lib/IR/Builders.cpp (+14)
- (modified) mlir/test/Transforms/test-strict-pattern-driver.mlir (+36-7)
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+28-1)
``````````diff
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7acef1073c6de..589d41de9b8bc 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -525,16 +525,30 @@ LogicalResult OpBuilder::tryFold(Operation *op,
Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
Operation *newOp = op.clone(mapper);
newOp = insert(newOp);
+
// The `insert` call above handles the notification for inserting `newOp`
// itself. But if `newOp` has any regions, we need to notify the listener
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
+ // Helper function that sends block insertion notifications for every block
+ // within the given op.
+ auto notifyBlockInsertions = [&](Operation *op) {
+ for (Region &r : op->getRegions())
+ for (Block &b : r.getBlocks())
+ listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+ /*previousIt=*/{});
+ };
+ // The `insert` call above notifies about op insertion, but not about block
+ // insertion.
+ notifyBlockInsertions(newOp);
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+ notifyBlockInsertions(walkedOp);
};
for (Region ®ion : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);
}
+
return newOp;
}
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 6d7ccf161c35d..5d889979a1f92 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -272,19 +272,20 @@ func.func @test_inline_block_before() {
// -----
+// CHECK-AN: notifyBlockInserted into test.op_with_region: was unlinked
// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
// CHECK-AN: notifyOperationRemoved: test.split_block_here
// CHECK-AN-LABEL: func @test_split_block(
-// CHECK: "test.op_with_region"() ({
-// CHECK: test.op_1
-// CHECK: ^{{.*}}:
-// CHECK: test.new_op
-// CHECK: test.op_2
-// CHECK: test.op_3
-// CHECK: }) : () -> ()
+// CHECK-AN: "test.op_with_region"() ({
+// CHECK-AN: test.op_1
+// CHECK-AN: ^{{.*}}:
+// CHECK-AN: test.new_op
+// CHECK-AN: test.op_2
+// CHECK-AN: test.op_3
+// CHECK-AN: }) : () -> ()
func.func @test_split_block() {
"test.op_with_region"() ({
"test.op_1"() : () -> ()
@@ -294,3 +295,31 @@ func.func @test_split_block() {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.clone_me, was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyOperationInserted: test.foo, was unlinked
+// CHECK-AN: notifyOperationInserted: test.bar, was unlinked
+// CHECK-AN-LABEL: func @clone_op(
+// CHECK-AN: "test.clone_me"() ({
+// CHECK-AN: "test.foo"() : () -> ()
+// CHECK-AN: ^bb1: // no predecessors
+// CHECK-AN: "test.bar"() : () -> ()
+// CHECK-AN: }) {was_cloned} : () -> ()
+// CHECK-AN: "test.clone_me"() ({
+// CHECK-AN: "test.foo"() : () -> ()
+// CHECK-AN: ^bb1: // no predecessors
+// CHECK-AN: "test.bar"() : () -> ()
+// CHECK-AN: }) : () -> ()
+func.func @clone_op() {
+ "test.clone_me"() ({
+ ^bb0:
+ "test.foo"() : () -> ()
+ ^bb1:
+ "test.bar"() : () -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 307ae58ba74c5..e3978d3789cf0 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -251,6 +251,22 @@ struct SplitBlockHere : public RewritePattern {
}
};
+/// This pattern clones "test.clone_me" ops.
+struct CloneOp : public RewritePattern {
+ CloneOp(MLIRContext *context)
+ : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not clone already cloned ops to avoid going into an infinite loop.
+ if (op->hasAttr("was_cloned"))
+ return failure();
+ Operation *cloned = rewriter.clone(*op);
+ cloned->setAttr("was_cloned", rewriter.getUnitAttr());
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -291,6 +307,16 @@ struct TestPatternDriver
};
struct DumpNotifications : public RewriterBase::Listener {
+ void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) override {
+ llvm::outs() << "notifyBlockInserted into "
+ << block->getParentOp()->getName() << ": ";
+ if (previous == nullptr) {
+ llvm::outs() << "was unlinked\n";
+ } else {
+ llvm::outs() << "was linked\n";
+ }
+ }
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
llvm::outs() << "notifyOperationInserted: " << op->getName();
@@ -331,6 +357,7 @@ struct TestStrictPatternDriver
patterns.add<
// clang-format off
ChangeBlockOp,
+ CloneOp,
EraseOp,
ImplicitChangeOp,
InlineBlocksIntoParent,
@@ -347,7 +374,7 @@ struct TestStrictPatternDriver
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
opName == "test.move_before_parent_op" ||
opName == "test.inline_blocks_into_parent" ||
- opName == "test.split_block_here") {
+ opName == "test.split_block_here" || opName == "test.clone_me") {
ops.push_back(op);
}
});
``````````
</details>
https://github.com/llvm/llvm-project/pull/80262
More information about the Mlir-commits
mailing list