[Mlir-commits] [mlir] b840d29 - [mlir][IR] Send notifications for `cloneRegionBefore` (#66871)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 2 01:06:14 PST 2024
Author: Matthias Springer
Date: 2024-02-02T10:06:10+01:00
New Revision: b840d2968391dd610b792a65133a1edc1bcc397c
URL: https://github.com/llvm/llvm-project/commit/b840d2968391dd610b792a65133a1edc1bcc397c
DIFF: https://github.com/llvm/llvm-project/commit/b840d2968391dd610b792a65133a1edc1bcc397c.diff
LOG: [mlir][IR] Send notifications for `cloneRegionBefore` (#66871)
Similar to `OpBuilder::clone`, operation/block insertion notifications
should be sent when cloning the contents of a region. E.g., this is to
ensure that the newly created operations are put on the worklist of the
greedy pattern rewriter driver.
Also move `cloneRegionBefore` from `RewriterBase` to `OpBuilder`. It
only creates new IR, so it should be part of the builder API (like
`clone(Operation &)`). The function does not have to be virtual. Now
that notifications are properly sent, the override in the dialect
conversion is no longer needed.
Added:
Modified:
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/IR/Builders.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Transforms/test-legalizer-full.mlir
mlir/test/Transforms/test-strict-pattern-driver.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 4fc29c65f2e68..2fe1495b2b593 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -583,6 +583,16 @@ class OpBuilder : public Builder {
return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
}
+ /// Clone the blocks that belong to "region" before the given position in
+ /// another region "parent". The two regions must be
diff erent. The caller is
+ /// responsible for creating or updating the operation transferring flow of
+ /// control to the region and passing it the correct block arguments.
+ void cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before, IRMapping &mapping);
+ void cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before);
+ void cloneRegionBefore(Region ®ion, Block *before);
+
protected:
/// The optional listener for events of this builder.
Listener *listener;
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 72cbc85e9081d..61da27825e870 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -500,16 +500,6 @@ class RewriterBase : public OpBuilder {
Region::iterator before);
void inlineRegionBefore(Region ®ion, Block *before);
- /// Clone the blocks that belong to "region" before the given position in
- /// another region "parent". The two regions must be
diff erent. The caller is
- /// responsible for creating or updating the operation transferring flow of
- /// control to the region and passing it the correct block arguments.
- virtual void cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before, IRMapping &mapping);
- void cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before);
- void cloneRegionBefore(Region ®ion, Block *before);
-
/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when the provided `functor` returns true for a specific use.
/// The number of values in `newValues` is required to match the number of
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index d9470de9ceb9f..51e3e413b516f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -724,14 +724,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
ValueRange argValues = std::nullopt) override;
using PatternRewriter::inlineBlockBefore;
- /// PatternRewriter hook for cloning blocks of one region into another. The
- /// given region to clone *must* not have been modified as part of conversion
- /// yet, i.e. it must be within an operation that is either in the process of
- /// conversion, or has not yet been converted.
- void cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before, IRMapping &mapping) override;
- using PatternRewriter::cloneRegionBefore;
-
/// PatternRewriter hook for inserting a new operation.
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index e7725a1d9fd2a..2e42c4e870716 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -522,6 +522,16 @@ LogicalResult OpBuilder::tryFold(Operation *op,
return success();
}
+/// Helper function that sends block insertion notifications for every block
+/// that is directly nested in the given op.
+static void notifyBlockInsertions(Operation *op,
+ OpBuilder::Listener *listener) {
+ for (Region &r : op->getRegions())
+ for (Block &b : r.getBlocks())
+ listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+ /*previousIt=*/{});
+}
+
Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
Operation *newOp = op.clone(mapper);
newOp = insert(newOp);
@@ -530,20 +540,12 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
// itself. But if `newOp` has any regions, we need to notify the listener
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
- // Helper function that sends block insertion notifications for every block
- // within the given op.
- auto notifyBlockInsertions = [&](Operation *op) {
- for (Region &r : op->getRegions())
- for (Block &b : r.getBlocks())
- listener->notifyBlockInserted(&b, /*previous=*/nullptr,
- /*previousIt=*/{});
- };
// The `insert` call above notifies about op insertion, but not about block
// insertion.
- notifyBlockInsertions(newOp);
+ notifyBlockInsertions(newOp, listener);
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
- notifyBlockInsertions(walkedOp);
+ notifyBlockInsertions(walkedOp, listener);
};
for (Region ®ion : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);
@@ -556,3 +558,33 @@ Operation *OpBuilder::clone(Operation &op) {
IRMapping mapper;
return clone(op, mapper);
}
+
+void OpBuilder::cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before, IRMapping &mapping) {
+ region.cloneInto(&parent, before, mapping);
+
+ // Fast path: If no listener is attached, there is no more work to do.
+ if (!listener)
+ return;
+
+ // Notify about op/block insertion.
+ for (auto it = mapping.lookup(®ion.front())->getIterator(); it != before;
+ ++it) {
+ listener->notifyBlockInserted(&*it, /*previous=*/nullptr,
+ /*previousIt=*/{});
+ it->walk<WalkOrder::PreOrder>([&](Operation *walkedOp) {
+ listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+ notifyBlockInsertions(walkedOp, listener);
+ });
+ }
+}
+
+void OpBuilder::cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before) {
+ IRMapping mapping;
+ cloneRegionBefore(region, parent, before, mapping);
+}
+
+void OpBuilder::cloneRegionBefore(Region ®ion, Block *before) {
+ cloneRegionBefore(region, *before->getParent(), before->getIterator());
+}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 22b5ad749f0c6..9204733c99bab 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -384,24 +384,6 @@ void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) {
inlineRegionBefore(region, *before->getParent(), before->getIterator());
}
-/// Clone the blocks that belong to "region" before the given position in
-/// another region "parent". The two regions must be
diff erent. The caller is
-/// responsible for creating or updating the operation transferring flow of
-/// control to the region and passing it the correct block arguments.
-void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before,
- IRMapping &mapping) {
- region.cloneInto(&parent, before, mapping);
-}
-void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before) {
- IRMapping mapping;
- cloneRegionBefore(region, parent, before, mapping);
-}
-void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
- cloneRegionBefore(region, *before->getParent(), before->getIterator());
-}
-
void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
moveBlockBefore(block, anotherBlock->getParent(),
anotherBlock->getIterator());
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3928b98568bf3..346135fb44722 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1573,23 +1573,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
eraseBlock(source);
}
-void ConversionPatternRewriter::cloneRegionBefore(Region ®ion,
- Region &parent,
- Region::iterator before,
- IRMapping &mapping) {
- if (region.empty())
- return;
-
- PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
-
- for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
- Block *cloned = mapping.lookup(&b);
- impl->notifyInsertedBlock(cloned, /*previous=*/nullptr, /*previousIt=*/{});
- cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
- [&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
- }
-}
-
void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
InsertPoint previous) {
assert(!previous.isSet() && "expected newly created op");
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index ecb17d5f1b67d..74f312e8144a0 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -110,9 +110,11 @@ builtin.module {
// expected-error at +1 {{failed to legalize operation 'test.region'}}
"test.region"() ({
^bb1(%i0: i64):
- cf.br ^bb2(%i0 : i64)
+ cf.br ^bb3(%i0 : i64)
^bb2(%i1: i64):
"test.invalid"(%i1) : (i64) -> ()
+ ^bb3(%i2: i64):
+ cf.br ^bb2(%i2 : i64)
}) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> ()
"test.return"() : () -> ()
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 5d889979a1f92..559561b34ceec 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -323,3 +323,34 @@ func.func @clone_op() {
}) : () -> ()
return
}
+
+
+// -----
+
+// CHECK-AN: notifyBlockInserted into func.func: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_1, was unlinked
+// CHECK-AN: notifyBlockInserted into func.func: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_2, was unlinked
+// CHECK-AN: notifyBlockInserted into test.op_2: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_3, was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_4, was unlinked
+// CHECK-AN-LABEL: func @test_clone_region_before(
+// CHECK-AN: "test.op_1"() : () -> ()
+// CHECK-AN: ^{{.*}}:
+// CHECK-AN: "test.op_2"() ({
+// CHECK-AN: "test.op_3"() : () -> ()
+// CHECK-AN: }) : () -> ()
+// CHECK-AN: "test.op_4"() : () -> ()
+// CHECK-AN: ^{{.*}}:
+// CHECK-AN: "test.clone_region_before"() ({
+func.func @test_clone_region_before() {
+ "test.clone_region_before"() ({
+ "test.op_1"() : () -> ()
+ ^bb0:
+ "test.op_2"() ({
+ "test.op_3"() : () -> ()
+ }) : () -> ()
+ "test.op_4"() : () -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e3978d3789cf0..d7e5d6db50c1f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -267,6 +267,24 @@ struct CloneOp : public RewritePattern {
}
};
+/// This pattern clones regions of "test.clone_region_before" ops before the
+/// parent block.
+struct CloneRegionBeforeOp : public RewritePattern {
+ CloneRegionBeforeOp(MLIRContext *context)
+ : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Do not clone already cloned ops to avoid going into an infinite loop.
+ if (op->hasAttr("was_cloned"))
+ return failure();
+ for (Region &r : op->getRegions())
+ rewriter.cloneRegionBefore(r, op->getBlock());
+ op->setAttr("was_cloned", rewriter.getUnitAttr());
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -358,6 +376,7 @@ struct TestStrictPatternDriver
// clang-format off
ChangeBlockOp,
CloneOp,
+ CloneRegionBeforeOp,
EraseOp,
ImplicitChangeOp,
InlineBlocksIntoParent,
@@ -374,7 +393,8 @@ struct TestStrictPatternDriver
opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
opName == "test.move_before_parent_op" ||
opName == "test.inline_blocks_into_parent" ||
- opName == "test.split_block_here" || opName == "test.clone_me") {
+ opName == "test.split_block_here" || opName == "test.clone_me" ||
+ opName == "test.clone_region_before") {
ops.push_back(op);
}
});
More information about the Mlir-commits
mailing list