[Mlir-commits] [mlir] [mlir][Transforms] Deactivate `replaceAllUsesWith` in dialect conversion (PR #154112)

Matthias Springer llvmlistbot at llvm.org
Mon Aug 18 05:59:49 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/154112

`RewriterBase` exposes `replaceAllUsesWith` (and variants), even though it is not supported in a dialect conversion. This commit makes these functions virtual, so that they can be deactivated when used in a dialect conversion. `ConversionPatternRewriter::replaceAllUsesWith` will now trigger an LLVM fatal error.

`replaceAllUsesWith` is not supported in a dialect conversion because it bypasses the mapping infrastructure and immediately modifies the IR. This can cause subtle crashes like the one described in https://github.com/llvm/llvm-project/pull/154075#issuecomment-3196361718.

`replaceAllUsesWith` can be safely supported with `allowPatternRollback = false`, but this requires a bit more work. For now, just deactivate the functions entirely for safety.

Note for LLVM integration: If this commit breaks your code, consider rewriting the respective patterns without `replaceAllUsesWith`. Alternatively, you can use `value.replaceAllUsesWith` instead of `rewriter.replaceAllUsesWith`, but be aware that that's an API violation as well.


>From 779412bbe6173f52859f6112657f51e33391bece Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 18 Aug 2025 12:55:10 +0000
Subject: [PATCH] [mlir][Transforms] Deactivate `replaceAllUsesWith` in dialect
 conversion

---
 mlir/include/mlir/IR/PatternMatch.h           | 10 ++++-----
 .../mlir/Transforms/DialectConversion.h       | 21 +++++++++++++++++++
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    | 15 ++++++-------
 3 files changed, 34 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c1..b7291653b70bb 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,13 +633,13 @@ class RewriterBase : public OpBuilder {
 
   /// Find uses of `from` and replace them with `to`. Also notify the listener
   /// about every in-place op modification (for every use that was replaced).
-  void replaceAllUsesWith(Value from, Value to) {
+  virtual void replaceAllUsesWith(Value from, Value to) {
     for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
       Operation *op = operand.getOwner();
       modifyOpInPlace(op, [&]() { operand.set(to); });
     }
   }
-  void replaceAllUsesWith(Block *from, Block *to) {
+  virtual void replaceAllUsesWith(Block *from, Block *to) {
     for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
       Operation *op = operand.getOwner();
       modifyOpInPlace(op, [&]() { operand.set(to); });
@@ -665,9 +665,9 @@ class RewriterBase : public OpBuilder {
   /// true. Also notify the listener about every in-place op modification (for
   /// every use that was replaced). The optional `allUsesReplaced` flag is set
   /// to "true" if all uses were replaced.
-  void replaceUsesWithIf(Value from, Value to,
-                         function_ref<bool(OpOperand &)> functor,
-                         bool *allUsesReplaced = nullptr);
+  virtual void replaceUsesWithIf(Value from, Value to,
+                                 function_ref<bool(OpOperand &)> functor,
+                                 bool *allUsesReplaced = nullptr);
   void replaceUsesWithIf(ValueRange from, ValueRange to,
                          function_ref<bool(OpOperand &)> functor,
                          bool *allUsesReplaced = nullptr);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 220431e6ee2f1..9341da19905ab 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -784,6 +784,27 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// function supports both 1:1 and 1:N replacements.
   void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
 
+  /// Replace all the uses of the value `from` with `to`.
+  /// TODO: Currently not supported in a dialect conversion.
+  void replaceAllUsesWith(Value from, Value to) override {
+    llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
+  }
+
+  /// Replace all the uses of the block `from` with `to`.
+  /// TODO: Currently not supported in a dialect conversion.
+  void replaceAllUsesWith(Block *from, Block *to) override {
+    llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
+  }
+
+  /// Replace all the uses of the value `from` with `to` if the `functor`
+  /// returns "true".
+  /// TODO: Currently not supported in a dialect conversion.
+  void replaceUsesWithIf(Value from, Value to,
+                         function_ref<bool(OpOperand &)> functor,
+                         bool *allUsesReplaced = nullptr) override {
+    llvm::report_fatal_error("replaceUsesWithIf is not supported yet");
+  }
+
   /// Return the converted value of 'key' with a type defined by the type
   /// converter of the currently executing pattern. Return nullptr in the case
   /// of failure, the remapped value otherwise.
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 34f372af1e4b5..c903016611422 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -22,7 +22,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
@@ -538,15 +538,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 
 /// Applies the conversion patterns in the given function.
 static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
-  ConversionTarget target(*module.getContext());
-  target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
-  target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
-                         memref::MemRefDialect>();
-
   RewritePatternSet patterns(module.getContext());
   patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
   FrozenRewritePatternSet frozen(std::move(patterns));
-  return applyPartialConversion(module, target, frozen);
+  walkAndApplyPatterns(module, frozen);
+  auto status = module.walk([](Operation *op) {
+    if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op))
+      return WalkResult::interrupt();
+    return WalkResult::advance();
+  });
+  return failure(status.wasInterrupted());
 }
 
 /// A pass converting SCF operations to OpenMP operations.



More information about the Mlir-commits mailing list