[Mlir-commits] [mlir] [mlir][Transforms] Remove `replaceAllUsesWith` workaround (PR #169609)
Matthias Springer
llvmlistbot at llvm.org
Sun Dec 7 00:49:56 PST 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/169609
>From ad31a2510ca570964dd2a1af1867d01ce1cfe374 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 26 Nov 2025 06:01:37 +0000
Subject: [PATCH] [mlir][Transforms] Remove `replaceAllUsesWith` workaround
---
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 12 +++++--
.../Transforms/Utils/DialectConversion.cpp | 33 +++++++------------
.../test/Transforms/test-convert-func-op.mlir | 3 +-
.../FuncToLLVM/TestConvertFuncOp.cpp | 11 ++++++-
4 files changed, 34 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 2220f61ed8a07..ddd94f5d03042 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -283,8 +283,16 @@ static void restoreByValRefArgumentType(
Type resTy = typeConverter.convertType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
- Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
- rewriter.replaceAllUsesWith(arg, valueArg);
+ auto loadOp = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
+ if (!rewriter.getConfig().allowPatternRollback) {
+ rewriter.replaceAllUsesExcept(arg, loadOp, loadOp);
+ } else {
+ // replaceAllUsesExcept is not supported in rollback mode. The rollback
+ // mode implementation has a workaround: certain replacements that would
+ // cause a dominance violation are skipped.
+ // TODO: Remove workaround.
+ rewriter.replaceAllUsesWith(arg, loadOp);
+ }
}
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 09ad42364baaf..dd9411a1399f4 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1205,17 +1205,14 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-/// Replace all uses of `from` with `repl`.
-static void
-performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
- function_ref<bool(OpOperand &)> functor = nullptr) {
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
+ if (!repl)
+ return;
+
if (isa<BlockArgument>(repl)) {
// `repl` is a block argument. Directly replace all uses.
- if (functor) {
- rewriter.replaceUsesWithIf(from, repl, functor);
- } else {
- rewriter.replaceAllUsesWith(from, repl);
- }
+ rewriter.replaceAllUsesWith(value, repl);
return;
}
@@ -1244,23 +1241,14 @@ performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
// `ConversionPatternRewriter` API with the normal `RewriterBase` API.
Operation *replOp = repl.getDefiningOp();
Block *replBlock = replOp->getBlock();
- rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
bool result =
user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
- if (result && functor)
- result &= functor(operand);
return result;
});
}
-void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
- if (!repl)
- return;
- performReplaceValue(rewriter, value, repl);
-}
-
void ReplaceValueRewrite::rollback() {
rewriterImpl.mapping.erase({value});
#ifndef NDEBUG
@@ -2000,8 +1988,11 @@ void ConversionPatternRewriterImpl::replaceValueUses(
Value repl = repls.front();
if (!repl)
return;
-
- performReplaceValue(r, from, repl, functor);
+ if (functor) {
+ r.replaceUsesWithIf(from, repl, functor);
+ } else {
+ r.replaceAllUsesWith(from, repl);
+ }
return;
}
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 180f16a32991b..14c15ecbe77f0 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=1" --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=0" --split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @add
func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index 75168dde93130..897b11b65b6f2 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -68,6 +68,9 @@ struct TestConvertFuncOp
: public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
+ TestConvertFuncOp() = default;
+ TestConvertFuncOp(const TestConvertFuncOp &other) : PassWrapper(other) {}
+
void getDependentDialects(DialectRegistry ®istry) const final {
registry.insert<LLVM::LLVMDialect>();
}
@@ -92,10 +95,16 @@ struct TestConvertFuncOp
patterns.add<ReturnOpConversion>(typeConverter);
LLVMConversionTarget target(getContext());
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ std::move(patterns), config)))
signalPassFailure();
}
+
+ Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+ llvm::cl::desc("Allow pattern rollback"),
+ llvm::cl::init(true)};
};
} // namespace
More information about the Mlir-commits
mailing list