[Mlir-commits] [mlir] [mlir][IR] Notify about block insertion when cloning an op (PR #80262)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 1 01:43:25 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

`OpBuilder::clone(Operation &)` should trigger not only `notifyOperationInserted` but also `notifyBlockInserted` (for all block contained in `op`).

---
Full diff: https://github.com/llvm/llvm-project/pull/80262.diff


3 Files Affected:

- (modified) mlir/lib/IR/Builders.cpp (+14) 
- (modified) mlir/test/Transforms/test-strict-pattern-driver.mlir (+36-7) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+28-1) 


``````````diff
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7acef1073c6de..589d41de9b8bc 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -525,16 +525,30 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
   Operation *newOp = op.clone(mapper);
   newOp = insert(newOp);
+
   // The `insert` call above handles the notification for inserting `newOp`
   // itself. But if `newOp` has any regions, we need to notify the listener
   // about any ops that got inserted inside those regions as part of cloning.
   if (listener) {
+    // Helper function that sends block insertion notifications for every block
+    // within the given op.
+    auto notifyBlockInsertions = [&](Operation *op) {
+      for (Region &r : op->getRegions())
+        for (Block &b : r.getBlocks())
+          listener->notifyBlockInserted(&b, /*previous=*/nullptr,
+                                        /*previousIt=*/{});
+    };
+    // The `insert` call above notifies about op insertion, but not about block
+    // insertion.
+    notifyBlockInsertions(newOp);
     auto walkFn = [&](Operation *walkedOp) {
       listener->notifyOperationInserted(walkedOp, /*previous=*/{});
+      notifyBlockInsertions(walkedOp);
     };
     for (Region &region : newOp->getRegions())
       region.walk<WalkOrder::PreOrder>(walkFn);
   }
+
   return newOp;
 }
 
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 6d7ccf161c35d..5d889979a1f92 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -272,19 +272,20 @@ func.func @test_inline_block_before() {
 
 // -----
 
+// CHECK-AN: notifyBlockInserted into test.op_with_region: was unlinked
 // CHECK-AN: notifyOperationInserted: test.op_3, was last in block
 // CHECK-AN: notifyOperationInserted: test.op_2, was last in block
 // CHECK-AN: notifyOperationInserted: test.split_block_here, was last in block
 // CHECK-AN: notifyOperationInserted: test.new_op, was unlinked
 // CHECK-AN: notifyOperationRemoved: test.split_block_here
 // CHECK-AN-LABEL: func @test_split_block(
-//          CHECK:   "test.op_with_region"() ({
-//          CHECK:     test.op_1
-//          CHECK:   ^{{.*}}:
-//          CHECK:     test.new_op
-//          CHECK:     test.op_2
-//          CHECK:     test.op_3
-//          CHECK:   }) : () -> ()
+//       CHECK-AN:   "test.op_with_region"() ({
+//       CHECK-AN:     test.op_1
+//       CHECK-AN:   ^{{.*}}:
+//       CHECK-AN:     test.new_op
+//       CHECK-AN:     test.op_2
+//       CHECK-AN:     test.op_3
+//       CHECK-AN:   }) : () -> ()
 func.func @test_split_block() {
   "test.op_with_region"() ({
     "test.op_1"() : () -> ()
@@ -294,3 +295,31 @@ func.func @test_split_block() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.clone_me, was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyBlockInserted into test.clone_me: was unlinked
+// CHECK-AN: notifyOperationInserted: test.foo, was unlinked
+// CHECK-AN: notifyOperationInserted: test.bar, was unlinked
+// CHECK-AN-LABEL: func @clone_op(
+// CHECK-AN:         "test.clone_me"() ({
+// CHECK-AN:           "test.foo"() : () -> ()
+// CHECK-AN:         ^bb1:  // no predecessors
+// CHECK-AN:           "test.bar"() : () -> ()
+// CHECK-AN:         }) {was_cloned} : () -> ()
+// CHECK-AN:         "test.clone_me"() ({
+// CHECK-AN:           "test.foo"() : () -> ()
+// CHECK-AN:         ^bb1:  // no predecessors
+// CHECK-AN:           "test.bar"() : () -> ()
+// CHECK-AN:         }) : () -> ()
+func.func @clone_op() {
+  "test.clone_me"() ({
+  ^bb0:
+    "test.foo"() : () -> ()
+  ^bb1:
+    "test.bar"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 307ae58ba74c5..e3978d3789cf0 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -251,6 +251,22 @@ struct SplitBlockHere : public RewritePattern {
   }
 };
 
+/// This pattern clones "test.clone_me" ops.
+struct CloneOp : public RewritePattern {
+  CloneOp(MLIRContext *context)
+      : RewritePattern("test.clone_me", /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Do not clone already cloned ops to avoid going into an infinite loop.
+    if (op->hasAttr("was_cloned"))
+      return failure();
+    Operation *cloned = rewriter.clone(*op);
+    cloned->setAttr("was_cloned", rewriter.getUnitAttr());
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -291,6 +307,16 @@ struct TestPatternDriver
 };
 
 struct DumpNotifications : public RewriterBase::Listener {
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override {
+    llvm::outs() << "notifyBlockInserted into "
+                 << block->getParentOp()->getName() << ": ";
+    if (previous == nullptr) {
+      llvm::outs() << "was unlinked\n";
+    } else {
+      llvm::outs() << "was linked\n";
+    }
+  }
   void notifyOperationInserted(Operation *op,
                                OpBuilder::InsertPoint previous) override {
     llvm::outs() << "notifyOperationInserted: " << op->getName();
@@ -331,6 +357,7 @@ struct TestStrictPatternDriver
     patterns.add<
         // clang-format off
         ChangeBlockOp,
+        CloneOp,
         EraseOp,
         ImplicitChangeOp,
         InlineBlocksIntoParent,
@@ -347,7 +374,7 @@ struct TestStrictPatternDriver
           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
           opName == "test.move_before_parent_op" ||
           opName == "test.inline_blocks_into_parent" ||
-          opName == "test.split_block_here") {
+          opName == "test.split_block_here" || opName == "test.clone_me") {
         ops.push_back(op);
       }
     });

``````````

</details>


https://github.com/llvm/llvm-project/pull/80262


More information about the Mlir-commits mailing list