[Mlir-commits] [mlir] [mlir][IR] Make `replaceOp` / `replaceAllUsesWith` API consistent (PR #82629)
Matthias Springer
llvmlistbot at llvm.org
Wed Mar 6 17:08:38 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/82629
>From 5074d55305ea160a145567234cda4e25a7644324 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 22 Feb 2024 15:03:30 +0000
Subject: [PATCH] [mlir][IR] Make `replaceOp` / `replaceAllUsesWith` API
consistent
* `replaceOp` replaces all uses of the original op and erases the old op.
* `replaceAllUsesWith` replaces all uses of the original op/value/block. It does not erase any IR.
This commit renames `replaceOpWithIf` to `replaceUsesWithIf`. `replaceOpWithIf` was a misnomer because the function never erases the original op. Similarly, `replaceOpWithinBlock` is renamed to `replaceUsesWithinBlock`.
Also improve comments.
BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
mlir/include/mlir/IR/PatternMatch.h | 87 +++++++++----------
.../mlir/Transforms/DialectConversion.h | 6 --
.../Linalg/Transforms/DecomposeLinalgOps.cpp | 4 +-
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 2 +-
mlir/lib/IR/PatternMatch.cpp | 73 ++++++----------
.../Transforms/Utils/DialectConversion.cpp | 13 ---
mlir/lib/Transforms/Utils/RegionUtils.cpp | 2 +-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 2 +-
8 files changed, 74 insertions(+), 115 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index f8d22cfb22afd0..e3500b3f9446d8 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -497,42 +497,19 @@ class RewriterBase : public OpBuilder {
Region::iterator before);
void inlineRegionBefore(Region ®ion, Block *before);
- /// This method replaces the uses of the results of `op` with the values in
- /// `newValues` when the provided `functor` returns true for a specific use.
- /// The number of values in `newValues` is required to match the number of
- /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
- /// the uses of `op` were replaced. Note that in some rewriters, the given
- /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
- /// As such, the function should not capture by reference and instead use
- /// value capture as necessary.
- virtual void
- replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor);
- void replaceOpWithIf(Operation *op, ValueRange newValues,
- llvm::unique_function<bool(OpOperand &) const> functor) {
- replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
- std::move(functor));
- }
-
- /// This method replaces the uses of the results of `op` with the values in
- /// `newValues` when a use is nested within the given `block`. The number of
- /// values in `newValues` is required to match the number of results of `op`.
- /// If all uses of this operation are replaced, the operation is erased.
- void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
- bool *allUsesReplaced = nullptr);
-
- /// This method replaces the results of the operation with the specified list
- /// of values. The number of provided values must match the number of results
- /// of the operation. The replaced op is erased.
+ /// Replace the results of the given (original) operation with the specified
+ /// list of values (replacements). The result types of the given op and the
+ /// replacements must match. The original op is erased.
virtual void replaceOp(Operation *op, ValueRange newValues);
- /// This method replaces the results of the operation with the specified
- /// new op (replacement). The number of results of the two operations must
- /// match. The replaced op is erased.
+ /// Replace the results of the given (original) operation with the specified
+ /// new op (replacement). The result types of the two ops must match. The
+ /// original op is erased.
virtual void replaceOp(Operation *op, Operation *newOp);
- /// Replaces the result op with a new op that is created without verification.
- /// The result values of the two ops must be the same types.
+ /// Replace the results of the given (original) op with a new op that is
+ /// created without verification (replacement). The result values of the two
+ /// ops must match. The original op is erased.
template <typename OpTy, typename... Args>
OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
@@ -634,9 +611,8 @@ class RewriterBase : public OpBuilder {
finalizeOpModification(root);
}
- /// Find uses of `from` and replace them with `to`. It also marks every
- /// modified uses and notifies the rewriter that an in-place operation
- /// modification is about to happen.
+ /// Find uses of `from` and replace them with `to`. Also notify the listener
+ /// about every in-place op modification (for every use that was replaced).
void replaceAllUsesWith(Value from, Value to) {
return replaceAllUsesWith(from.getImpl(), to);
}
@@ -652,22 +628,43 @@ class RewriterBase : public OpBuilder {
for (auto it : llvm::zip(from, to))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
}
+ void replaceAllUsesWith(Operation *from, ValueRange to) {
+ replaceAllUsesWith(from->getResults(), to);
+ }
/// Find uses of `from` and replace them with `to` if the `functor` returns
- /// true. It also marks every modified uses and notifies the rewriter that an
- /// in-place operation modification is about to happen.
+ /// true. Also notify the listener about every in-place op modification (for
+ /// every use that was replaced). The optional `allUsesReplaced` flag is set
+ /// to "true" if all uses were replaced.
void replaceUsesWithIf(Value from, Value to,
- function_ref<bool(OpOperand &)> functor);
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr);
void replaceUsesWithIf(ValueRange from, ValueRange to,
- function_ref<bool(OpOperand &)> functor) {
- assert(from.size() == to.size() && "incorrect number of replacements");
- for (auto it : llvm::zip(from, to))
- replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr);
+ void replaceUsesWithIf(Operation *from, ValueRange to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr) {
+ replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
+ }
+
+ /// Find uses of `from` within `block` and replace them with `to`. Also notify
+ /// the listener about every in-place op modification (for every use that was
+ /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
+ /// uses were replaced.
+ void replaceUsesWithinBlock(Operation *op, ValueRange newValues, Block *block,
+ bool *allUsesReplaced = nullptr) {
+ replaceUsesWithIf(
+ op, newValues,
+ [block](OpOperand &use) {
+ return block->getParentOp()->isProperAncestor(use.getOwner());
+ },
+ allUsesReplaced);
}
/// Find uses of `from` and replace them with `to` except if the user is
- /// `exceptedUser`. It also marks every modified uses and notifies the
- /// rewriter that an in-place operation modification is about to happen.
+ /// `exceptedUser`. Also notify the listener about every in-place op
+ /// modification (for every use that was replaced).
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
return replaceUsesWithIf(from, to, [&](OpOperand &use) {
Operation *user = use.getOwner();
@@ -675,7 +672,7 @@ class RewriterBase : public OpBuilder {
});
}
- /// Used to notify the rewriter that the IR failed to be rewritten because of
+ /// Used to notify the listener that the IR failed to be rewritten because of
/// a match failure, and provide a callback to populate a diagnostic with the
/// reason why the failure occurred. This method allows for derived rewriters
/// to optionally hook into the reason why a rewrite failed, and display it to
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 84396529eb7c2e..01fde101ef3cb6 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -720,12 +720,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// patterns even if a failure is encountered during the rewrite step.
bool canRecoverFromRewriteFailure() const override { return true; }
- /// PatternRewriter hook for replacing an operation when the given functor
- /// returns "true".
- void replaceOpWithIf(
- Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor) override;
-
/// PatternRewriter hook for replacing an operation.
void replaceOp(Operation *op, ValueRange newValues) override;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 5cd6d4597affaf..1658ea67a46077 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -370,8 +370,8 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
scalarReplacements.push_back(
residualGenericOpBody->getArgument(num + origNumInputs));
bool allUsesReplaced = false;
- rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
- residualGenericOpBody, &allUsesReplaced);
+ rewriter.replaceUsesWithinBlock(peeledScalarOperation, scalarReplacements,
+ residualGenericOpBody, &allUsesReplaced);
assert(!allUsesReplaced &&
"peeled scalar operation is erased when it wasnt expected to be");
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 43c408a97687ce..74f6d97aeea53c 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -870,7 +870,7 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
{getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
Value materialized =
getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
- b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
+ b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
return use.getOwner() != materialized.getDefiningOp();
});
}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 8796289d725707..0a88e40f73ec6c 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -110,41 +110,6 @@ RewriterBase::~RewriterBase() {
// Out of line to provide a vtable anchor for the class.
}
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when the provided `functor` returns true for a specific use.
-/// The number of values in `newValues` is required to match the number of
-/// results of `op`.
-void RewriterBase::replaceOpWithIf(
- Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor) {
- assert(op->getNumResults() == newValues.size() &&
- "incorrect number of values to replace operation");
-
- // Notify the listener that we're about to replace this op.
- if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
- rewriteListener->notifyOperationReplaced(op, newValues);
-
- // Replace each use of the results when the functor is true.
- bool replacedAllUses = true;
- for (auto it : llvm::zip(op->getResults(), newValues)) {
- replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
- replacedAllUses &= std::get<0>(it).use_empty();
- }
- if (allUsesReplaced)
- *allUsesReplaced = replacedAllUses;
-}
-
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when a use is nested within the given `block`. The number of
-/// values in `newValues` is required to match the number of results of `op`.
-/// If all uses of this operation are replaced, the operation is erased.
-void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
- Block *block, bool *allUsesReplaced) {
- replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
- return block->getParentOp()->isProperAncestor(use.getOwner());
- });
-}
-
/// This method replaces the results of the operation with the specified list of
/// values. The number of provided values must match the number of results of
/// the operation. The replaced op is erased.
@@ -156,9 +121,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newValues);
- // Replace results one-by-one. Also notifies the listener of modifications.
- for (auto it : llvm::zip(op->getResults(), newValues))
- replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+ // Replace all result uses. Also notifies the listener of modifications.
+ replaceAllUsesWith(op, newValues);
// Erase op and notify listener.
eraseOp(op);
@@ -176,9 +140,8 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyOperationReplaced(op, newOp);
- // Replace results one-by-one. Also notifies the listener of modifications.
- for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
- replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+ // Replace all result uses. Also notifies the listener of modifications.
+ replaceAllUsesWith(op, newOp->getResults());
// Erase op and notify listener.
eraseOp(op);
@@ -279,15 +242,33 @@ void RewriterBase::finalizeOpModification(Operation *op) {
rewriteListener->notifyOperationModified(op);
}
-/// Find uses of `from` and replace them with `to` if the `functor` returns
-/// true. It also marks every modified uses and notifies the rewriter that an
-/// in-place operation modification is about to happen.
void RewriterBase::replaceUsesWithIf(Value from, Value to,
- function_ref<bool(OpOperand &)> functor) {
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced) {
+ bool allReplaced = true;
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
- if (functor(operand))
+ bool replace = functor(operand);
+ if (replace)
modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
+ allReplaced &= replace;
}
+ if (allUsesReplaced)
+ *allUsesReplaced = allReplaced;
+}
+
+void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced) {
+ assert(from.size() == to.size() && "incorrect number of replacements");
+ bool allReplaced = true;
+ for (auto it : llvm::zip_equal(from, to)) {
+ bool r;
+ replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
+ /*allUsesReplaced=*/&r);
+ allReplaced &= r;
+ }
+ if (allUsesReplaced)
+ *allUsesReplaced = allReplaced;
}
void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4741110bc60682..d7dc902a9a5ebd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1528,19 +1528,6 @@ ConversionPatternRewriter::ConversionPatternRewriter(
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
-void ConversionPatternRewriter::replaceOpWithIf(
- Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor) {
- // TODO: To support this we will need to rework a bit of how replacements are
- // tracked, given that this isn't guranteed to replace all of the uses of an
- // operation. The main change is that now an operation can be replaced
- // multiple times, in parts. The current "set" based tracking is mainly useful
- // for tracking if a replaced operation should be ignored, i.e. if all of the
- // uses will be replaced.
- llvm_unreachable(
- "replaceOpWithIf is currently not supported by DialectConversion");
-}
-
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
assert(op && newOp && "expected non-null op");
replaceOp(op, newOp->getResults());
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e8b07143fc60bd..eff8acdfb33d20 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -161,7 +161,7 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
rewriter.setInsertionPointToStart(newEntryBlock);
for (auto *clonedOp : clonedOperations) {
Operation *newOp = rewriter.clone(*clonedOp, map);
- rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn);
+ rewriter.replaceUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
}
rewriter.mergeBlocks(
entryBlock, newEntryBlock,
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index abc0e43c7b7f2d..27eae2ffd694b5 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1836,7 +1836,7 @@ struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
OperandRange operands = op.getOperands();
// Replace non-terminator uses with the first operand.
- rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
});
// Replace everything else with the second operand if the operation isn't
More information about the Mlir-commits
mailing list