[Mlir-commits] [mlir] [mlir][IR] Trigger nested operation/block insertion notifications for clones (PR #66871)

Matthias Springer llvmlistbot at llvm.org
Thu Feb 1 03:27:56 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/66871

>From c149370a584fb1ecffaa4997c57e0f7c595cedcc Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 1 Feb 2024 11:26:32 +0000
Subject: [PATCH 1/2] [mlir][IR] Notify about block insertion when cloning an
 op

`OpBuilder::clone(Operation &)` should trigger not only `notifyOperationInserted` but also `notifyBlockInserted` (for all block contained in `op`).
---
 mlir/lib/IR/Builders.cpp                      | 14 ++++++
 .../test-strict-pattern-driver.mlir           | 43 ++++++++++++++++---
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 29 ++++++++++++-
 3 files changed, 78 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7acef1073c6de..589d41de9b8bc 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -525,16 +525,30 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
   Operation *newOp = op.clone(mapper);
   newOp = insert(newOp);
+
   // The `insert` call above handles the notification for inserting `newOp`
   // itself. But if `newOp` has any regions, we need to notify the listener
   // about any ops that got inserted inside those regions as part of cloning.
   if (listener) {
+    // Helper function that sends block insertion notifications for every block
+    // within the given op.
+    auto notifyBlockInsertions = [&](Operation *op) {
+      for (Region &r : op->getRegions())
+        for (Block &b : r.getBlocks())
+          listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+                                        /*previousIt=*/{});
+    };
+    // The `insert` call above notifies about op insertion, but not about block
+    // insertion.
+    notifyBlockInsertions(newOp);
     auto walkFn = [&](Operation *walkedOp) {
       listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+      notifyBlockInsertions(walkedOp);
     };
     for (Region &region : newOp->getRegions())
       region.walk<WalkOrder::PreOrder>(walkFn);
   }
+
   return newOp;
 }
 
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 6d7ccf161c35d..5d889979a1f92 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -272,19 +272,20 @@ func.func @test_inline_block_before() {
 
 // -----
 
+// CHECK-AN: notifyBlockInserted into test.op_with_region: was unlinked
 // CHECK-AN: notifyOperationInserted: test.op_3, was last in block
 // CHECK-AN: notifyOperationInserted: test.op_2, was last in block
 // CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
 // CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
 // CHECK-AN: notifyOperationRemoved: test.split_block_here
 // CHECK-AN-LABEL: func @test_split_block(
-//          CHECK:   "test.op_with_region"() ({
-//          CHECK:     test.op_1
-//          CHECK:   ^{{.*}}:
-//          CHECK:     test.new_op
-//          CHECK:     test.op_2
-//          CHECK:     test.op_3
-//          CHECK:   }) : () -> ()
+//       CHECK-AN:   "test.op_with_region"() ({
+//       CHECK-AN:     test.op_1
+//       CHECK-AN:   ^{{.*}}:
+//       CHECK-AN:     test.new_op
+//       CHECK-AN:     test.op_2
+//       CHECK-AN:     test.op_3
+//       CHECK-AN:   }) : () -> ()
 func.func @test_split_block() {
   "test.op_with_region"() ({
     "test.op_1"() : () -> ()
@@ -294,3 +295,31 @@ func.func @test_split_block() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.clone_me, was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyOperationInserted: test.foo, was unlinked
+// CHECK-AN: notifyOperationInserted: test.bar, was unlinked
+// CHECK-AN-LABEL: func @clone_op(
+// CHECK-AN:         "test.clone_me"() ({
+// CHECK-AN:           "test.foo"() : () -> ()
+// CHECK-AN:         ^bb1:  // no predecessors
+// CHECK-AN:           "test.bar"() : () -> ()
+// CHECK-AN:         }) {was_cloned} : () -> ()
+// CHECK-AN:         "test.clone_me"() ({
+// CHECK-AN:           "test.foo"() : () -> ()
+// CHECK-AN:         ^bb1:  // no predecessors
+// CHECK-AN:           "test.bar"() : () -> ()
+// CHECK-AN:         }) : () -> ()
+func.func @clone_op() {
+  "test.clone_me"() ({
+  ^bb0:
+    "test.foo"() : () -> ()
+  ^bb1:
+    "test.bar"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 307ae58ba74c5..e3978d3789cf0 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -251,6 +251,22 @@ struct SplitBlockHere : public RewritePattern {
   }
 };
 
+/// This pattern clones "test.clone_me" ops.
+struct CloneOp : public RewritePattern {
+  CloneOp(MLIRContext *context)
+      : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Do not clone already cloned ops to avoid going into an infinite loop.
+    if (op->hasAttr("was_cloned"))
+      return failure();
+    Operation *cloned = rewriter.clone(*op);
+    cloned->setAttr("was_cloned", rewriter.getUnitAttr());
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -291,6 +307,16 @@ struct TestPatternDriver
 };
 
 struct DumpNotifications : public RewriterBase::Listener {
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override {
+    llvm::outs() << "notifyBlockInserted into "
+                 << block->getParentOp()->getName() << ": ";
+    if (previous == nullptr) {
+      llvm::outs() << "was unlinked\n";
+    } else {
+      llvm::outs() << "was linked\n";
+    }
+  }
   void notifyOperationInserted(Operation *op,
                                OpBuilder::InsertPoint previous) override {
     llvm::outs() << "notifyOperationInserted: " << op->getName();
@@ -331,6 +357,7 @@ struct TestStrictPatternDriver
     patterns.add<
         // clang-format off
         ChangeBlockOp,
+        CloneOp,
         EraseOp,
         ImplicitChangeOp,
         InlineBlocksIntoParent,
@@ -347,7 +374,7 @@ struct TestStrictPatternDriver
           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
           opName == "test.move_before_parent_op" ||
           opName == "test.inline_blocks_into_parent" ||
-          opName == "test.split_block_here") {
+          opName == "test.split_block_here" || opName == "test.clone_me") {
         ops.push_back(op);
       }
     });

>From 3051c7d199c19204502572f7fa51e7994ee63027 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 1 Feb 2024 11:26:53 +0000
Subject: [PATCH 2/2] [mlir][IR] Send notifications for `cloneRegionBefore`

Similar to `OpBuilder::clone`, operation/block insertion notifications should be sent when cloning the contents of a region. E.g., this is to ensure that the newly created operations are put on the worklist of the greedy pattern rewriter driver.

Also move `cloneRegionBefore` from `RewriterBase` to `OpBuilder`. It only creates new IR, so it should be part of the builder API (like `clone(Operation &)`). The function does not have to be virtual. Now that notifications are properly sent, the override in the dialect conversion is no longer needed.

BEGIN_PUBLIC
No public commit message for presubmit.
END_PUBLIC
---
 mlir/include/mlir/IR/Builders.h               | 10 ++++
 mlir/include/mlir/IR/PatternMatch.h           | 10 ----
 .../mlir/Transforms/DialectConversion.h       |  8 ---
 mlir/lib/IR/Builders.cpp                      | 52 +++++++++++++++----
 mlir/lib/IR/PatternMatch.cpp                  | 18 -------
 .../Transforms/Utils/DialectConversion.cpp    | 17 ------
 mlir/test/Transforms/test-legalizer-full.mlir |  4 +-
 .../test-strict-pattern-driver.mlir           | 31 +++++++++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 22 +++++++-
 9 files changed, 107 insertions(+), 65 deletions(-)

diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 8c25a1aa2fad1..c4e165d5a053d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -583,6 +583,16 @@ class OpBuilder : public Builder {
     return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
   }
 
+  /// Clone the blocks that belong to "region" before the given position in
+  /// another region "parent". The two regions must be different. The caller is
+  /// responsible for creating or updating the operation transferring flow of
+  /// control to the region and passing it the correct block arguments.
+  void cloneRegionBefore(Region &region, Region &parent,
+                         Region::iterator before, IRMapping &mapping);
+  void cloneRegionBefore(Region &region, Region &parent,
+                         Region::iterator before);
+  void cloneRegionBefore(Region &region, Block *before);
+
 protected:
   /// The optional listener for events of this builder.
   Listener *listener;
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 72cbc85e9081d..61da27825e870 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -500,16 +500,6 @@ class RewriterBase : public OpBuilder {
                           Region::iterator before);
   void inlineRegionBefore(Region &region, Block *before);
 
-  /// Clone the blocks that belong to "region" before the given position in
-  /// another region "parent". The two regions must be different. The caller is
-  /// responsible for creating or updating the operation transferring flow of
-  /// control to the region and passing it the correct block arguments.
-  virtual void cloneRegionBefore(Region &region, Region &parent,
-                                 Region::iterator before, IRMapping &mapping);
-  void cloneRegionBefore(Region &region, Region &parent,
-                         Region::iterator before);
-  void cloneRegionBefore(Region &region, Block *before);
-
   /// This method replaces the uses of the results of `op` with the values in
   /// `newValues` when the provided `functor` returns true for a specific use.
   /// The number of values in `newValues` is required to match the number of
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index d9470de9ceb9f..51e3e413b516f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -724,14 +724,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
                          ValueRange argValues = std::nullopt) override;
   using PatternRewriter::inlineBlockBefore;
 
-  /// PatternRewriter hook for cloning blocks of one region into another. The
-  /// given region to clone *must* not have been modified as part of conversion
-  /// yet, i.e. it must be within an operation that is either in the process of
-  /// conversion, or has not yet been converted.
-  void cloneRegionBefore(Region &region, Region &parent,
-                         Region::iterator before, IRMapping &mapping) override;
-  using PatternRewriter::cloneRegionBefore;
-
   /// PatternRewriter hook for inserting a new operation.
   void notifyOperationInserted(Operation *op, InsertPoint previous) override;
 
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 589d41de9b8bc..a905aae86a426 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -522,6 +522,16 @@ LogicalResult OpBuilder::tryFold(Operation *op,
   return success();
 }
 
+/// Helper function that sends block insertion notifications for every block
+/// that is directly nested in the given op.
+static void notifyBlockInsertions(Operation *op,
+                                  OpBuilder::Listener *listener) {
+  for (Region &r : op->getRegions())
+    for (Block &b : r.getBlocks())
+      listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+                                    /*previousIt=*/{});
+}
+
 Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
   Operation *newOp = op.clone(mapper);
   newOp = insert(newOp);
@@ -530,20 +540,12 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
   // itself. But if `newOp` has any regions, we need to notify the listener
   // about any ops that got inserted inside those regions as part of cloning.
   if (listener) {
-    // Helper function that sends block insertion notifications for every block
-    // within the given op.
-    auto notifyBlockInsertions = [&](Operation *op) {
-      for (Region &r : op->getRegions())
-        for (Block &b : r.getBlocks())
-          listener->notifyBlockInserted(&b, /*previous=*/nullptr,
-                                        /*previousIt=*/{});
-    };
     // The `insert` call above notifies about op insertion, but not about block
     // insertion.
-    notifyBlockInsertions(newOp);
+    notifyBlockInsertions(newOp, listener);
     auto walkFn = [&](Operation *walkedOp) {
       listener->notifyOperationInserted(walkedOp, /*previous=*/{});
-      notifyBlockInsertions(walkedOp);
+      notifyBlockInsertions(walkedOp, listener);
     };
     for (Region &region : newOp->getRegions())
       region.walk<WalkOrder::PreOrder>(walkFn);
@@ -556,3 +558,33 @@ Operation *OpBuilder::clone(Operation &op) {
   IRMapping mapper;
   return clone(op, mapper);
 }
+
+void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
+                                  Region::iterator before, IRMapping &mapping) {
+  region.cloneInto(&parent, before, mapping);
+
+  // Fast path: If no listener is attached, there is no more work to do.
+  if (!listener)
+    return;
+
+  // Notify about op/block insertion.
+  for (auto it = mapping.lookup(&region.front())->getIterator(); it != before;
+       ++it) {
+    listener->notifyBlockInserted(&*it, /*previous=*/nullptr,
+                                  /*previousIt=*/{});
+    it->walk<WalkOrder::PreOrder>([&](Operation *walkedOp) {
+      listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+      notifyBlockInsertions(walkedOp, listener);
+    });
+  }
+}
+
+void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
+                                  Region::iterator before) {
+  IRMapping mapping;
+  cloneRegionBefore(region, parent, before, mapping);
+}
+
+void OpBuilder::cloneRegionBefore(Region &region, Block *before) {
+  cloneRegionBefore(region, *before->getParent(), before->getIterator());
+}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 22b5ad749f0c6..9204733c99bab 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -384,24 +384,6 @@ void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
   inlineRegionBefore(region, *before->getParent(), before->getIterator());
 }
 
-/// Clone the blocks that belong to "region" before the given position in
-/// another region "parent". The two regions must be different. The caller is
-/// responsible for creating or updating the operation transferring flow of
-/// control to the region and passing it the correct block arguments.
-void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
-                                     Region::iterator before,
-                                     IRMapping &mapping) {
-  region.cloneInto(&parent, before, mapping);
-}
-void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
-                                     Region::iterator before) {
-  IRMapping mapping;
-  cloneRegionBefore(region, parent, before, mapping);
-}
-void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
-  cloneRegionBefore(region, *before->getParent(), before->getIterator());
-}
-
 void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
   moveBlockBefore(block, anotherBlock->getParent(),
                   anotherBlock->getIterator());
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3928b98568bf3..346135fb44722 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1573,23 +1573,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
   eraseBlock(source);
 }
 
-void ConversionPatternRewriter::cloneRegionBefore(Region &region,
-                                                  Region &parent,
-                                                  Region::iterator before,
-                                                  IRMapping &mapping) {
-  if (region.empty())
-    return;
-
-  PatternRewriter::cloneRegionBefore(region, parent, before, mapping);
-
-  for (Block &b : ForwardDominanceIterator<>::makeIterable(region)) {
-    Block *cloned = mapping.lookup(&b);
-    impl->notifyInsertedBlock(cloned, /*previous=*/nullptr, /*previousIt=*/{});
-    cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
-        [&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
-  }
-}
-
 void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
                                                         InsertPoint previous) {
   assert(!previous.isSet() && "expected newly created op");
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index ecb17d5f1b67d..74f312e8144a0 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -110,9 +110,11 @@ builtin.module {
     // expected-error at +1 {{failed to legalize operation 'test.region'}}
     "test.region"() ({
       ^bb1(%i0: i64):
-        cf.br ^bb2(%i0 : i64)
+        cf.br ^bb3(%i0 : i64)
       ^bb2(%i1: i64):
         "test.invalid"(%i1) : (i64) -> ()
+      ^bb3(%i2: i64):
+        cf.br ^bb2(%i2 : i64)
     }) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> ()
 
     "test.return"() : () -> ()
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 5d889979a1f92..559561b34ceec 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -323,3 +323,34 @@ func.func @clone_op() {
   }) : () -> ()
   return
 }
+
+
+// -----
+
+// CHECK-AN: notifyBlockInserted into func.func: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_1, was unlinked
+// CHECK-AN: notifyBlockInserted into func.func: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_2, was unlinked
+// CHECK-AN: notifyBlockInserted into test.op_2: was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_3, was unlinked
+// CHECK-AN: notifyOperationInserted: test.op_4, was unlinked
+// CHECK-AN-LABEL: func @test_clone_region_before(
+// CHECK-AN:         "test.op_1"() : () -> ()
+// CHECK-AN:       ^{{.*}}:
+// CHECK-AN:         "test.op_2"() ({
+// CHECK-AN:           "test.op_3"() : () -> ()
+// CHECK-AN:         }) : () -> ()
+// CHECK-AN:         "test.op_4"() : () -> ()
+// CHECK-AN:       ^{{.*}}:
+// CHECK-AN:         "test.clone_region_before"() ({
+func.func @test_clone_region_before() {
+  "test.clone_region_before"() ({
+    "test.op_1"() : () -> ()
+  ^bb0:
+    "test.op_2"() ({
+      "test.op_3"() : () -> ()
+    }) : () -> ()
+    "test.op_4"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e3978d3789cf0..d7e5d6db50c1f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -267,6 +267,24 @@ struct CloneOp : public RewritePattern {
   }
 };
 
+/// This pattern clones regions of "test.clone_region_before" ops before the
+/// parent block.
+struct CloneRegionBeforeOp : public RewritePattern {
+  CloneRegionBeforeOp(MLIRContext *context)
+      : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Do not clone already cloned ops to avoid going into an infinite loop.
+    if (op->hasAttr("was_cloned"))
+      return failure();
+    for (Region &r : op->getRegions())
+      rewriter.cloneRegionBefore(r, op->getBlock());
+    op->setAttr("was_cloned", rewriter.getUnitAttr());
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -358,6 +376,7 @@ struct TestStrictPatternDriver
         // clang-format off
         ChangeBlockOp,
         CloneOp,
+        CloneRegionBeforeOp,
         EraseOp,
         ImplicitChangeOp,
         InlineBlocksIntoParent,
@@ -374,7 +393,8 @@ struct TestStrictPatternDriver
           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
           opName == "test.move_before_parent_op" ||
           opName == "test.inline_blocks_into_parent" ||
-          opName == "test.split_block_here" || opName == "test.clone_me") {
+          opName == "test.split_block_here" || opName == "test.clone_me" ||
+          opName == "test.clone_region_before") {
         ops.push_back(op);
       }
     });



More information about the Mlir-commits mailing list