[llvm-branch-commits] [mlir] [mlir][Transforms] Add support for `ConversionPatternRewriter::replaceAllUsesWith` (PR #155244)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Aug 25 06:30:09 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/155244
Depends on #155242.
>From b217bce2ba7ecaf94d1e6364cac7b75f4ffb3f41 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 23 Aug 2025 10:36:37 +0000
Subject: [PATCH] [mlir][Transforms] Add support for
`ConversionPatternRewriter::replaceAllUsesWith`
---
mlir/include/mlir/IR/PatternMatch.h | 2 +-
.../mlir/Transforms/DialectConversion.h | 17 +-
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 2 +-
.../Transforms/Utils/DialectConversion.cpp | 158 +++++++++++-------
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 5 +-
5 files changed, 112 insertions(+), 72 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c1..7b0b9cef9c5bd 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder {
/// Find uses of `from` and replace them with `to`. Also notify the listener
/// about every in-place op modification (for every use that was replaced).
- void replaceAllUsesWith(Value from, Value to) {
+ virtual void replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f23a70601fc0a..ffad78db3ca87 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -780,15 +780,18 @@ class ConversionPatternRewriter final : public PatternRewriter {
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);
- /// Replace all the uses of the block argument `from` with `to`. This
- /// function supports both 1:1 and 1:N replacements.
+ /// Replace all the uses of `from` with `to`. This function supports both 1:1
+ /// and 1:N replacements.
///
/// Note: If `allowPatternRollback` is set to "true", this function replaces
- /// all current and future uses of the block argument. This same block
- /// block argument must not be replaced multiple times. Uses are not replaced
- /// immediately but in a delayed fashion. Patterns may still see the original
- /// uses when inspecting IR.
- void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
+ /// all current and future uses of the `from` value. This same value must not
+ /// be replaced multiple times. Uses are not replaced immediately but in a
+ /// delayed fashion. Patterns may still see the original uses when inspecting
+ /// IR.
+ void replaceAllUsesWith(Value from, ValueRange to);
+ void replaceAllUsesWith(Value from, Value to) override {
+ replaceAllUsesWith(from, ValueRange{to});
+ }
/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 42c76ed475b4c..93fe2edad5274 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -284,7 +284,7 @@ static void restoreByValRefArgumentType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
- rewriter.replaceUsesOfBlockArgument(arg, valueArg);
+ rewriter.replaceAllUsesWith(arg, valueArg);
}
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e3248204d6694..ce8e314ed6f7b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -277,13 +277,14 @@ class IRRewrite {
InlineBlock,
MoveBlock,
BlockTypeConversion,
- ReplaceBlockArg,
// Operation rewrites
MoveOperation,
ModifyOperation,
ReplaceOperation,
CreateOperation,
- UnresolvedMaterialization
+ UnresolvedMaterialization,
+ // Value rewrites
+ ReplaceValue
};
virtual ~IRRewrite() = default;
@@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite {
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::CreateBlock &&
- rewrite->getKind() <= Kind::ReplaceBlockArg;
+ rewrite->getKind() <= Kind::BlockTypeConversion;
}
protected:
@@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite {
Block *block;
};
+/// A value rewrite.
+class ValueRewrite : public IRRewrite {
+public:
+ /// Return the value that this rewrite operates on.
+ Value getValue() const { return value; }
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceValue;
+ }
+
+protected:
+ ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+ Value value)
+ : IRRewrite(kind, rewriterImpl), value(value) {}
+
+ // The value that this rewrite operates on.
+ Value value;
+};
+
/// Creation of a block. Block creations are immediately reflected in the IR.
/// There is no extra work to commit the rewrite. During rollback, the newly
/// created block is erased.
@@ -548,19 +568,18 @@ class BlockTypeConversionRewrite : public BlockRewrite {
Block *newBlock;
};
-/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// Replacing a value. This rewrite is not immediately reflected in the
/// IR. An internal IR mapping is updated, but the actual replacement is delayed
/// until the rewrite is committed.
-class ReplaceBlockArgRewrite : public BlockRewrite {
+class ReplaceValueRewrite : public ValueRewrite {
public:
- ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block, BlockArgument arg,
- const TypeConverter *converter)
- : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
+ ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
+ const TypeConverter *converter)
+ : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
- return rewrite->getKind() == Kind::ReplaceBlockArg;
+ return rewrite->getKind() == Kind::ReplaceValue;
}
void commit(RewriterBase &rewriter) override;
@@ -568,9 +587,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
void rollback() override;
private:
- BlockArgument arg;
-
- /// The current type converter when the block argument was replaced.
+ /// The current type converter when the value was replaced.
const TypeConverter *converter;
};
@@ -942,10 +959,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// uses.
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
- /// Replace the given block argument with the given values. The specified
+ /// Replace the uses of the given value with the given values. The specified
/// converter is used to build materializations (if necessary).
- void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
- const TypeConverter *converter);
+ void replaceAllUsesWith(Value from, ValueRange to,
+ const TypeConverter *converter);
/// Erase the given block and its contents.
void eraseBlock(Block *block);
@@ -1132,10 +1149,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
IRRewriter notifyingRewriter;
#ifndef NDEBUG
- /// A set of replaced block arguments. This set is for debugging purposes
- /// only and it is maintained only if `allowPatternRollback` is set to
- /// "true".
- DenseSet<BlockArgument> replacedArgs;
+ /// A set of replaced values. This set is for debugging purposes only and it
+ /// is maintained only if `allowPatternRollback` is set to "true".
+ DenseSet<Value> replacedValues;
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
@@ -1172,32 +1188,54 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
- Value repl) {
+/// Replace all uses of `from` with `repl`.
+static void performReplaceValue(RewriterBase &rewriter, Value from,
+ Value repl) {
if (isa<BlockArgument>(repl)) {
- rewriter.replaceAllUsesWith(arg, repl);
+ // `repl` is a block argument. Directly replace all uses.
+ rewriter.replaceAllUsesWith(from, repl);
return;
}
- // If the replacement value is an operation, we check to make sure that we
- // don't replace uses that are within the parent operation of the
- // replacement value.
- Operation *replOp = cast<OpResult>(repl).getOwner();
+ // If the replacement value is an operation, only replace those uses that:
+ // - are in a different block than the replacement operation, or
+ // - are in the same block but after the replacement operation.
+ //
+ // Example:
+ // ^bb0(%arg0: i32):
+ // %0 = "consumer"(%arg0) : (i32) -> (i32)
+ // "another_consumer"(%arg0) : (i32) -> ()
+ //
+ // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
+ // use in "another_consumer" but not the use in "consumer". When using the
+ // normal RewriterBase API, this would typically be done with
+ // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
+ // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
+ // it cannot be supported efficiently with `allowPatternRollback` set to
+ // "true". Therefore, the conversion driver is trying to be smart and replaces
+ // only those uses that do not lead to a dominance violation. E.g., the
+ // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
+ // behavior.
+ //
+ // TODO: As we move more and more towards `allowPatternRollback` set to
+ // "false", we should remove this special handling, in order to align the
+ // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
+ Operation *replOp = repl.getDefiningOp();
Block *replBlock = replOp->getBlock();
- rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
});
}
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
if (!repl)
return;
- performReplaceBlockArg(rewriter, arg, repl);
+ performReplaceValue(rewriter, value, repl);
}
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
+void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
auto *listener =
@@ -1590,7 +1628,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
/*castOp=*/nullptr, /*isPureTypeConversion=*/false)
.front();
- replaceUsesOfBlockArgument(origArg, mat, converter);
+ replaceAllUsesWith(origArg, mat, converter);
continue;
}
@@ -1599,15 +1637,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
- replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
- converter);
+ replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
continue;
}
// This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- replaceUsesOfBlockArgument(origArg, replArgs, converter);
+ replaceAllUsesWith(origArg, replArgs, converter);
}
if (config.allowPatternRollback)
@@ -1882,8 +1919,8 @@ void ConversionPatternRewriterImpl::replaceOp(
op->walk([&](Operation *op) { replacedOps.insert(op); });
}
-void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
- BlockArgument from, ValueRange to, const TypeConverter *converter) {
+void ConversionPatternRewriterImpl::replaceAllUsesWith(
+ Value from, ValueRange to, const TypeConverter *converter) {
if (!config.allowPatternRollback) {
SmallVector<Value> toConv = llvm::to_vector(to);
SmallVector<Value> repls =
@@ -1893,25 +1930,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
if (!repl)
return;
- performReplaceBlockArg(r, from, repl);
+ performReplaceValue(r, from, repl);
return;
}
#ifndef NDEBUG
- // Make sure that a block argument is not replaced multiple times. In
- // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
- // uses of the given block argument, but also all future uses that may be
- // introduced by future pattern applications. Therefore, it does not make
- // sense to call `replaceUsesOfBlockArgument` multiple times with the same
- // block argument. Doing so would overwrite the mapping and mess with the
- // internal state of the dialect conversion driver.
- assert(!replacedArgs.contains(from) &&
- "attempting to replace a block argument that was already replaced");
- replacedArgs.insert(from);
+ // Make sure that a value is not replaced multiple times. In rollback mode,
+ // `replaceAllUsesWith` replaces not only all current uses of the given value,
+ // but also all future uses that may be introduced by future pattern
+ // applications. Therefore, it does not make sense to call
+ // `replaceAllUsesWith` multiple times with the same value. Doing so would
+ // overwrite the mapping and mess with the internal state of the dialect
+ // conversion driver.
+ assert(!replacedValues.contains(from) &&
+ "attempting to replace a value that was already replaced");
+ replacedValues.insert(from);
#endif // NDEBUG
- appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
mapping.map(from, to);
+ appendRewrite<ReplaceValueRewrite>(from, converter);
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
@@ -2116,18 +2153,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
return impl->convertRegionTypes(*this, region, converter, entryConversion);
}
-void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
- ValueRange to) {
+void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
LLVM_DEBUG({
- impl->logger.startLine() << "** Replace Argument : '" << from << "'";
- if (Operation *parentOp = from.getOwner()->getParentOp()) {
- impl->logger.getOStream() << " (in region of '" << parentOp->getName()
- << "' (" << parentOp << ")\n";
- } else {
- impl->logger.getOStream() << " (unlinked block)\n";
+ impl->logger.startLine() << "** Replace Value : '" << from << "'";
+ if (auto blockArg = dyn_cast<BlockArgument>(from)) {
+ if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
+ impl->logger.getOStream() << " (in region of '" << parentOp->getName()
+ << "' (" << parentOp << ")\n";
+ } else {
+ impl->logger.getOStream() << " (unlinked block)\n";
+ }
}
});
- impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
+ impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -2185,7 +2223,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// Replace all uses of block arguments.
for (auto it : llvm::zip(source->getArguments(), argValues))
- replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+ replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
if (fastPath) {
// Move all ops at once.
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b6f16ac1b5c48..e0a004b706be4 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -951,7 +951,7 @@ struct TestCreateIllegalBlock : public RewritePattern {
}
};
-/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
+/// A simple pattern that tests the "replaceAllUsesWith" API.
struct TestBlockArgReplace : public ConversionPattern {
TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
: ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
@@ -962,8 +962,7 @@ struct TestBlockArgReplace : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
// Replace the first block argument with 2x the second block argument.
Value repl = op->getRegion(0).getArgument(1);
- rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
- {repl, repl});
+ rewriter.replaceAllUsesWith(op->getRegion(0).getArgument(0), {repl, repl});
rewriter.modifyOpInPlace(op, [&] {
// If the "trigger_rollback" attribute is set, keep the op illegal, so
// that a rollback is triggered.
More information about the llvm-branch-commits
mailing list