[Mlir-commits] [mlir] 695a5a6 - [mlir][IR] Trigger `notifyOperationRemoved` callback for nested ops (#66771)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 19 23:45:50 PDT 2023
Author: Matthias Springer
Date: 2023-09-20T08:45:46+02:00
New Revision: 695a5a6a66396b83263bbb3f1946fbaf41e422c3
URL: https://github.com/llvm/llvm-project/commit/695a5a6a66396b83263bbb3f1946fbaf41e422c3
DIFF: https://github.com/llvm/llvm-project/commit/695a5a6a66396b83263bbb3f1946fbaf41e422c3.diff
LOG: [mlir][IR] Trigger `notifyOperationRemoved` callback for nested ops (#66771)
When cloning an op, the `notifyOperationInserted` callback is triggered
for all nested ops. Similarly, the `notifyOperationRemoved` callback
should be triggered for all nested ops when removing an op.
Listeners may inspect the IR during a `notifyOperationRemoved` callback.
Therefore, when multiple ops are removed in a single
`RewriterBase::eraseOp` call, the notifications must be triggered in an
order in which the ops could have been removed one-by-one:
* Op removals must be interleaved with `notifyOperationRemoved`
callbacks. A callback is triggered right before the respective op is
removed.
* Ops are removed post-order and in reverse order. Other traversal
orders could delete an op that still has uses. (This is not avoidable in
graph regions and with cyclic block graphs.)
Differential Revision: Imported from https://reviews.llvm.org/D144193.
Added:
Modified:
mlir/include/mlir/IR/RegionKindInterface.h
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/IR/RegionKindInterface.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Transforms/test-strict-pattern-driver.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h
index 46bfe717533a84a..d6d3aeeb9bd0526 100644
--- a/mlir/include/mlir/IR/RegionKindInterface.h
+++ b/mlir/include/mlir/IR/RegionKindInterface.h
@@ -43,6 +43,12 @@ class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
/// not implement the RegionKindInterface.
bool mayHaveSSADominance(Region ®ion);
+/// Return "true" if the given region may be a graph region without SSA
+/// dominance. This function returns "true" in case the owner op is an
+/// unregistered op. It returns "false" if it is a registered op that does not
+/// implement the RegionKindInterface.
+bool mayBeGraphRegion(Region ®ion);
+
} // namespace mlir
#include "mlir/IR/RegionKindInterface.h.inc"
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index cad78b3e65b2313..c34f422292cb4f0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -394,12 +394,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
protected:
void notifyOperationRemoved(Operation *op) override {
- // TODO: Walk can be removed when D144193 has landed.
- op->walk([&](Operation *op) {
- erasedOps.insert(op);
- // Erase if present.
- toMemrefOps.erase(op);
- });
+ erasedOps.insert(op);
+ // Erase if present.
+ toMemrefOps.erase(op);
}
void notifyOperationInserted(Operation *op) override {
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index db920c14ea08dc7..5e9b9b2a810a4c5 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -8,6 +8,8 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/RegionKindInterface.h"
using namespace mlir;
@@ -275,7 +277,7 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
for (auto it : llvm::zip(op->getResults(), newValues))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
- // Erase the op.
+ // Erase op and notify listener.
eraseOp(op);
}
@@ -295,7 +297,7 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
- // Erase the old op.
+ // Erase op and notify listener.
eraseOp(op);
}
@@ -303,9 +305,71 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
/// the given operation *must* be known to be dead.
void RewriterBase::eraseOp(Operation *op) {
assert(op->use_empty() && "expected 'op' to have no uses");
- if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
+
+ // Fast path: If no listener is attached, the op can be dropped in one go.
+ if (!rewriteListener) {
+ op->erase();
+ return;
+ }
+
+ // Helper function that erases a single op.
+ auto eraseSingleOp = [&](Operation *op) {
+#ifndef NDEBUG
+ // All nested ops should have been erased already.
+ assert(
+ llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
+ "expected empty regions");
+ // All users should have been erased already if the op is in a region with
+ // SSA dominance.
+ if (!op->use_empty() && op->getParentOp())
+ assert(mayBeGraphRegion(*op->getParentRegion()) &&
+ "expected that op has no uses");
+#endif // NDEBUG
rewriteListener->notifyOperationRemoved(op);
- op->erase();
+
+ // Explicitly drop all uses in case the op is in a graph region.
+ op->dropAllUses();
+ op->erase();
+ };
+
+ // Nested ops must be erased one-by-one, so that listeners have a consistent
+ // view of the IR every time a notification is triggered. Users must be
+ // erased before definitions. I.e., post-order, reverse dominance.
+ std::function<void(Operation *)> eraseTree = [&](Operation *op) {
+ // Erase nested ops.
+ for (Region &r : llvm::reverse(op->getRegions())) {
+ // Erase all blocks in the right order. Successors should be erased
+ // before predecessors because successor blocks may use values defined
+ // in predecessor blocks. A post-order traversal of blocks within a
+ // region visits successors before predecessors. Repeat the traversal
+ // until the region is empty. (The block graph could be disconnected.)
+ while (!r.empty()) {
+ SmallVector<Block *> erasedBlocks;
+ for (Block *b : llvm::post_order(&r.front())) {
+ // Visit ops in reverse order.
+ for (Operation &op :
+ llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
+ eraseTree(&op);
+ // Do not erase the block immediately. This is not supprted by the
+ // post_order iterator.
+ erasedBlocks.push_back(b);
+ }
+ for (Block *b : erasedBlocks) {
+ // Explicitly drop all uses in case there is a cycle in the block
+ // graph.
+ for (BlockArgument bbArg : b->getArguments())
+ bbArg.dropAllUses();
+ b->dropAllUses();
+ b->erase();
+ }
+ }
+ }
+ // Then erase the enclosing op.
+ eraseSingleOp(op);
+ };
+
+ eraseTree(op);
}
void RewriterBase::eraseBlock(Block *block) {
diff --git a/mlir/lib/IR/RegionKindInterface.cpp b/mlir/lib/IR/RegionKindInterface.cpp
index cbef3025a5dd626..007f4cf92dbc7ae 100644
--- a/mlir/lib/IR/RegionKindInterface.cpp
+++ b/mlir/lib/IR/RegionKindInterface.cpp
@@ -18,9 +18,17 @@ using namespace mlir;
#include "mlir/IR/RegionKindInterface.cpp.inc"
bool mlir::mayHaveSSADominance(Region ®ion) {
- auto regionKindOp =
- dyn_cast_if_present<RegionKindInterface>(region.getParentOp());
+ auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
if (!regionKindOp)
return true;
return regionKindOp.hasSSADominance(region.getRegionNumber());
}
+
+bool mlir::mayBeGraphRegion(Region ®ion) {
+ if (!region.getParentOp()->isRegistered())
+ return true;
+ auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
+ if (!regionKindOp)
+ return false;
+ return !regionKindOp.hasSSADominance(region.getRegionNumber());
+}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index fba4944f130c230..8e2bfe557c555f3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -421,8 +421,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// If the operation is trivially dead - remove it.
if (isOpTriviallyDead(op)) {
- notifyOperationRemoved(op);
- op->erase();
+ eraseOp(op);
changed = true;
LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
@@ -567,10 +566,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
config.listener->notifyOperationRemoved(op);
addOperandsToWorklist(op->getOperands());
- op->walk([this](Operation *operation) {
- worklist.remove(operation);
- folder.notifyRemoval(operation);
- });
+ worklist.remove(op);
+ folder.notifyRemoval(op);
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 5df2d6d1fdeeb38..a5ab8f97c74ce33 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -12,9 +12,9 @@
// CHECK-EN-LABEL: func @test_erase
// CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true}
-// CHECK-EN: test.arg0
-// CHECK-EN: test.arg1
-// CHECK-EN-NOT: test.erase_op
+// CHECK-EN: "test.arg0"
+// CHECK-EN: "test.arg1"
+// CHECK-EN-NOT: "test.erase_op"
func.func @test_erase() {
%0 = "test.arg0"() : () -> (i32)
%1 = "test.arg1"() : () -> (i32)
@@ -51,13 +51,13 @@ func.func @test_replace_with_new_op() {
// CHECK-EN-LABEL: func @test_replace_with_erase_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
-// CHECK-EN-NOT: test.replace_with_new_op
-// CHECK-EN-NOT: test.erase_op
+// CHECK-EN-NOT: "test.replace_with_new_op"
+// CHECK-EN-NOT: "test.erase_op"
// CHECK-EX-LABEL: func @test_replace_with_erase_op
// CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
-// CHECK-EX-NOT: test.replace_with_new_op
-// CHECK-EX: test.erase_op
+// CHECK-EX-NOT: "test.replace_with_new_op"
+// CHECK-EX: "test.erase_op"
func.func @test_replace_with_erase_op() {
"test.replace_with_new_op"() {create_erase_op} : () -> ()
return
@@ -83,3 +83,149 @@ func.func @test_trigger_rewrite_through_block() {
// in turn, replaces the successor with bb3.
"test.implicit_change_op"() [^bb1] : () -> ()
}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.foo_b
+// CHECK-AN: notifyOperationRemoved: test.foo_a
+// CHECK-AN: notifyOperationRemoved: test.graph_region
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_graph_region()
+// CHECK-AN-NEXT: return
+func.func @test_remove_graph_region() {
+ "test.erase_op"() ({
+ test.graph_region {
+ %0 = "test.foo_a"(%1) : (i1) -> (i1)
+ %1 = "test.foo_b"(%0) : (i1) -> (i1)
+ }
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_cyclic_blocks()
+// CHECK-AN-NEXT: return
+func.func @test_remove_cyclic_blocks() {
+ "test.erase_op"() ({
+ %x = "test.dummy_op"() : () -> (i1)
+ cf.br ^bb1(%x: i1)
+ ^bb1(%arg0: i1):
+ "test.foo"(%x) : (i1) -> ()
+ cf.br ^bb2(%arg0: i1)
+ ^bb2(%arg1: i1):
+ "test.bar"(%x) : (i1) -> ()
+ cf.br ^bb1(%arg1: i1)
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: test.qux
+// CHECK-AN: notifyOperationRemoved: test.qux_unreachable
+// CHECK-AN: notifyOperationRemoved: test.nested_dummy
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_dead_blocks()
+// CHECK-AN-NEXT: return
+func.func @test_remove_dead_blocks() {
+ "test.erase_op"() ({
+ "test.dummy_op"() : () -> (i1)
+ // The following blocks are not reachable. Still, ^bb2 should be deleted
+ // befire ^bb1.
+ ^bb1(%arg0: i1):
+ "test.foo"() : () -> ()
+ cf.br ^bb2(%arg0: i1)
+ ^bb2(%arg1: i1):
+ "test.nested_dummy"() ({
+ "test.qux"() : () -> ()
+ // The following block is unreachable.
+ ^bb3:
+ "test.qux_unreachable"() : () -> ()
+ }) : () -> ()
+ "test.bar"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// test.nested_* must be deleted before test.foo.
+// test.bar must be deleted before test.foo.
+
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_b
+// CHECK-AN: notifyOperationRemoved: test.nested_a
+// CHECK-AN: notifyOperationRemoved: test.nested_d
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_e
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_c
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_nested_ops()
+// CHECK-AN-NEXT: return
+func.func @test_remove_nested_ops() {
+ "test.erase_op"() ({
+ %x = "test.dummy_op"() : () -> (i1)
+ cf.br ^bb1(%x: i1)
+ ^bb1(%arg0: i1):
+ "test.foo"() ({
+ "test.nested_a"() : () -> ()
+ "test.nested_b"() : () -> ()
+ ^dead1:
+ "test.nested_c"() : () -> ()
+ cf.br ^dead3
+ ^dead2:
+ "test.nested_d"() : () -> ()
+ ^dead3:
+ "test.nested_e"() : () -> ()
+ cf.br ^dead2
+ }) : () -> ()
+ cf.br ^bb2(%arg0: i1)
+ ^bb2(%arg1: i1):
+ "test.bar"(%x) : (i1) -> ()
+ cf.br ^bb1(%arg1: i1)
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.qux
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.cond_br
+// CHECK-AN-LABEL: func @test_remove_diamond(
+// CHECK-AN-NEXT: return
+func.func @test_remove_diamond(%c: i1) {
+ "test.erase_op"() ({
+ cf.cond_br %c, ^bb1, ^bb2
+ ^bb1:
+ "test.foo"() : () -> ()
+ cf.br ^bb3
+ ^bb2:
+ "test.bar"() : () -> ()
+ cf.br ^bb3
+ ^bb3:
+ "test.qux"() : () -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e23ed105e383390..2e3bc76009ca208 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -239,6 +239,12 @@ struct TestPatternDriver
llvm::cl::init(GreedyRewriteConfig().maxIterations)};
};
+struct DumpNotifications : public RewriterBase::Listener {
+ void notifyOperationRemoved(Operation *op) override {
+ llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
+ }
+};
+
struct TestStrictPatternDriver
: public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
public:
@@ -275,7 +281,9 @@ struct TestStrictPatternDriver
}
});
+ DumpNotifications dumpNotifications;
GreedyRewriteConfig config;
+ config.listener = &dumpNotifications;
if (strictMode == "AnyOp") {
config.strictMode = GreedyRewriteStrictness::AnyOp;
} else if (strictMode == "ExistingAndNewOps") {
More information about the Mlir-commits
mailing list