[llvm-branch-commits] [mlir] [mlir][Transforms] Remove `replaceAllUsesWith` workaround (PR #169609)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Nov 25 22:04:40 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Replace a workaround in the implementation of `replaceAllUsesWith` in the no-rollback dialect conversion. This workaround was necessary for `restoreByValRefArgumentType` in the `func-to-llvm` lowering because there was no support for `replaceAllUsesExcept`. Support for this API has been added to the no-rollback driver, so the workaround can be dropped from that driver. The workaround is still in place for the rollback driver.
Depends on #<!-- -->169606.
---
Full diff: https://github.com/llvm/llvm-project/pull/169609.diff
4 Files Affected:
- (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+10-2)
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+12-21)
- (modified) mlir/test/Transforms/test-convert-func-op.mlir (+2-1)
- (modified) mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp (+10-1)
``````````diff
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 2220f61ed8a07..ddd94f5d03042 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -283,8 +283,16 @@ static void restoreByValRefArgumentType(
Type resTy = typeConverter.convertType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
- Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
- rewriter.replaceAllUsesWith(arg, valueArg);
+ auto loadOp = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
+ if (!rewriter.getConfig().allowPatternRollback) {
+ rewriter.replaceAllUsesExcept(arg, loadOp, loadOp);
+ } else {
+ // replaceAllUsesExcept is not supported in rollback mode. The rollback
+ // mode implementation has a workaround: certain replacements that would
+ // cause a dominance violation are skipped.
+ // TODO: Remove workaround.
+ rewriter.replaceAllUsesWith(arg, loadOp);
+ }
}
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c9f1596c07cbe..ccc5b7cb6f229 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1205,17 +1205,14 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-/// Replace all uses of `from` with `repl`.
-static void
-performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
- function_ref<bool(OpOperand &)> functor = nullptr) {
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
+ if (!repl)
+ return;
+
if (isa<BlockArgument>(repl)) {
// `repl` is a block argument. Directly replace all uses.
- if (functor) {
- rewriter.replaceUsesWithIf(from, repl, functor);
- } else {
- rewriter.replaceAllUsesWith(from, repl);
- }
+ rewriter.replaceAllUsesWith(value, repl);
return;
}
@@ -1244,23 +1241,14 @@ performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
// `ConversionPatternRewriter` API with the normal `RewriterBase` API.
Operation *replOp = repl.getDefiningOp();
Block *replBlock = replOp->getBlock();
- rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
bool result =
user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
- if (functor)
- result &= functor(operand);
return result;
});
}
-void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
- if (!repl)
- return;
- performReplaceValue(rewriter, value, repl);
-}
-
void ReplaceValueRewrite::rollback() {
rewriterImpl.mapping.erase({value});
#ifndef NDEBUG
@@ -2000,8 +1988,11 @@ void ConversionPatternRewriterImpl::replaceValueUses(
Value repl = repls.front();
if (!repl)
return;
-
- performReplaceValue(r, from, repl, functor);
+ if (functor) {
+ r.replaceUsesWithIf(from, repl, functor);
+ } else {
+ r.replaceAllUsesWith(from, repl);
+ }
return;
}
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 180f16a32991b..14c15ecbe77f0 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=1" --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=0" --split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @add
func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index 75168dde93130..897b11b65b6f2 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -68,6 +68,9 @@ struct TestConvertFuncOp
: public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
+ TestConvertFuncOp() = default;
+ TestConvertFuncOp(const TestConvertFuncOp &other) : PassWrapper(other) {}
+
void getDependentDialects(DialectRegistry ®istry) const final {
registry.insert<LLVM::LLVMDialect>();
}
@@ -92,10 +95,16 @@ struct TestConvertFuncOp
patterns.add<ReturnOpConversion>(typeConverter);
LLVMConversionTarget target(getContext());
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ std::move(patterns), config)))
signalPassFailure();
}
+
+ Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+ llvm::cl::desc("Allow pattern rollback"),
+ llvm::cl::init(true)};
};
} // namespace
``````````
</details>
https://github.com/llvm/llvm-project/pull/169609
More information about the llvm-branch-commits
mailing list