[llvm-branch-commits] [mlir] [mlir][Transforms] Remove `replaceAllUsesWith` workaround (PR #169609)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Nov 25 22:04:10 PST 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/169609
Replace a workaround in the implementation of `replaceAllUsesWith` in the no-rollback dialect conversion. This workaround was necessary for `restoreByValRefArgumentType` in the `func-to-llvm` lowering because there was no support for `replaceAllUsesExcept`. Support for this API has been added to the no-rollback driver, so the workaround can be dropped from that driver. The workaround is still in place for the rollback driver.
Depends on #169606.
>From cf32ec20e509cbe10ead86e8911c9c476b96a9c6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 26 Nov 2025 06:01:37 +0000
Subject: [PATCH] [mlir][Transforms] Remove `replaceAllUsesWith` workaround
---
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 12 +++++--
.../Transforms/Utils/DialectConversion.cpp | 33 +++++++------------
.../test/Transforms/test-convert-func-op.mlir | 3 +-
.../FuncToLLVM/TestConvertFuncOp.cpp | 11 ++++++-
4 files changed, 34 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 2220f61ed8a07..ddd94f5d03042 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -283,8 +283,16 @@ static void restoreByValRefArgumentType(
Type resTy = typeConverter.convertType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
- Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
- rewriter.replaceAllUsesWith(arg, valueArg);
+ auto loadOp = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
+ if (!rewriter.getConfig().allowPatternRollback) {
+ rewriter.replaceAllUsesExcept(arg, loadOp, loadOp);
+ } else {
+ // replaceAllUsesExcept is not supported in rollback mode. The rollback
+ // mode implementation has a workaround: certain replacements that would
+ // cause a dominance violation are skipped.
+ // TODO: Remove workaround.
+ rewriter.replaceAllUsesWith(arg, loadOp);
+ }
}
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c9f1596c07cbe..ccc5b7cb6f229 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1205,17 +1205,14 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-/// Replace all uses of `from` with `repl`.
-static void
-performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
- function_ref<bool(OpOperand &)> functor = nullptr) {
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
+ if (!repl)
+ return;
+
if (isa<BlockArgument>(repl)) {
// `repl` is a block argument. Directly replace all uses.
- if (functor) {
- rewriter.replaceUsesWithIf(from, repl, functor);
- } else {
- rewriter.replaceAllUsesWith(from, repl);
- }
+ rewriter.replaceAllUsesWith(value, repl);
return;
}
@@ -1244,23 +1241,14 @@ performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
// `ConversionPatternRewriter` API with the normal `RewriterBase` API.
Operation *replOp = repl.getDefiningOp();
Block *replBlock = replOp->getBlock();
- rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
bool result =
user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
- if (functor)
- result &= functor(operand);
return result;
});
}
-void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
- if (!repl)
- return;
- performReplaceValue(rewriter, value, repl);
-}
-
void ReplaceValueRewrite::rollback() {
rewriterImpl.mapping.erase({value});
#ifndef NDEBUG
@@ -2000,8 +1988,11 @@ void ConversionPatternRewriterImpl::replaceValueUses(
Value repl = repls.front();
if (!repl)
return;
-
- performReplaceValue(r, from, repl, functor);
+ if (functor) {
+ r.replaceUsesWithIf(from, repl, functor);
+ } else {
+ r.replaceAllUsesWith(from, repl);
+ }
return;
}
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 180f16a32991b..14c15ecbe77f0 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=1" --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=0" --split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @add
func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index 75168dde93130..897b11b65b6f2 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -68,6 +68,9 @@ struct TestConvertFuncOp
: public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
+ TestConvertFuncOp() = default;
+ TestConvertFuncOp(const TestConvertFuncOp &other) : PassWrapper(other) {}
+
void getDependentDialects(DialectRegistry ®istry) const final {
registry.insert<LLVM::LLVMDialect>();
}
@@ -92,10 +95,16 @@ struct TestConvertFuncOp
patterns.add<ReturnOpConversion>(typeConverter);
LLVMConversionTarget target(getContext());
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ std::move(patterns), config)))
signalPassFailure();
}
+
+ Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+ llvm::cl::desc("Allow pattern rollback"),
+ llvm::cl::init(true)};
};
} // namespace
More information about the llvm-branch-commits
mailing list