[Mlir-commits] [mlir] [mlir][IR] Trigger notifyOperationRemoved callback for nested ops (PR #66771)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 19 05:52:32 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

<details>
<summary>Changes</summary>

When cloning an op, the `notifyOperationInserted` callback is triggered for all nested ops. Similarly, the `notifyOperationRemoved` callback should be triggered for all nested ops when removing an op.

Listeners may inspect the IR during a `notifyOperationRemoved` callback. Therefore, when multiple ops are removed in a single `RewriterBase::eraseOp` call, the notifications must be triggered in an order in which the ops could have been removed one-by-one:

* Op removals must be interleaved with `notifyOperationRemoved` callbacks. A callback is triggered right before the respective op is removed.
* Ops are removed post-order and in reverse order. Other traversal orders could delete an op that still has uses. (This is not avoidable in graph regions and with cyclic block graphs.)

Note: Imported from https://reviews.llvm.org/D144193.


---
Full diff: https://github.com/llvm/llvm-project/pull/66771.diff


7 Files Affected:

- (modified) mlir/include/mlir/IR/RegionKindInterface.h (+6) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+3-6) 
- (modified) mlir/lib/IR/PatternMatch.cpp (+68-4) 
- (modified) mlir/lib/IR/RegionKindInterface.cpp (+10-2) 
- (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+3-6) 
- (modified) mlir/test/Transforms/test-strict-pattern-driver.mlir (+153-7) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+8) 


``````````diff
diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h
index 46bfe717533a84a..d6d3aeeb9bd0526 100644
--- a/mlir/include/mlir/IR/RegionKindInterface.h
+++ b/mlir/include/mlir/IR/RegionKindInterface.h
@@ -43,6 +43,12 @@ class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
 /// not implement the RegionKindInterface.
 bool mayHaveSSADominance(Region &region);
 
+/// Return "true" if the given region may be a graph region without SSA
+/// dominance. This function returns "true" in case the owner op is an
+/// unregistered op. It returns "false" if it is a registered op that does not
+/// implement the RegionKindInterface.
+bool mayBeGraphRegion(Region &region);
+
 } // namespace mlir
 
 #include "mlir/IR/RegionKindInterface.h.inc"
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index cad78b3e65b2313..c34f422292cb4f0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -394,12 +394,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
 protected:
   void notifyOperationRemoved(Operation *op) override {
-    // TODO: Walk can be removed when D144193 has landed.
-    op->walk([&](Operation *op) {
-      erasedOps.insert(op);
-      // Erase if present.
-      toMemrefOps.erase(op);
-    });
+    erasedOps.insert(op);
+    // Erase if present.
+    toMemrefOps.erase(op);
   }
 
   void notifyOperationInserted(Operation *op) override {
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index db920c14ea08dc7..5e9b9b2a810a4c5 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -8,6 +8,8 @@
 
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/RegionKindInterface.h"
 
 using namespace mlir;
 
@@ -275,7 +277,7 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   for (auto it : llvm::zip(op->getResults(), newValues))
     replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
-  // Erase the op.
+  // Erase op and notify listener.
   eraseOp(op);
 }
 
@@ -295,7 +297,7 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
   for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
     replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
-  // Erase the old op.
+  // Erase op and notify listener.
   eraseOp(op);
 }
 
@@ -303,9 +305,71 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
 /// the given operation *must* be known to be dead.
 void RewriterBase::eraseOp(Operation *op) {
   assert(op->use_empty() && "expected 'op' to have no uses");
-  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+  auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
+
+  // Fast path: If no listener is attached, the op can be dropped in one go.
+  if (!rewriteListener) {
+    op->erase();
+    return;
+  }
+
+  // Helper function that erases a single op.
+  auto eraseSingleOp = [&](Operation *op) {
+#ifndef NDEBUG
+    // All nested ops should have been erased already.
+    assert(
+        llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
+        "expected empty regions");
+    // All users should have been erased already if the op is in a region with
+    // SSA dominance.
+    if (!op->use_empty() && op->getParentOp())
+      assert(mayBeGraphRegion(*op->getParentRegion()) &&
+             "expected that op has no uses");
+#endif // NDEBUG
     rewriteListener->notifyOperationRemoved(op);
-  op->erase();
+
+    // Explicitly drop all uses in case the op is in a graph region.
+    op->dropAllUses();
+    op->erase();
+  };
+
+  // Nested ops must be erased one-by-one, so that listeners have a consistent
+  // view of the IR every time a notification is triggered. Users must be
+  // erased before definitions. I.e., post-order, reverse dominance.
+  std::function<void(Operation *)> eraseTree = [&](Operation *op) {
+    // Erase nested ops.
+    for (Region &r : llvm::reverse(op->getRegions())) {
+      // Erase all blocks in the right order. Successors should be erased
+      // before predecessors because successor blocks may use values defined
+      // in predecessor blocks. A post-order traversal of blocks within a
+      // region visits successors before predecessors. Repeat the traversal
+      // until the region is empty. (The block graph could be disconnected.)
+      while (!r.empty()) {
+        SmallVector<Block *> erasedBlocks;
+        for (Block *b : llvm::post_order(&r.front())) {
+          // Visit ops in reverse order.
+          for (Operation &op :
+               llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
+            eraseTree(&op);
+          // Do not erase the block immediately. This is not supprted by the
+          // post_order iterator.
+          erasedBlocks.push_back(b);
+        }
+        for (Block *b : erasedBlocks) {
+          // Explicitly drop all uses in case there is a cycle in the block
+          // graph.
+          for (BlockArgument bbArg : b->getArguments())
+            bbArg.dropAllUses();
+          b->dropAllUses();
+          b->erase();
+        }
+      }
+    }
+    // Then erase the enclosing op.
+    eraseSingleOp(op);
+  };
+
+  eraseTree(op);
 }
 
 void RewriterBase::eraseBlock(Block *block) {
diff --git a/mlir/lib/IR/RegionKindInterface.cpp b/mlir/lib/IR/RegionKindInterface.cpp
index cbef3025a5dd626..007f4cf92dbc7ae 100644
--- a/mlir/lib/IR/RegionKindInterface.cpp
+++ b/mlir/lib/IR/RegionKindInterface.cpp
@@ -18,9 +18,17 @@ using namespace mlir;
 #include "mlir/IR/RegionKindInterface.cpp.inc"
 
 bool mlir::mayHaveSSADominance(Region &region) {
-  auto regionKindOp =
-      dyn_cast_if_present<RegionKindInterface>(region.getParentOp());
+  auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
   if (!regionKindOp)
     return true;
   return regionKindOp.hasSSADominance(region.getRegionNumber());
 }
+
+bool mlir::mayBeGraphRegion(Region &region) {
+  if (!region.getParentOp()->isRegistered())
+    return true;
+  auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
+  if (!regionKindOp)
+    return false;
+  return !regionKindOp.hasSSADominance(region.getRegionNumber());
+}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index fba4944f130c230..8e2bfe557c555f3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -421,8 +421,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 
     // If the operation is trivially dead - remove it.
     if (isOpTriviallyDead(op)) {
-      notifyOperationRemoved(op);
-      op->erase();
+      eraseOp(op);
       changed = true;
 
       LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
@@ -567,10 +566,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
     config.listener->notifyOperationRemoved(op);
 
   addOperandsToWorklist(op->getOperands());
-  op->walk([this](Operation *operation) {
-    worklist.remove(operation);
-    folder.notifyRemoval(operation);
-  });
+  worklist.remove(op);
+  folder.notifyRemoval(op);
 
   if (config.strictMode != GreedyRewriteStrictness::AnyOp)
     strictModeFilteredOps.erase(op);
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 5df2d6d1fdeeb38..a5ab8f97c74ce33 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -12,9 +12,9 @@
 
 // CHECK-EN-LABEL: func @test_erase
 //  CHECK-EN-SAME:     pattern_driver_all_erased = true, pattern_driver_changed = true}
-//       CHECK-EN:   test.arg0
-//       CHECK-EN:   test.arg1
-//   CHECK-EN-NOT:   test.erase_op
+//       CHECK-EN:   "test.arg0"
+//       CHECK-EN:   "test.arg1"
+//   CHECK-EN-NOT:   "test.erase_op"
 func.func @test_erase() {
   %0 = "test.arg0"() : () -> (i32)
   %1 = "test.arg1"() : () -> (i32)
@@ -51,13 +51,13 @@ func.func @test_replace_with_new_op() {
 
 // CHECK-EN-LABEL: func @test_replace_with_erase_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
-//   CHECK-EN-NOT:   test.replace_with_new_op
-//   CHECK-EN-NOT:   test.erase_op
+//   CHECK-EN-NOT:   "test.replace_with_new_op"
+//   CHECK-EN-NOT:   "test.erase_op"
 
 // CHECK-EX-LABEL: func @test_replace_with_erase_op
 //  CHECK-EX-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
-//   CHECK-EX-NOT:   test.replace_with_new_op
-//       CHECK-EX:   test.erase_op
+//   CHECK-EX-NOT:   "test.replace_with_new_op"
+//       CHECK-EX:   "test.erase_op"
 func.func @test_replace_with_erase_op() {
   "test.replace_with_new_op"() {create_erase_op} : () -> ()
   return
@@ -83,3 +83,149 @@ func.func @test_trigger_rewrite_through_block() {
   // in turn, replaces the successor with bb3.
   "test.implicit_change_op"() [^bb1] : () -> ()
 }
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.foo_b
+// CHECK-AN: notifyOperationRemoved: test.foo_a
+// CHECK-AN: notifyOperationRemoved: test.graph_region
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_graph_region()
+//  CHECK-AN-NEXT:   return
+func.func @test_remove_graph_region() {
+  "test.erase_op"() ({
+    test.graph_region {
+      %0 = "test.foo_a"(%1) : (i1) -> (i1)
+      %1 = "test.foo_b"(%0) : (i1) -> (i1)
+    }
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_cyclic_blocks()
+//  CHECK-AN-NEXT:   return
+func.func @test_remove_cyclic_blocks() {
+  "test.erase_op"() ({
+    %x = "test.dummy_op"() : () -> (i1)
+    cf.br ^bb1(%x: i1)
+  ^bb1(%arg0: i1):
+    "test.foo"(%x) : (i1) -> ()
+    cf.br ^bb2(%arg0: i1)
+  ^bb2(%arg1: i1):
+    "test.bar"(%x) : (i1) -> ()
+    cf.br ^bb1(%arg1: i1)
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: test.qux
+// CHECK-AN: notifyOperationRemoved: test.qux_unreachable
+// CHECK-AN: notifyOperationRemoved: test.nested_dummy
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_dead_blocks()
+//  CHECK-AN-NEXT:   return
+func.func @test_remove_dead_blocks() {
+  "test.erase_op"() ({
+    "test.dummy_op"() : () -> (i1)
+  // The following blocks are not reachable. Still, ^bb2 should be deleted
+  // befire ^bb1.
+  ^bb1(%arg0: i1):
+    "test.foo"() : () -> ()
+    cf.br ^bb2(%arg0: i1)
+  ^bb2(%arg1: i1):
+    "test.nested_dummy"() ({
+      "test.qux"() : () -> ()
+    // The following block is unreachable.
+    ^bb3:
+      "test.qux_unreachable"() : () -> ()
+    }) : () -> ()
+    "test.bar"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// test.nested_* must be deleted before test.foo.
+// test.bar must be deleted before test.foo.
+
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_b
+// CHECK-AN: notifyOperationRemoved: test.nested_a
+// CHECK-AN: notifyOperationRemoved: test.nested_d
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_e
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_c
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_nested_ops()
+//  CHECK-AN-NEXT:   return
+func.func @test_remove_nested_ops() {
+  "test.erase_op"() ({
+    %x = "test.dummy_op"() : () -> (i1)
+    cf.br ^bb1(%x: i1)
+  ^bb1(%arg0: i1):
+    "test.foo"() ({
+      "test.nested_a"() : () -> ()
+      "test.nested_b"() : () -> ()
+    ^dead1:
+      "test.nested_c"() : () -> ()
+      cf.br ^dead3
+    ^dead2:
+      "test.nested_d"() : () -> ()
+    ^dead3:
+      "test.nested_e"() : () -> ()
+      cf.br ^dead2
+    }) : () -> ()
+    cf.br ^bb2(%arg0: i1)
+  ^bb2(%arg1: i1):
+    "test.bar"(%x) : (i1) -> ()
+    cf.br ^bb1(%arg1: i1)
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.qux
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.cond_br
+// CHECK-AN-LABEL: func @test_remove_diamond(
+//  CHECK-AN-NEXT:   return
+func.func @test_remove_diamond(%c: i1) {
+  "test.erase_op"() ({
+    cf.cond_br %c, ^bb1, ^bb2
+  ^bb1:
+    "test.foo"() : () -> ()
+    cf.br ^bb3
+  ^bb2:
+    "test.bar"() : () -> ()
+    cf.br ^bb3
+  ^bb3:
+    "test.qux"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e23ed105e383390..2e3bc76009ca208 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -239,6 +239,12 @@ struct TestPatternDriver
       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
 };
 
+struct DumpNotifications : public RewriterBase::Listener {
+  void notifyOperationRemoved(Operation *op) override {
+    llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
+  }
+};
+
 struct TestStrictPatternDriver
     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
 public:
@@ -275,7 +281,9 @@ struct TestStrictPatternDriver
       }
     });
 
+    DumpNotifications dumpNotifications;
     GreedyRewriteConfig config;
+    config.listener = &dumpNotifications;
     if (strictMode == "AnyOp") {
       config.strictMode = GreedyRewriteStrictness::AnyOp;
     } else if (strictMode == "ExistingAndNewOps") {

``````````

</details>


https://github.com/llvm/llvm-project/pull/66771


More information about the Mlir-commits mailing list