[Mlir-commits] [mlir] 21f4b84 - [mlir][IR] Trigger notifyOperationModified for replacements
Matthias Springer
llvmlistbot at llvm.org
Mon Mar 6 01:08:50 PST 2023
Author: Matthias Springer
Date: 2023-03-06T10:07:48+01:00
New Revision: 21f4b84c456b471cc52016cf360e14d45f7f2960
URL: https://github.com/llvm/llvm-project/commit/21f4b84c456b471cc52016cf360e14d45f7f2960
DIFF: https://github.com/llvm/llvm-project/commit/21f4b84c456b471cc52016cf360e14d45f7f2960.diff
LOG: [mlir][IR] Trigger notifyOperationModified for replacements
Each user of the original value is modified in-place. Therefore, the corresponding notification should be triggered.
Also fixes a bug where `RewriterBase::mergeBlocks` did not notify the GreedyPatternRewriteDriver when replacing uses of block arguments. This function does not trigger "operation replaced" notifications, so the GreedyPatternRewriteDriver was not made aware of such IR changes.
Differential Revision: https://reviews.llvm.org/D144549
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/IR/PatternMatch.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index acc7e3ef8e1d..ed431badf05b 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -555,9 +555,8 @@ class RewriterBase : public OpBuilder {
/// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. It also marks every modified uses and notifies the rewriter that an
/// in-place operation modification is about to happen.
- void
- replaceUsesWithIf(Value from, Value to,
- llvm::unique_function<bool(OpOperand &) const> functor);
+ void replaceUsesWithIf(Value from, Value to,
+ function_ref<bool(OpOperand &)> functor);
/// Find uses of `from` and replace them with `to` except if the user is
/// `exceptedUser`. It also marks every modified uses and notifies the
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index a01ccca5d33a..1fc234e10421 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -235,14 +235,14 @@ void RewriterBase::replaceOpWithIf(
assert(op->getNumResults() == newValues.size() &&
"incorrect number of values to replace operation");
- // Notify the rewriter subclass that we're about to replace this root.
+ // Notify the listener that we're about to replace this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newValues);
// Replace each use of the results when the functor is true.
bool replacedAllUses = true;
for (auto it : llvm::zip(op->getResults(), newValues)) {
- std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor);
+ replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
replacedAllUses &= std::get<0>(it).use_empty();
}
if (allUsesReplaced)
@@ -264,13 +264,16 @@ void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
/// values. The number of provided values must match the number of results of
/// the operation.
void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
- // Notify the rewriter subclass that we're about to replace this root.
+ assert(op->getNumResults() == newValues.size() &&
+ "incorrect # of replacement values");
+
+ // Notify the listener that we're about to remove this op.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newValues);
- assert(op->getNumResults() == newValues.size() &&
- "incorrect # of replacement values");
- op->replaceAllUsesWith(newValues);
+ // Replace results one-by-one. Also notifies the listener of modifications.
+ for (auto it : llvm::zip(op->getResults(), newValues))
+ replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationRemoved(op);
@@ -314,7 +317,7 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
// Replace all of the successor arguments with the provided values.
for (auto it : llvm::zip(source->getArguments(), argValues))
- std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+ replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
// Splice the operations of the 'source' block into the 'dest' block and erase
// it.
@@ -326,9 +329,8 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
/// Find uses of `from` and replace them with `to` if the `functor` returns
/// true. It also marks every modified uses and notifies the rewriter that an
/// in-place operation modification is about to happen.
-void RewriterBase::replaceUsesWithIf(
- Value from, Value to,
- llvm::unique_function<bool(OpOperand &) const> functor) {
+void RewriterBase::replaceUsesWithIf(Value from, Value to,
+ function_ref<bool(OpOperand &)> functor) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
if (functor(operand))
updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); });
More information about the Mlir-commits
mailing list