[Mlir-commits] [mlir] [mlir][IR] Trigger nested operation/block insertion notifications for clones (PR #66871)
Matthias Springer
llvmlistbot at llvm.org
Thu Feb 1 03:27:56 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/66871
>From c149370a584fb1ecffaa4997c57e0f7c595cedcc Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 1 Feb 2024 11:26:32 +0000
Subject: [PATCH 1/2] [mlir][IR] Notify about block insertion when cloning an
op
`OpBuilder::clone(Operation &)` should trigger not only `notifyOperationInserted` but also `notifyBlockInserted` (for all block contained in `op`).
---
mlir/lib/IR/Builders.cpp | 14 ++++++
.../test-strict-pattern-driver.mlir | 43 ++++++++++++++++---
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 29 ++++++++++++-
3 files changed, 78 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7acef1073c6de..589d41de9b8bc 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -525,16 +525,30 @@ LogicalResult OpBuilder::tryFold(Operation *op,
Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
Operation *newOp = op.clone(mapper);
newOp = insert(newOp);
+
// The `insert` call above handles the notification for inserting `newOp`
// itself. But if `newOp` has any regions, we need to notify the listener
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
+ // Helper function that sends block insertion notifications for every block
+ // within the given op.
+ auto notifyBlockInsertions = [&](Operation *op) {
+ for (Region &r : op->getRegions())
+ for (Block &b : r.getBlocks())
+ listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+ /*previousIt=*/{});
+ };
+ // The `insert` call above notifies about op insertion, but not about block
+ // insertion.
+ notifyBlockInsertions(newOp);
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+ notifyBlockInsertions(walkedOp);
};
for (Region ®ion : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);
}
+
return newOp;
}
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 6d7ccf161c35d..5d889979a1f92 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -272,19 +272,20 @@ func.func @test_inline_block_before() {
// -----
+// CHECK-AN: notifyBlockInserted into test.op_with_region: was unlinked
// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
// CHECK-AN: notifyOperationInserted: test.op_2, was last in block
// CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
// CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
// CHECK-AN: notifyOperationRemoved: test.split_block_here
// CHECK-AN-LABEL: func @test_split_block(
-// CHECK: "test.op_with_region"() ({
-// CHECK: test.op_1
-// CHECK: ^{{.*}}:
-// CHECK: test.new_op
-// CHECK: test.op_2
-// CHECK: test.op_3
-// CHECK: }) : () -> ()
+// CHECK-AN: "test.op_with_region"() ({
+// CHECK-AN: test.op_1
+// CHECK-AN: ^{{.*}}:
+// CHECK-AN: test.new_op
+// CHECK-AN: test.op_2
+// CHECK-AN: test.op_3
+// CHECK-AN: }) : () -> ()
func.func @test_split_block() {
"test.op_with_region"() ({
"test.op_1"() : () -> ()
@@ -294,3 +295,31 @@ func.func @test_split_block() {
}) : () -> ()
return
}
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.clone_me, was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyOperationInserted: test.foo, was unlinked
+// CHECK-AN: notifyOperationInserted: test.bar, was unlinked
+// CHECK-AN-LABEL: func @clone_op(
+// CHECK-AN: "test.clone_me"() ({
+// CHECK-AN: "test.foo"() : () -> ()
+// CHECK-AN: ^bb1: // no predecessors
+// CHECK-AN: "test.bar"() : () -> ()
+// CHECK-AN: }) {was_cloned} : () -> ()
+// CHECK-AN: "test.clone_me"() ({
+// CHECK-AN: "test.foo"() : () -> ()
+// CHECK-AN: ^bb1: // no predecessors
+// CHECK-AN: "test.bar"() : () -> ()
+// CHECK-AN: }) : () -> ()
+func.func @clone_op() {
+ "test.clone_me"() ({
+ ^bb0:
+ "test.foo"() : () -> ()
+ ^bb1:
+ "test.bar"() : () -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 307ae58ba74c5..e3978d3789cf0 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -251,6 +251,22 @@ struct SplitBlockHere : public RewritePattern {
}
};
+/// This pattern clones "test.clone_me" ops.
+struct CloneOp : public RewritePattern {
+ CloneOp(MLIRContext *context)
+ : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not clone already cloned ops to avoid going into an infinite loop.
+ if (op->hasAttr("was_cloned"))
+ return failure();
+ Operation *cloned = rewriter.clone(*op);
+ cloned->setAttr("was_cloned", rewriter.getUnitAttr());
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -291,6 +307,16 @@ struct TestPatternDriver
};
struct DumpNotifications : public RewriterBase::Listener {
+ void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) override {
+ llvm::outs() << "notifyBlockInserted into "
+ << block->getParentOp()->getName() << ": ";
+ if (previous == nullptr) {
+ llvm::outs() << "was unlinked\n";
+ } else {
+ llvm::outs() << "was linked\n";
+ }
+ }
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
llvm::outs() << "notifyOperationInserted: " << op->getName();
@@ -331,6 +357,7 @@ struct TestStrictPatternDriver
patterns.add<
// clang-format off
ChangeBlockOp,
+ CloneOp,
EraseOp,
ImplicitChangeOp,
InlineBlocksIntoParent,
@@ -347,7 +374,7 @@ struct TestStrictPatternDriver
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
opName == "test.move_before_parent_op" ||
opName == "test.inline_blocks_into_parent" ||
- opName == "test.split_block_here") {
+ opName == "test.split_block_here" || opName == "test.clone_me") {
ops.push_back(op);
}
});
>From 3051c7d199c19204502572f7fa51e7994ee63027 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 1 Feb 2024 11:26:53 +0000
Subject: [PATCH 2/2] [mlir][IR] Send notifications for `cloneRegionBefore`
Similar to `OpBuilder::clone`, operation/block insertion notifications should be sent when cloning the contents of a region. E.g., this is to ensure that the newly created operations are put on the worklist of the greedy pattern rewriter driver.
Also move `cloneRegionBefore` from `RewriterBase` to `OpBuilder`. It only creates new IR, so it should be part of the builder API (like `clone(Operation &)`). The function does not have to be virtual. Now that notifications are properly sent, the override in the dialect conversion is no longer needed.
BEGIN_PUBLIC
No public commit message for presubmit.
END_PUBLIC
---
mlir/include/mlir/IR/Builders.h | 10 ++++
mlir/include/mlir/IR/PatternMatch.h | 10 ----
.../mlir/Transforms/DialectConversion.h | 8 ---
mlir/lib/IR/Builders.cpp | 52 +++++++++++++++----
mlir/lib/IR/PatternMatch.cpp | 18 -------
.../Transforms/Utils/DialectConversion.cpp | 17 ------
mlir/test/Transforms/test-legalizer-full.mlir | 4 +-
.../test-strict-pattern-driver.mlir | 31 +++++++++++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 22 +++++++-
9 files changed, 107 insertions(+), 65 deletions(-)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 8c25a1aa2fad1..c4e165d5a053d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -583,6 +583,16 @@ class OpBuilder : public Builder {
return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
}
+ /// Clone the blocks that belong to "region" before the given position in
+ /// another region "parent". The two regions must be different. The caller is
+ /// responsible for creating or updating the operation transferring flow of
+ /// control to the region and passing it the correct block arguments.
+ void cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before, IRMapping &mapping);
+ void cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before);
+ void cloneRegionBefore(Region ®ion, Block *before);
+
protected:
/// The optional listener for events of this builder.
Listener *listener;
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 72cbc85e9081d..61da27825e870 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -500,16 +500,6 @@ class RewriterBase : public OpBuilder {
Region::iterator before);
void inlineRegionBefore(Region ®ion, Block *before);
- /// Clone the blocks that belong to "region" before the given position in
- /// another region "parent". The two regions must be different. The caller is
- /// responsible for creating or updating the operation transferring flow of
- /// control to the region and passing it the correct block arguments.
- virtual void cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before, IRMapping &mapping);
- void cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before);
- void cloneRegionBefore(Region ®ion, Block *before);
-
/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when the provided `functor` returns true for a specific use.
/// The number of values in `newValues` is required to match the number of
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index d9470de9ceb9f..51e3e413b516f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -724,14 +724,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
ValueRange argValues = std::nullopt) override;
using PatternRewriter::inlineBlockBefore;
- /// PatternRewriter hook for cloning blocks of one region into another. The
- /// given region to clone *must* not have been modified as part of conversion
- /// yet, i.e. it must be within an operation that is either in the process of
- /// conversion, or has not yet been converted.
- void cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before, IRMapping &mapping) override;
- using PatternRewriter::cloneRegionBefore;
-
/// PatternRewriter hook for inserting a new operation.
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 589d41de9b8bc..a905aae86a426 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -522,6 +522,16 @@ LogicalResult OpBuilder::tryFold(Operation *op,
return success();
}
+/// Helper function that sends block insertion notifications for every block
+/// that is directly nested in the given op.
+static void notifyBlockInsertions(Operation *op,
+ OpBuilder::Listener *listener) {
+ for (Region &r : op->getRegions())
+ for (Block &b : r.getBlocks())
+ listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+ /*previousIt=*/{});
+}
+
Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
Operation *newOp = op.clone(mapper);
newOp = insert(newOp);
@@ -530,20 +540,12 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
// itself. But if `newOp` has any regions, we need to notify the listener
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
- // Helper function that sends block insertion notifications for every block
- // within the given op.
- auto notifyBlockInsertions = [&](Operation *op) {
- for (Region &r : op->getRegions())
- for (Block &b : r.getBlocks())
- listener->notifyBlockInserted(&b, /*previous=*/nullptr,
- /*previousIt=*/{});
- };
// The `insert` call above notifies about op insertion, but not about block
// insertion.
- notifyBlockInsertions(newOp);
+ notifyBlockInsertions(newOp, listener);
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
- notifyBlockInsertions(walkedOp);
+ notifyBlockInsertions(walkedOp, listener);
};
for (Region ®ion : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);
@@ -556,3 +558,33 @@ Operation *OpBuilder::clone(Operation &op) {
IRMapping mapper;
return clone(op, mapper);
}
+
+void OpBuilder::cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before, IRMapping &mapping) {
+ region.cloneInto(&parent, before, mapping);
+
+ // Fast path: If no listener is attached, there is no more work to do.
+ if (!listener)
+ return;
+
+ // Notify about op/block insertion.
+ for (auto it = mapping.lookup(®ion.front())->getIterator(); it != before;
+ ++it) {
+ listener->notifyBlockInserted(&*it, /*previous=*/nullptr,
+ /*previousIt=*/{});
+ it->walk<WalkOrder::PreOrder>([&](Operation *walkedOp) {
+ listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+ notifyBlockInsertions(walkedOp, listener);
+ });
+ }
+}
+
+void OpBuilder::cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before) {
+ IRMapping mapping;
+ cloneRegionBefore(region, parent, before, mapping);
+}
+
+void OpBuilder::cloneRegionBefore(Region ®ion, Block *before) {
+ cloneRegionBefore(region, *before->getParent(), before->getIterator());
+}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 22b5ad749f0c6..9204733c99bab 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -384,24 +384,6 @@ void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) {
inlineRegionBefore(region, *before->getParent(), before->getIterator());
}
-/// Clone the blocks that belong to "region" before the given position in
-/// another region "parent". The two regions must be different. The caller is
-/// responsible for creating or updating the operation transferring flow of
-/// control to the region and passing it the correct block arguments.
-void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before,
- IRMapping &mapping) {
- region.cloneInto(&parent, before, mapping);
-}
-void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before) {
- IRMapping mapping;
- cloneRegionBefore(region, parent, before, mapping);
-}
-void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
- cloneRegionBefore(region, *before->getParent(), before->getIterator());
-}
-
void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
moveBlockBefore(block, anotherBlock->getParent(),
anotherBlock->getIterator());
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3928b98568bf3..346135fb44722 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1573,23 +1573,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
eraseBlock(source);
}
-void ConversionPatternRewriter::cloneRegionBefore(Region ®ion,
- Region &parent,
- Region::iterator before,
- IRMapping &mapping) {
- if (region.empty())
- return;
-
- PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
-
- for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
- Block *cloned = mapping.lookup(&b);
- impl->notifyInsertedBlock(cloned, /*previous=*/nullptr, /*previousIt=*/{});
- cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
- [&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
- }
-}
-
void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
InsertPoint previous) {
assert(!previous.isSet() && "expected newly created op");
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index ecb17d5f1b67d..74f312e8144a0 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -110,9 +110,11 @@ builtin.module {
// expected-error at +1 {{failed to legalize operation 'test.region'}}
"test.region"() ({
^bb1(%i0: i64):
- cf.br ^bb2(%i0 : i64)
+ cf.br ^bb3(%i0 : i64)
^bb2(%i1: i64):
"test.invalid"(%i1) : (i64) -> ()
+ ^bb3(%i2: i64):
+ cf.br ^bb2(%i2 : i64)
}) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> ()
"test.return"() : () -> ()
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 5d889979a1f92..559561b34ceec 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -323,3 +323,34 @@ func.func @clone_op() {
}) : () -> ()
return
}
+
+
+// -----
+
+// CHECK-AN: notifyBlockInserted into func.func: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_1, was unlinked
+// CHECK-AN: notifyBlockInserted into func.func: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_2, was unlinked
+// CHECK-AN: notifyBlockInserted into test.op_2: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_3, was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_4, was unlinked
+// CHECK-AN-LABEL: func @test_clone_region_before(
+// CHECK-AN: "test.op_1"() : () -> ()
+// CHECK-AN: ^{{.*}}:
+// CHECK-AN: "test.op_2"() ({
+// CHECK-AN: "test.op_3"() : () -> ()
+// CHECK-AN: }) : () -> ()
+// CHECK-AN: "test.op_4"() : () -> ()
+// CHECK-AN: ^{{.*}}:
+// CHECK-AN: "test.clone_region_before"() ({
+func.func @test_clone_region_before() {
+ "test.clone_region_before"() ({
+ "test.op_1"() : () -> ()
+ ^bb0:
+ "test.op_2"() ({
+ "test.op_3"() : () -> ()
+ }) : () -> ()
+ "test.op_4"() : () -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e3978d3789cf0..d7e5d6db50c1f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -267,6 +267,24 @@ struct CloneOp : public RewritePattern {
}
};
+/// This pattern clones regions of "test.clone_region_before" ops before the
+/// parent block.
+struct CloneRegionBeforeOp : public RewritePattern {
+ CloneRegionBeforeOp(MLIRContext *context)
+ : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not clone already cloned ops to avoid going into an infinite loop.
+ if (op->hasAttr("was_cloned"))
+ return failure();
+ for (Region &r : op->getRegions())
+ rewriter.cloneRegionBefore(r, op->getBlock());
+ op->setAttr("was_cloned", rewriter.getUnitAttr());
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -358,6 +376,7 @@ struct TestStrictPatternDriver
// clang-format off
ChangeBlockOp,
CloneOp,
+ CloneRegionBeforeOp,
EraseOp,
ImplicitChangeOp,
InlineBlocksIntoParent,
@@ -374,7 +393,8 @@ struct TestStrictPatternDriver
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
opName == "test.move_before_parent_op" ||
opName == "test.inline_blocks_into_parent" ||
- opName == "test.split_block_here" || opName == "test.clone_me") {
+ opName == "test.split_block_here" || opName == "test.clone_me" ||
+ opName == "test.clone_region_before") {
ops.push_back(op);
}
});
More information about the Mlir-commits
mailing list