[Mlir-commits] [mlir] b840d29 - [mlir][IR] Send notifications for `cloneRegionBefore` (#66871)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 2 01:06:14 PST 2024


Author: Matthias Springer
Date: 2024-02-02T10:06:10+01:00
New Revision: b840d2968391dd610b792a65133a1edc1bcc397c

URL: https://github.com/llvm/llvm-project/commit/b840d2968391dd610b792a65133a1edc1bcc397c
DIFF: https://github.com/llvm/llvm-project/commit/b840d2968391dd610b792a65133a1edc1bcc397c.diff

LOG: [mlir][IR] Send notifications for `cloneRegionBefore` (#66871)

Similar to `OpBuilder::clone`, operation/block insertion notifications
should be sent when cloning the contents of a region. E.g., this is to
ensure that the newly created operations are put on the worklist of the
greedy pattern rewriter driver.

Also move `cloneRegionBefore` from `RewriterBase` to `OpBuilder`. It
only creates new IR, so it should be part of the builder API (like
`clone(Operation &)`). The function does not have to be virtual. Now
that notifications are properly sent, the override in the dialect
conversion is no longer needed.

Added: 
    

Modified: 
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/IR/Builders.cpp
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer-full.mlir
    mlir/test/Transforms/test-strict-pattern-driver.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 4fc29c65f2e68..2fe1495b2b593 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -583,6 +583,16 @@ class OpBuilder : public Builder {
     return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
   }
 
+  /// Clone the blocks that belong to "region" before the given position in
+  /// another region "parent". The two regions must be 
diff erent. The caller is
+  /// responsible for creating or updating the operation transferring flow of
+  /// control to the region and passing it the correct block arguments.
+  void cloneRegionBefore(Region &region, Region &parent,
+                         Region::iterator before, IRMapping &mapping);
+  void cloneRegionBefore(Region &region, Region &parent,
+                         Region::iterator before);
+  void cloneRegionBefore(Region &region, Block *before);
+
 protected:
   /// The optional listener for events of this builder.
   Listener *listener;

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 72cbc85e9081d..61da27825e870 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -500,16 +500,6 @@ class RewriterBase : public OpBuilder {
                           Region::iterator before);
   void inlineRegionBefore(Region &region, Block *before);
 
-  /// Clone the blocks that belong to "region" before the given position in
-  /// another region "parent". The two regions must be 
diff erent. The caller is
-  /// responsible for creating or updating the operation transferring flow of
-  /// control to the region and passing it the correct block arguments.
-  virtual void cloneRegionBefore(Region &region, Region &parent,
-                                 Region::iterator before, IRMapping &mapping);
-  void cloneRegionBefore(Region &region, Region &parent,
-                         Region::iterator before);
-  void cloneRegionBefore(Region &region, Block *before);
-
   /// This method replaces the uses of the results of `op` with the values in
   /// `newValues` when the provided `functor` returns true for a specific use.
   /// The number of values in `newValues` is required to match the number of

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index d9470de9ceb9f..51e3e413b516f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -724,14 +724,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
                          ValueRange argValues = std::nullopt) override;
   using PatternRewriter::inlineBlockBefore;
 
-  /// PatternRewriter hook for cloning blocks of one region into another. The
-  /// given region to clone *must* not have been modified as part of conversion
-  /// yet, i.e. it must be within an operation that is either in the process of
-  /// conversion, or has not yet been converted.
-  void cloneRegionBefore(Region &region, Region &parent,
-                         Region::iterator before, IRMapping &mapping) override;
-  using PatternRewriter::cloneRegionBefore;
-
   /// PatternRewriter hook for inserting a new operation.
   void notifyOperationInserted(Operation *op, InsertPoint previous) override;
 

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index e7725a1d9fd2a..2e42c4e870716 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -522,6 +522,16 @@ LogicalResult OpBuilder::tryFold(Operation *op,
   return success();
 }
 
+/// Helper function that sends block insertion notifications for every block
+/// that is directly nested in the given op.
+static void notifyBlockInsertions(Operation *op,
+                                  OpBuilder::Listener *listener) {
+  for (Region &r : op->getRegions())
+    for (Block &b : r.getBlocks())
+      listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+                                    /*previousIt=*/{});
+}
+
 Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
   Operation *newOp = op.clone(mapper);
   newOp = insert(newOp);
@@ -530,20 +540,12 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
   // itself. But if `newOp` has any regions, we need to notify the listener
   // about any ops that got inserted inside those regions as part of cloning.
   if (listener) {
-    // Helper function that sends block insertion notifications for every block
-    // within the given op.
-    auto notifyBlockInsertions = [&](Operation *op) {
-      for (Region &r : op->getRegions())
-        for (Block &b : r.getBlocks())
-          listener->notifyBlockInserted(&b, /*previous=*/nullptr,
-                                        /*previousIt=*/{});
-    };
     // The `insert` call above notifies about op insertion, but not about block
     // insertion.
-    notifyBlockInsertions(newOp);
+    notifyBlockInsertions(newOp, listener);
     auto walkFn = [&](Operation *walkedOp) {
       listener->notifyOperationInserted(walkedOp, /*previous=*/{});
-      notifyBlockInsertions(walkedOp);
+      notifyBlockInsertions(walkedOp, listener);
     };
     for (Region &region : newOp->getRegions())
       region.walk<WalkOrder::PreOrder>(walkFn);
@@ -556,3 +558,33 @@ Operation *OpBuilder::clone(Operation &op) {
   IRMapping mapper;
   return clone(op, mapper);
 }
+
+void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
+                                  Region::iterator before, IRMapping &mapping) {
+  region.cloneInto(&parent, before, mapping);
+
+  // Fast path: If no listener is attached, there is no more work to do.
+  if (!listener)
+    return;
+
+  // Notify about op/block insertion.
+  for (auto it = mapping.lookup(&region.front())->getIterator(); it != before;
+       ++it) {
+    listener->notifyBlockInserted(&*it, /*previous=*/nullptr,
+                                  /*previousIt=*/{});
+    it->walk<WalkOrder::PreOrder>([&](Operation *walkedOp) {
+      listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+      notifyBlockInsertions(walkedOp, listener);
+    });
+  }
+}
+
+void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
+                                  Region::iterator before) {
+  IRMapping mapping;
+  cloneRegionBefore(region, parent, before, mapping);
+}
+
+void OpBuilder::cloneRegionBefore(Region &region, Block *before) {
+  cloneRegionBefore(region, *before->getParent(), before->getIterator());
+}

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 22b5ad749f0c6..9204733c99bab 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -384,24 +384,6 @@ void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
   inlineRegionBefore(region, *before->getParent(), before->getIterator());
 }
 
-/// Clone the blocks that belong to "region" before the given position in
-/// another region "parent". The two regions must be 
diff erent. The caller is
-/// responsible for creating or updating the operation transferring flow of
-/// control to the region and passing it the correct block arguments.
-void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
-                                     Region::iterator before,
-                                     IRMapping &mapping) {
-  region.cloneInto(&parent, before, mapping);
-}
-void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
-                                     Region::iterator before) {
-  IRMapping mapping;
-  cloneRegionBefore(region, parent, before, mapping);
-}
-void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
-  cloneRegionBefore(region, *before->getParent(), before->getIterator());
-}
-
 void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
   moveBlockBefore(block, anotherBlock->getParent(),
                   anotherBlock->getIterator());

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3928b98568bf3..346135fb44722 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1573,23 +1573,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
   eraseBlock(source);
 }
 
-void ConversionPatternRewriter::cloneRegionBefore(Region &region,
-                                                  Region &parent,
-                                                  Region::iterator before,
-                                                  IRMapping &mapping) {
-  if (region.empty())
-    return;
-
-  PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
-
-  for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
-    Block *cloned = mapping.lookup(&b);
-    impl->notifyInsertedBlock(cloned, /*previous=*/nullptr, /*previousIt=*/{});
-    cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
-        [&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
-  }
-}
-
 void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
                                                         InsertPoint previous) {
   assert(!previous.isSet() && "expected newly created op");

diff  --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index ecb17d5f1b67d..74f312e8144a0 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -110,9 +110,11 @@ builtin.module {
     // expected-error at +1 {{failed to legalize operation 'test.region'}}
     "test.region"() ({
       ^bb1(%i0: i64):
-        cf.br ^bb2(%i0 : i64)
+        cf.br ^bb3(%i0 : i64)
       ^bb2(%i1: i64):
         "test.invalid"(%i1) : (i64) -> ()
+      ^bb3(%i2: i64):
+        cf.br ^bb2(%i2 : i64)
     }) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> ()
 
     "test.return"() : () -> ()

diff  --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 5d889979a1f92..559561b34ceec 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -323,3 +323,34 @@ func.func @clone_op() {
   }) : () -> ()
   return
 }
+
+
+// -----
+
+// CHECK-AN: notifyBlockInserted into func.func: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_1, was unlinked
+// CHECK-AN: notifyBlockInserted into func.func: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_2, was unlinked
+// CHECK-AN: notifyBlockInserted into test.op_2: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_3, was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_4, was unlinked
+// CHECK-AN-LABEL: func @test_clone_region_before(
+// CHECK-AN:         "test.op_1"() : () -> ()
+// CHECK-AN:       ^{{.*}}:
+// CHECK-AN:         "test.op_2"() ({
+// CHECK-AN:           "test.op_3"() : () -> ()
+// CHECK-AN:         }) : () -> ()
+// CHECK-AN:         "test.op_4"() : () -> ()
+// CHECK-AN:       ^{{.*}}:
+// CHECK-AN:         "test.clone_region_before"() ({
+func.func @test_clone_region_before() {
+  "test.clone_region_before"() ({
+    "test.op_1"() : () -> ()
+  ^bb0:
+    "test.op_2"() ({
+      "test.op_3"() : () -> ()
+    }) : () -> ()
+    "test.op_4"() : () -> ()
+  }) : () -> ()
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e3978d3789cf0..d7e5d6db50c1f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -267,6 +267,24 @@ struct CloneOp : public RewritePattern {
   }
 };
 
+/// This pattern clones regions of "test.clone_region_before" ops before the
+/// parent block.
+struct CloneRegionBeforeOp : public RewritePattern {
+  CloneRegionBeforeOp(MLIRContext *context)
+      : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Do not clone already cloned ops to avoid going into an infinite loop.
+    if (op->hasAttr("was_cloned"))
+      return failure();
+    for (Region &r : op->getRegions())
+      rewriter.cloneRegionBefore(r, op->getBlock());
+    op->setAttr("was_cloned", rewriter.getUnitAttr());
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -358,6 +376,7 @@ struct TestStrictPatternDriver
         // clang-format off
         ChangeBlockOp,
         CloneOp,
+        CloneRegionBeforeOp,
         EraseOp,
         ImplicitChangeOp,
         InlineBlocksIntoParent,
@@ -374,7 +393,8 @@ struct TestStrictPatternDriver
           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
           opName == "test.move_before_parent_op" ||
           opName == "test.inline_blocks_into_parent" ||
-          opName == "test.split_block_here" || opName == "test.clone_me") {
+          opName == "test.split_block_here" || opName == "test.clone_me" ||
+          opName == "test.clone_region_before") {
         ops.push_back(op);
       }
     });


        


More information about the Mlir-commits mailing list