[Mlir-commits] [mlir] 38113a0 - [mlir][IR] Trigger `notifyOperationReplaced` on `replaceAllOpUsesWith` (#84721)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 1 18:54:01 PDT 2024
Author: Matthias Springer
Date: 2024-04-02T10:53:57+09:00
New Revision: 38113a083283d2f30a677befaa5fb86dce731c8b
URL: https://github.com/llvm/llvm-project/commit/38113a083283d2f30a677befaa5fb86dce731c8b
DIFF: https://github.com/llvm/llvm-project/commit/38113a083283d2f30a677befaa5fb86dce731c8b.diff
LOG: [mlir][IR] Trigger `notifyOperationReplaced` on `replaceAllOpUsesWith` (#84721)
Before this change: `notifyOperationReplaced` was triggered when calling
`RewriteBase::replaceOp`.
After this change: `notifyOperationReplaced` is triggered when
`RewriterBase::replaceAllOpUsesWith` or `RewriterBase::replaceOp` is
called.
Until now, every `notifyOperationReplaced` was always sent together with
a `notifyOperationErased`, which made that `notifyOperationErased`
callback irrelevant. More importantly, when a user called
`RewriterBase::replaceAllOpUsesWith`+`RewriterBase::eraseOp` instead of
`RewriterBase::replaceOp`, no `notifyOperationReplaced` callback was
sent, even though the two notations are semantically equivalent. As an
example, this can be a problem when applying patterns with the transform
dialect because the `TrackingListener` will only see the
`notifyOperationErased` callback and the payload op is dropped from the
mappings.
Note: It is still possible to write semantically equivalent code that
does not trigger a `notifyOperationReplaced` (e.g., when op results are
replaced one-by-one), but this commit already improves the situation a
lot.
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/IR/PatternMatch.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 070e6ed702f86a..ac2b0d5a38375a 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -409,9 +409,9 @@ class RewriterBase : public OpBuilder {
/// Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationModified(Operation *op) {}
- /// Notify the listener that the specified operation is about to be replaced
- /// with another operation. This is called before the uses of the old
- /// operation have been changed.
+ /// Notify the listener that all uses of the specified operation's results
+ /// are about to be replaced with the results of another operation. This is
+ /// called before the uses of the old operation have been changed.
///
/// By default, this function calls the "operation replaced with values"
/// notification.
@@ -420,9 +420,10 @@ class RewriterBase : public OpBuilder {
notifyOperationReplaced(op, replacement->getResults());
}
- /// Notify the listener that the specified operation is about to be replaced
- /// with the a range of values, potentially produced by other operations.
- /// This is called before the uses of the operation have been changed.
+ /// Notify the listener that all uses of the specified operation's results
+ /// are about to be replaced with the a range of values, potentially
+ /// produced by other operations. This is called before the uses of the
+ /// operation have been changed.
virtual void notifyOperationReplaced(Operation *op,
ValueRange replacement) {}
@@ -648,12 +649,16 @@ class RewriterBase : public OpBuilder {
for (auto it : llvm::zip(from, to))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
}
- // Note: This function cannot be called `replaceAllUsesWith` because the
- // overload resolution, when called with an op that can be implicitly
- // converted to a Value, would be ambiguous.
- void replaceAllOpUsesWith(Operation *from, ValueRange to) {
- replaceAllUsesWith(from->getResults(), to);
- }
+
+ /// Find uses of `from` and replace them with `to`. Also notify the listener
+ /// about every in-place op modification (for every use that was replaced)
+ /// and that the `from` operation is about to be replaced.
+ ///
+ /// Note: This function cannot be called `replaceAllUsesWith` because the
+ /// overload resolution, when called with an op that can be implicitly
+ /// converted to a Value, would be ambiguous.
+ void replaceAllOpUsesWith(Operation *from, ValueRange to);
+ void replaceAllOpUsesWith(Operation *from, Operation *to);
/// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. Also notify the listener about every in-place op modification (for
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 4079ccc7567256..5944a0ea46a143 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -110,6 +110,22 @@ RewriterBase::~RewriterBase() {
// Out of line to provide a vtable anchor for the class.
}
+void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
+ // Notify the listener that we're about to replace this op.
+ if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ rewriteListener->notifyOperationReplaced(from, to);
+
+ replaceAllUsesWith(from->getResults(), to);
+}
+
+void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
+ // Notify the listener that we're about to replace this op.
+ if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ rewriteListener->notifyOperationReplaced(from, to);
+
+ replaceAllUsesWith(from->getResults(), to->getResults());
+}
+
/// This method replaces the results of the operation with the specified list of
/// values. The number of provided values must match the number of results of
/// the operation. The replaced op is erased.
@@ -117,10 +133,6 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
- // Notify the listener that we're about to replace this op.
- if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
- rewriteListener->notifyOperationReplaced(op, newValues);
-
// Replace all result uses. Also notifies the listener of modifications.
replaceAllOpUsesWith(op, newValues);
@@ -136,10 +148,6 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
assert(op->getNumResults() == newOp->getNumResults() &&
"ops have
diff erent number of results");
- // Notify the listener that we're about to replace this op.
- if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
- rewriteListener->notifyOperationReplaced(op, newOp);
-
// Replace all result uses. Also notifies the listener of modifications.
replaceAllOpUsesWith(op, newOp->getResults());
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 2da184bc3d85ba..76dc825fe44515 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -489,7 +489,10 @@ struct TestStrictPatternDriver
OperationName("test.new_op", op->getContext()).getIdentifier(),
op->getOperands(), op->getResultTypes());
}
- rewriter.replaceOp(op, newOp->getResults());
+ // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
+ // A "notifyOperationReplaced" callback is triggered in either case.
+ rewriter.replaceAllOpUsesWith(op, newOp->getResults());
+ rewriter.eraseOp(op);
return success();
}
};
More information about the Mlir-commits
mailing list