[Mlir-commits] [mlir] [mlir][IR] Send missing notifications when inlining a block (PR #79593)

Matthias Springer llvmlistbot at llvm.org
Wed Jan 31 05:21:25 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/79593

>From 4c14106722fef1c1798177c876f0a7e01e5c7e0d Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 26 Jan 2024 13:12:55 +0000
Subject: [PATCH] [mlir][IR] Send missing notifications when inlining a block

When a block is inlined into another block, the operations are moved into another block and the `notifyOperationInserted` callback should be triggered. This commit adds the missing notifications for:
* `RewriterBase::inlineBlockBefore`
* `RewriterBase::mergeBlocks`
---
 mlir/lib/IR/PatternMatch.cpp                  | 11 ++++++-
 .../Dialect/Affine/simplify-structures.mlir   |  2 --
 .../test-strict-pattern-driver.mlir           | 20 ++++++++++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 32 ++++++++++++++++---
 4 files changed, 57 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index ee285f094cdd1..0bf79f5817b5a 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -317,7 +317,16 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
 
   // Move operations from the source block to the dest block and erase the
   // source block.
-  dest->getOperations().splice(before, source->getOperations());
+  if (!listener) {
+    // Fast path: If no listener is attached, move all operations at once.
+    dest->getOperations().splice(before, source->getOperations());
+  } else {
+    while (!source->empty())
+      moveOpBefore(&source->front(), dest, before);
+  }
+
+  // Erase the source block.
+  assert(source->empty() && "expected 'source' to be empty");
   eraseBlock(source);
 }
 
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index 2c693ea1551c0..92d3d86bc9306 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -411,8 +411,6 @@ func.func @test_trivially_false_returning_two_results(%arg0: index) -> (index, i
   // CHECK: %[[c13:.*]] = arith.constant 13 : index
   %c7 = arith.constant 7 : index
   %c13 = arith.constant 13 : index
-  // CHECK: %[[c2:.*]] = arith.constant 2 : index
-  // CHECK: %[[c3:.*]] = arith.constant 3 : index
   %res:2 = affine.if affine_set<(d0, d1) : (5 >= 0, -2 >= 0)> (%c7, %c13) -> (index, index) {
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 2ebc66b4f26d0..4785795121653 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -249,3 +249,23 @@ func.func @test_move_op_before() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-AN: notifyOperationInserted: test.op_1, previous = test.op_2
+// CHECK-AN: notifyOperationInserted: test.op_2, previous = test.op_3
+// CHECK-AN: notifyOperationInserted: test.op_3, was last in block
+// CHECK-AN-LABEL: func @test_inline_block_before(
+//       CHECK-AN:   test.op_1
+//       CHECK-AN:   test.op_2
+//       CHECK-AN:   test.op_3
+//       CHECK-AN:   test.inline_blocks_into_parent
+//       CHECK-AN:   return
+func.func @test_inline_block_before() {
+  "test.inline_blocks_into_parent"() ({
+    "test.op_1"() : () -> ()
+    "test.op_2"() : () -> ()
+    "test.op_3"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 2165b388d559c..9ce2ab8daa234 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -213,6 +213,26 @@ struct MoveBeforeParentOp : public RewritePattern {
   }
 };
 
+/// This pattern inlines blocks that are nested in
+/// "test.inline_blocks_into_parent" into the parent block.
+struct InlineBlocksIntoParent : public RewritePattern {
+  InlineBlocksIntoParent(MLIRContext *context)
+      : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1,
+                       context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    for (Region &r : op->getRegions()) {
+      while (!r.empty()) {
+        rewriter.inlineBlockBefore(&r.front(), op);
+        changed = true;
+      }
+    }
+    return success(changed);
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -292,12 +312,13 @@ struct TestStrictPatternDriver
     mlir::RewritePatternSet patterns(ctx);
     patterns.add<
         // clang-format off
-        InsertSameOp,
-        ReplaceWithNewOp,
-        EraseOp,
         ChangeBlockOp,
+        EraseOp,
         ImplicitChangeOp,
-        MoveBeforeParentOp
+        InlineBlocksIntoParent,
+        InsertSameOp,
+        MoveBeforeParentOp,
+        ReplaceWithNewOp
         // clang-format on
         >(ctx);
     SmallVector<Operation *> ops;
@@ -305,7 +326,8 @@ struct TestStrictPatternDriver
       StringRef opName = op->getName().getStringRef();
       if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
           opName == "test.replace_with_new_op" || opName == "test.erase_op" ||
-          opName == "test.move_before_parent_op") {
+          opName == "test.move_before_parent_op" ||
+          opName == "test.inline_blocks_into_parent") {
         ops.push_back(op);
       }
     });



More information about the Mlir-commits mailing list