[Mlir-commits] [mlir] [mlir][Transforms] Deactivate `replaceAllUsesWith` in dialect conversion (PR #154112)
Matthias Springer
llvmlistbot at llvm.org
Mon Aug 18 05:59:49 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/154112
`RewriterBase` exposes `replaceAllUsesWith` (and variants), even though it is not supported in a dialect conversion. This commit makes these functions virtual, so that they can be deactivated when used in a dialect conversion. `ConversionPatternRewriter::replaceAllUsesWith` will now trigger an LLVM fatal error.
`replaceAllUsesWith` is not supported in a dialect conversion because it bypasses the mapping infrastructure and immediately modifies the IR. This can cause subtle crashes like the one described in https://github.com/llvm/llvm-project/pull/154075#issuecomment-3196361718.
`replaceAllUsesWith` can be safely supported with `allowPatternRollback = false`, but this requires a bit more work. For now, just deactivate the functions entirely for safety.
Note for LLVM integration: If this commit breaks your code, consider rewriting the respective patterns without `replaceAllUsesWith`. Alternatively, you can use `value.replaceAllUsesWith` instead of `rewriter.replaceAllUsesWith`, but be aware that that's an API violation as well.
>From 779412bbe6173f52859f6112657f51e33391bece Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 18 Aug 2025 12:55:10 +0000
Subject: [PATCH] [mlir][Transforms] Deactivate `replaceAllUsesWith` in dialect
conversion
---
mlir/include/mlir/IR/PatternMatch.h | 10 ++++-----
.../mlir/Transforms/DialectConversion.h | 21 +++++++++++++++++++
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 15 ++++++-------
3 files changed, 34 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c1..b7291653b70bb 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,13 +633,13 @@ class RewriterBase : public OpBuilder {
/// Find uses of `from` and replace them with `to`. Also notify the listener
/// about every in-place op modification (for every use that was replaced).
- void replaceAllUsesWith(Value from, Value to) {
+ virtual void replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
}
}
- void replaceAllUsesWith(Block *from, Block *to) {
+ virtual void replaceAllUsesWith(Block *from, Block *to) {
for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
@@ -665,9 +665,9 @@ class RewriterBase : public OpBuilder {
/// true. Also notify the listener about every in-place op modification (for
/// every use that was replaced). The optional `allUsesReplaced` flag is set
/// to "true" if all uses were replaced.
- void replaceUsesWithIf(Value from, Value to,
- function_ref<bool(OpOperand &)> functor,
- bool *allUsesReplaced = nullptr);
+ virtual void replaceUsesWithIf(Value from, Value to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr);
void replaceUsesWithIf(ValueRange from, ValueRange to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 220431e6ee2f1..9341da19905ab 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -784,6 +784,27 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// function supports both 1:1 and 1:N replacements.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
+ /// Replace all the uses of the value `from` with `to`.
+ /// TODO: Currently not supported in a dialect conversion.
+ void replaceAllUsesWith(Value from, Value to) override {
+ llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
+ }
+
+ /// Replace all the uses of the block `from` with `to`.
+ /// TODO: Currently not supported in a dialect conversion.
+ void replaceAllUsesWith(Block *from, Block *to) override {
+ llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
+ }
+
+ /// Replace all the uses of the value `from` with `to` if the `functor`
+ /// returns "true".
+ /// TODO: Currently not supported in a dialect conversion.
+ void replaceUsesWithIf(Value from, Value to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr) override {
+ llvm::report_fatal_error("replaceUsesWithIf is not supported yet");
+ }
+
/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
/// of failure, the remapped value otherwise.
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 34f372af1e4b5..c903016611422 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -22,7 +22,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
@@ -538,15 +538,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/// Applies the conversion patterns in the given function.
static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
- ConversionTarget target(*module.getContext());
- target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
- target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
- memref::MemRefDialect>();
-
RewritePatternSet patterns(module.getContext());
patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
- return applyPartialConversion(module, target, frozen);
+ walkAndApplyPatterns(module, frozen);
+ auto status = module.walk([](Operation *op) {
+ if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ return failure(status.wasInterrupted());
}
/// A pass converting SCF operations to OpenMP operations.
More information about the Mlir-commits
mailing list