[Mlir-commits] [flang] [mlir] [mlir][Transforms] Add support for `ConversionPatternRewriter::replaceAllUsesWith` (PR #155244)

Matthias Springer llvmlistbot at llvm.org
Sat Aug 30 09:22:30 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/155244

>From 97393ddd20863052a40c1e2e4d7029e2f4f6156b Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 30 Aug 2025 09:39:31 +0000
Subject: [PATCH 1/3] [flang] Do not use dialect conversion in
 `DoConcurrentConversionPass`

---
 .../OpenMP/DoConcurrentConversion.cpp         | 71 +++++++++++--------
 .../Transforms/DoConcurrent/basic_device.mlir |  2 +-
 .../Transforms/DoConcurrent/basic_host.f90    |  7 +-
 .../Transforms/DoConcurrent/basic_host.mlir   |  6 +-
 .../locality_specifiers_simple.mlir           |  2 +-
 .../multiple_iteration_ranges.f90             | 17 +++--
 .../DoConcurrent/non_const_bounds.f90         |  3 +-
 .../Transforms/DoConcurrent/reduce_add.mlir   |  4 +-
 .../DoConcurrent/reduce_all_regions.mlir      |  2 +-
 .../Transforms/DoConcurrent/reduce_local.mlir |  4 +-
 10 files changed, 65 insertions(+), 53 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index c928b76065ade..39d30400a47dc 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -15,7 +15,7 @@
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/IRMapping.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/RegionUtils.h"
 
 namespace flangomp {
@@ -161,27 +161,34 @@ void collectLoopLocalValues(fir::DoConcurrentLoopOp loop,
 ///
 /// \param rewriter - builder used for updating \p allocRegion.
 static void localizeLoopLocalValue(mlir::Value local, mlir::Region &allocRegion,
-                                   mlir::ConversionPatternRewriter &rewriter) {
+                                   mlir::PatternRewriter &rewriter) {
   rewriter.moveOpBefore(local.getDefiningOp(), &allocRegion.front().front());
 }
 } // namespace looputils
 
 class DoConcurrentConversion
-    : public mlir::OpConversionPattern<fir::DoConcurrentOp> {
+    : public mlir::OpRewritePattern<fir::DoConcurrentOp> {
 public:
-  using mlir::OpConversionPattern<fir::DoConcurrentOp>::OpConversionPattern;
+  using mlir::OpRewritePattern<fir::DoConcurrentOp>::OpRewritePattern;
 
   DoConcurrentConversion(
       mlir::MLIRContext *context, bool mapToDevice,
       llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip,
       mlir::SymbolTable &moduleSymbolTable)
-      : OpConversionPattern(context), mapToDevice(mapToDevice),
+      : OpRewritePattern(context), mapToDevice(mapToDevice),
         concurrentLoopsToSkip(concurrentLoopsToSkip),
         moduleSymbolTable(moduleSymbolTable) {}
 
   mlir::LogicalResult
-  matchAndRewrite(fir::DoConcurrentOp doLoop, OpAdaptor adaptor,
-                  mlir::ConversionPatternRewriter &rewriter) const override {
+  matchAndRewrite(fir::DoConcurrentOp doLoop,
+                  mlir::PatternRewriter &rewriter) const override {
+    // TODO: This pass should use "walkAndApplyPatterns", but that driver does
+    // not support pre-order traversals yet.
+    if (doLoop->getParentOfType<fir::DoConcurrentOp>())
+      return rewriter.notifyMatchFailure(
+          doLoop, "skipping op to enforce pre-order traversal");
+    if (concurrentLoopsToSkip.contains(doLoop))
+      return rewriter.notifyMatchFailure(doLoop, "skipping concurrent loop");
     if (mapToDevice)
       return doLoop.emitError(
           "not yet implemented: Mapping `do concurrent` loops to device");
@@ -231,9 +238,8 @@ class DoConcurrentConversion
       rewriter.moveOpBefore(op, allocBlock, allocBlock->begin());
     }
 
-    // Mark `unordered` loops that are not perfectly nested to be skipped from
-    // the legality check of the `ConversionTarget` since we are not interested
-    // in mapping them to OpenMP.
+    // Mark `unordered` loops that are not perfectly nested to be skipped since
+    // we are not interested in mapping them to OpenMP.
     ompLoopNest->walk([&](fir::DoConcurrentOp doLoop) {
       concurrentLoopsToSkip.insert(doLoop);
     });
@@ -245,7 +251,7 @@ class DoConcurrentConversion
 
 private:
   mlir::omp::ParallelOp
-  genParallelOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
+  genParallelOp(mlir::Location loc, mlir::PatternRewriter &rewriter,
                 looputils::InductionVariableInfos &ivInfos,
                 mlir::IRMapping &mapper) const {
     auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc);
@@ -256,7 +262,7 @@ class DoConcurrentConversion
     return parallelOp;
   }
 
-  void genLoopNestIndVarAllocs(mlir::ConversionPatternRewriter &rewriter,
+  void genLoopNestIndVarAllocs(mlir::PatternRewriter &rewriter,
                                looputils::InductionVariableInfos &ivInfos,
                                mlir::IRMapping &mapper) const {
 
@@ -264,10 +270,9 @@ class DoConcurrentConversion
       genInductionVariableAlloc(rewriter, indVarInfo.iterVarMemDef, mapper);
   }
 
-  mlir::Operation *
-  genInductionVariableAlloc(mlir::ConversionPatternRewriter &rewriter,
-                            mlir::Operation *indVarMemDef,
-                            mlir::IRMapping &mapper) const {
+  mlir::Operation *genInductionVariableAlloc(mlir::PatternRewriter &rewriter,
+                                             mlir::Operation *indVarMemDef,
+                                             mlir::IRMapping &mapper) const {
     assert(
         indVarMemDef != nullptr &&
         "Induction variable memdef is expected to have a defining operation.");
@@ -285,8 +290,7 @@ class DoConcurrentConversion
   }
 
   void
-  genLoopNestClauseOps(mlir::Location loc,
-                       mlir::ConversionPatternRewriter &rewriter,
+  genLoopNestClauseOps(mlir::Location loc, mlir::PatternRewriter &rewriter,
                        fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
                        mlir::omp::LoopNestOperands &loopNestClauseOps) const {
     assert(loopNestClauseOps.loopLowerBounds.empty() &&
@@ -308,8 +312,8 @@ class DoConcurrentConversion
   }
 
   mlir::omp::LoopNestOp
-  genWsLoopOp(mlir::ConversionPatternRewriter &rewriter,
-              fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
+  genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoConcurrentLoopOp loop,
+              mlir::IRMapping &mapper,
               const mlir::omp::LoopNestOperands &clauseOps,
               bool isComposite) const {
     mlir::omp::WsloopOperands wsloopClauseOps;
@@ -472,18 +476,25 @@ class DoConcurrentConversionPass
     patterns.insert<DoConcurrentConversion>(
         context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
         concurrentLoopsToSkip, moduleSymbolTable);
-    mlir::ConversionTarget target(*context);
-    target.addDynamicallyLegalOp<fir::DoConcurrentOp>(
-        [&](fir::DoConcurrentOp op) {
-          return concurrentLoopsToSkip.contains(op);
-        });
-    target.markUnknownOpDynamicallyLegal(
-        [](mlir::Operation *) { return true; });
-
-    if (mlir::failed(
-            mlir::applyFullConversion(module, target, std::move(patterns)))) {
+
+    // TODO: This pass should use "walkAndApplyPatterns", but that driver does
+    // not support pre-order traversals yet.
+    if (mlir::failed(applyPatternsGreedily(module.getOperation(),
+                                           std::move(patterns)))) {
+      module.emitError("failed to apply patterns");
       signalPassFailure();
     }
+
+    // Make sure that all loops were converted.
+    mlir::WalkResult status = module->walk([&](fir::DoConcurrentOp op) {
+      if (concurrentLoopsToSkip.contains(op))
+        return mlir::WalkResult::advance();
+
+      op.emitError("failed to convert operation");
+      return mlir::WalkResult::interrupt();
+    });
+    if (status.wasInterrupted())
+      signalPassFailure();
   }
 };
 } // namespace
diff --git a/flang/test/Transforms/DoConcurrent/basic_device.mlir b/flang/test/Transforms/DoConcurrent/basic_device.mlir
index 0ca48943864c8..b88522a6307b8 100644
--- a/flang/test/Transforms/DoConcurrent/basic_device.mlir
+++ b/flang/test/Transforms/DoConcurrent/basic_device.mlir
@@ -12,7 +12,7 @@ func.func @do_concurrent_basic() attributes {fir.bindc_name = "do_concurrent_bas
     %c1 = arith.constant 1 : index
 
     // expected-error at +2 {{not yet implemented: Mapping `do concurrent` loops to device}}
-    // expected-error at below {{failed to legalize operation 'fir.do_concurrent'}}
+    // expected-error at below {{failed to convert operation}}
     fir.do_concurrent {
       %0 = fir.alloca i32 {bindc_name = "i"}
       %1:2 = hlfir.declare %0 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
diff --git a/flang/test/Transforms/DoConcurrent/basic_host.f90 b/flang/test/Transforms/DoConcurrent/basic_host.f90
index 6f24b346e3fb9..252be38ac8fd9 100644
--- a/flang/test/Transforms/DoConcurrent/basic_host.f90
+++ b/flang/test/Transforms/DoConcurrent/basic_host.f90
@@ -7,6 +7,10 @@
  
 ! CHECK-LABEL: DO_CONCURRENT_BASIC
 program do_concurrent_basic
+    ! CHECK: %[[C1:.*]] = arith.constant 1 : i32
+    ! CHECK: %[[C10:.*]] = arith.constant 10 : i32
+    ! CHECK: %[[STEP:.*]] = arith.constant 1 : index
+
     ! CHECK: %[[ARR:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
 
     implicit none
@@ -15,11 +19,8 @@ program do_concurrent_basic
 
     ! CHECK-NOT: fir.do_loop
 
-    ! CHECK: %[[C1:.*]] = arith.constant 1 : i32
     ! CHECK: %[[LB:.*]] = fir.convert %[[C1]] : (i32) -> index
-    ! CHECK: %[[C10:.*]] = arith.constant 10 : i32
     ! CHECK: %[[UB:.*]] = fir.convert %[[C10]] : (i32) -> index
-    ! CHECK: %[[STEP:.*]] = arith.constant 1 : index
 
     ! CHECK: omp.parallel {
 
diff --git a/flang/test/Transforms/DoConcurrent/basic_host.mlir b/flang/test/Transforms/DoConcurrent/basic_host.mlir
index 5425829404d7b..34d5c26c88d25 100644
--- a/flang/test/Transforms/DoConcurrent/basic_host.mlir
+++ b/flang/test/Transforms/DoConcurrent/basic_host.mlir
@@ -4,6 +4,9 @@
 
 // CHECK-LABEL: func.func @do_concurrent_basic
 func.func @do_concurrent_basic() attributes {fir.bindc_name = "do_concurrent_basic"} {
+    // CHECK: %[[C1:.*]] = arith.constant 1 : i32
+    // CHECK: %[[C10:.*]] = arith.constant 10 : i32
+    // CHECK: %[[STEP:.*]] = arith.constant 1 : index
     // CHECK: %[[ARR:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFEa"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
 
     %2 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xi32>>
@@ -18,11 +21,8 @@ func.func @do_concurrent_basic() attributes {fir.bindc_name = "do_concurrent_bas
 
     // CHECK-NOT: fir.do_concurrent
 
-    // CHECK: %[[C1:.*]] = arith.constant 1 : i32
     // CHECK: %[[LB:.*]] = fir.convert %[[C1]] : (i32) -> index
-    // CHECK: %[[C10:.*]] = arith.constant 10 : i32
     // CHECK: %[[UB:.*]] = fir.convert %[[C10]] : (i32) -> index
-    // CHECK: %[[STEP:.*]] = arith.constant 1 : index
 
     // CHECK: omp.parallel {
 
diff --git a/flang/test/Transforms/DoConcurrent/locality_specifiers_simple.mlir b/flang/test/Transforms/DoConcurrent/locality_specifiers_simple.mlir
index 160c1df040680..49fa7441ab311 100644
--- a/flang/test/Transforms/DoConcurrent/locality_specifiers_simple.mlir
+++ b/flang/test/Transforms/DoConcurrent/locality_specifiers_simple.mlir
@@ -33,13 +33,13 @@ func.func @_QPlocal_spec_translation() {
 // CHECK: omp.private {type = private} @[[PRIVATIZER:.*local_spec_translationElocal_var.*.omp]] : f32
 
 // CHECK: func.func @_QPlocal_spec_translation
+// CHECK:   %[[C42:.*]] = arith.constant 4.200000e+01 : f32
 // CHECK:   %[[LOCAL_VAR:.*]] = fir.alloca f32 {bindc_name = "local_var", {{.*}}}
 // CHECK:   %[[LOCAL_VAR_DECL:.*]]:2 = hlfir.declare %[[LOCAL_VAR]]
 // CHECK:   omp.parallel {
 // CHECK:     omp.wsloop private(@[[PRIVATIZER]] %[[LOCAL_VAR_DECL]]#0 -> %[[LOCAL_ARG:.*]] : !fir.ref<f32>) {
 // CHECK:       omp.loop_nest {{.*}} {
 // CHECK:       %[[PRIV_DECL:.*]]:2 = hlfir.declare %[[LOCAL_ARG]]
-// CHECK:       %[[C42:.*]] = arith.constant
 // CHECK:       hlfir.assign %[[C42]] to %[[PRIV_DECL]]#0
 // CHECK:       omp.yield
 // CHECK:     }
diff --git a/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 b/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90
index d0210726de83e..d40c1892820f1 100644
--- a/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90
+++ b/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90
@@ -20,22 +20,21 @@ program main
 ! CHECK: func.func @_QQmain
 
 ! CHECK: %[[C3:.*]] = arith.constant 3 : i32
-! CHECK: %[[LB_I:.*]] = fir.convert %[[C3]] : (i32) -> index
 ! CHECK: %[[C20:.*]] = arith.constant 20 : i32
+! CHECK: %[[C1:.*]] = arith.constant 1 : index
+! CHECK: %[[C5:.*]] = arith.constant 5 : i32
+! CHECK: %[[C40:.*]] = arith.constant 40 : i32
+! CHECK: %[[C7:.*]] = arith.constant 7 : i32
+! CHECK: %[[C60:.*]] = arith.constant 60 : i32
+
+! CHECK: %[[LB_I:.*]] = fir.convert %[[C3]] : (i32) -> index
 ! CHECK: %[[UB_I:.*]] = fir.convert %[[C20]] : (i32) -> index
-! CHECK: %[[STEP_I:.*]] = arith.constant 1 : index
 
-! CHECK: %[[C5:.*]] = arith.constant 5 : i32
 ! CHECK: %[[LB_J:.*]] = fir.convert %[[C5]] : (i32) -> index
-! CHECK: %[[C40:.*]] = arith.constant 40 : i32
 ! CHECK: %[[UB_J:.*]] = fir.convert %[[C40]] : (i32) -> index
-! CHECK: %[[STEP_J:.*]] = arith.constant 1 : index
 
-! CHECK: %[[C7:.*]] = arith.constant 7 : i32
 ! CHECK: %[[LB_K:.*]] = fir.convert %[[C7]] : (i32) -> index
-! CHECK: %[[C60:.*]] = arith.constant 60 : i32
 ! CHECK: %[[UB_K:.*]] = fir.convert %[[C60]] : (i32) -> index
-! CHECK: %[[STEP_K:.*]] = arith.constant 1 : index
 
 ! CHECK: omp.parallel {
 
@@ -53,7 +52,7 @@ program main
 ! CHECK-SAME:   (%[[ARG0:[^[:space:]]+]], %[[ARG1:[^[:space:]]+]], %[[ARG2:[^[:space:]]+]])
 ! CHECK-SAME:   : index = (%[[LB_I]], %[[LB_J]], %[[LB_K]])
 ! CHECK-SAME:     to (%[[UB_I]], %[[UB_J]], %[[UB_K]]) inclusive
-! CHECK-SAME:     step (%[[STEP_I]], %[[STEP_J]], %[[STEP_K]]) {
+! CHECK-SAME:     step (%[[C1]], %[[C1]], %[[C1]]) {
 
 ! CHECK-NEXT: %[[IV_IDX_I:.*]] = fir.convert %[[ARG0]]
 ! CHECK-NEXT: fir.store %[[IV_IDX_I]] to %[[BINDING_I]]#0
diff --git a/flang/test/Transforms/DoConcurrent/non_const_bounds.f90 b/flang/test/Transforms/DoConcurrent/non_const_bounds.f90
index cd1bd4f98a3f5..316ef18ca836f 100644
--- a/flang/test/Transforms/DoConcurrent/non_const_bounds.f90
+++ b/flang/test/Transforms/DoConcurrent/non_const_bounds.f90
@@ -20,6 +20,8 @@ subroutine foo(n)
 
 end program main
 
+! CHECK: %[[C1:.*]] = arith.constant 1 : index
+
 ! CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %{{.*}} dummy_scope %{{.*}} {uniq_name = "_QFFfooEn"}
 
 ! CHECK: fir.load
@@ -27,7 +29,6 @@ end program main
 ! CHECK: %[[LB:.*]] = fir.convert %{{c1_.*}} : (i32) -> index
 ! CHECK: %[[N_VAL:.*]] = fir.load %[[N_DECL]]#0 : !fir.ref<i32>
 ! CHECK: %[[UB:.*]] = fir.convert %[[N_VAL]] : (i32) -> index
-! CHECK: %[[C1:.*]] = arith.constant 1 : index
 
 ! CHECK: omp.parallel {
 
diff --git a/flang/test/Transforms/DoConcurrent/reduce_add.mlir b/flang/test/Transforms/DoConcurrent/reduce_add.mlir
index 1ea3e3e527335..bf9ce75e6c978 100644
--- a/flang/test/Transforms/DoConcurrent/reduce_add.mlir
+++ b/flang/test/Transforms/DoConcurrent/reduce_add.mlir
@@ -44,11 +44,12 @@ func.func @_QPdo_concurrent_reduce() {
 // CHECK:         }
 
 // CHECK-LABEL:   func.func @_QPdo_concurrent_reduce() {
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_12:.*]] = arith.constant 1 : i32
 // CHECK:           %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "i"}
 // CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdo_concurrent_reduceEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 // CHECK:           %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "s", uniq_name = "_QFdo_concurrent_reduceEs"}
 // CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFdo_concurrent_reduceEs"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
 // CHECK:           omp.parallel {
 // CHECK:             %[[VAL_5:.*]] = fir.alloca i32 {bindc_name = "i"}
 // CHECK:             %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]] {uniq_name = "_QFdo_concurrent_reduceEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
@@ -59,7 +60,6 @@ func.func @_QPdo_concurrent_reduce() {
 // CHECK:                 fir.store %[[VAL_9]] to %[[VAL_6]]#0 : !fir.ref<i32>
 // CHECK:                 %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_7]] {uniq_name = "_QFdo_concurrent_reduceEs"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 // CHECK:                 %[[VAL_11:.*]] = fir.load %[[VAL_10]]#0 : !fir.ref<i32>
-// CHECK:                 %[[VAL_12:.*]] = arith.constant 1 : i32
 // CHECK:                 %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32
 // CHECK:                 hlfir.assign %[[VAL_13]] to %[[VAL_10]]#0 : i32, !fir.ref<i32>
 // CHECK:                 omp.yield
diff --git a/flang/test/Transforms/DoConcurrent/reduce_all_regions.mlir b/flang/test/Transforms/DoConcurrent/reduce_all_regions.mlir
index 3d5b8bf22af75..815c3cfdd1a24 100644
--- a/flang/test/Transforms/DoConcurrent/reduce_all_regions.mlir
+++ b/flang/test/Transforms/DoConcurrent/reduce_all_regions.mlir
@@ -49,11 +49,11 @@ func.func @_QPdo_concurrent_reduce() {
 // CHECK:         }
 
 // CHECK-LABEL:   func.func @_QPdo_concurrent_reduce() {
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "i"}
 // CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdo_concurrent_reduceEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 // CHECK:           %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "s", uniq_name = "_QFdo_concurrent_reduceEs"}
 // CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFdo_concurrent_reduceEs"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
 // CHECK:           omp.parallel {
 // CHECK:             %[[VAL_5:.*]] = fir.alloca i32 {bindc_name = "i"}
 // CHECK:             %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]] {uniq_name = "_QFdo_concurrent_reduceEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
diff --git a/flang/test/Transforms/DoConcurrent/reduce_local.mlir b/flang/test/Transforms/DoConcurrent/reduce_local.mlir
index 0f667109e6e83..588d95f957a7d 100644
--- a/flang/test/Transforms/DoConcurrent/reduce_local.mlir
+++ b/flang/test/Transforms/DoConcurrent/reduce_local.mlir
@@ -51,13 +51,14 @@ fir.declare_reduction @add_reduction_i32 : i32 init {
 // CHECK:         omp.private {type = private} @_QFdo_concurrent_reduceEl_private_i32.omp : i32
 
 // CHECK-LABEL:   func.func @_QPdo_concurrent_reduce() {
+// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_15:.*]] = arith.constant 1 : i32
 // CHECK:           %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "i"}
 // CHECK:           %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdo_concurrent_reduceEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 // CHECK:           %[[VAL_2:.*]] = fir.alloca i32 {bindc_name = "l", uniq_name = "_QFdo_concurrent_reduceEl"}
 // CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_2]] {uniq_name = "_QFdo_concurrent_reduceEl"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 // CHECK:           %[[VAL_4:.*]] = fir.alloca i32 {bindc_name = "s", uniq_name = "_QFdo_concurrent_reduceEs"}
 // CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = "_QFdo_concurrent_reduceEs"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
 // CHECK:           omp.parallel {
 // CHECK:             %[[VAL_7:.*]] = fir.alloca i32 {bindc_name = "i"}
 // CHECK:             %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_7]] {uniq_name = "_QFdo_concurrent_reduceEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
@@ -67,7 +68,6 @@ fir.declare_reduction @add_reduction_i32 : i32 init {
 // CHECK:                 fir.store %[[VAL_12]] to %[[VAL_8]]#0 : !fir.ref<i32>
 // CHECK:                 %[[VAL_13:.*]]:2 = hlfir.declare %[[VAL_9]] {uniq_name = "_QFdo_concurrent_reduceEl"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
 // CHECK:                 %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_10]] {uniq_name = "_QFdo_concurrent_reduceEs"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-// CHECK:                 %[[VAL_15:.*]] = arith.constant 1 : i32
 // CHECK:                 hlfir.assign %[[VAL_15]] to %[[VAL_13]]#0 : i32, !fir.ref<i32>
 // CHECK:                 %[[VAL_16:.*]] = fir.load %[[VAL_14]]#0 : !fir.ref<i32>
 // CHECK:                 %[[VAL_17:.*]] = fir.load %[[VAL_13]]#0 : !fir.ref<i32>

>From f4dbd0d6bdb3131ab5747691b21305ee17dbdaec Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 30 Aug 2025 10:35:08 +0000
Subject: [PATCH 2/3] [flang] Do not use dialect conversion in
 `AffineDialectPromotion`

---
 .../Optimizer/Transforms/AffinePromotion.cpp  | 33 +++++++------------
 1 file changed, 11 insertions(+), 22 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
index b032767eef6f0..061a7d201edd3 100644
--- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
+++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
@@ -25,7 +25,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Visitors.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
@@ -451,10 +451,10 @@ static void rewriteStore(fir::StoreOp storeOp,
 }
 
 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) {
-  for (auto &bodyOp : block->getOperations()) {
+  for (auto &bodyOp : llvm::make_early_inc_range(block->getOperations())) {
     if (isa<fir::LoadOp>(bodyOp))
       rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter);
-    if (isa<fir::StoreOp>(bodyOp))
+    else if (isa<fir::StoreOp>(bodyOp))
       rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter);
   }
 }
@@ -476,6 +476,8 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
                loop.dump(););
     LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
         functionAnalysis.getChildLoopAnalysis(loop);
+    if (!loopAnalysis.canPromoteToAffine())
+      return rewriter.notifyMatchFailure(loop, "cannot promote to affine");
     auto &loopOps = loop.getBody()->getOperations();
     auto resultOp = cast<fir::ResultOp>(loop.getBody()->getTerminator());
     auto results = resultOp.getOperands();
@@ -576,12 +578,14 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
   AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
-      : OpRewritePattern(context) {}
+      : OpRewritePattern(context), functionAnalysis(afa) {}
   llvm::LogicalResult
   matchAndRewrite(fir::IfOp op,
                   mlir::PatternRewriter &rewriter) const override {
     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n";
                op.dump(););
+    if (!functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine())
+      return rewriter.notifyMatchFailure(op, "cannot promote to affine");
     auto &ifOps = op.getThenRegion().front().getOperations();
     auto affineCondition = AffineIfCondition(op.getCondition());
     if (!affineCondition.hasIntegerSet()) {
@@ -611,6 +615,8 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
     rewriter.replaceOp(op, affineIf.getOperation()->getResults());
     return success();
   }
+
+  AffineFunctionAnalysis &functionAnalysis;
 };
 
 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
@@ -627,28 +633,11 @@ class AffineDialectPromotion
     mlir::RewritePatternSet patterns(context);
     patterns.insert<AffineIfConversion>(context, functionAnalysis);
     patterns.insert<AffineLoopConversion>(context, functionAnalysis);
-    mlir::ConversionTarget target = *context;
-    target.addLegalDialect<mlir::affine::AffineDialect, FIROpsDialect,
-                           mlir::scf::SCFDialect, mlir::arith::ArithDialect,
-                           mlir::func::FuncDialect>();
-    target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) {
-      return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine());
-    });
-    target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis](
-                                               fir::DoLoopOp op) {
-      return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine());
-    });
-
     LLVM_DEBUG(llvm::dbgs()
                    << "AffineDialectPromotion: running promotion on: \n";
                function.print(llvm::dbgs()););
     // apply the patterns
-    if (mlir::failed(mlir::applyPartialConversion(function, target,
-                                                  std::move(patterns)))) {
-      mlir::emitError(mlir::UnknownLoc::get(context),
-                      "error in converting to affine dialect\n");
-      signalPassFailure();
-    }
+    walkAndApplyPatterns(function, std::move(patterns));
   }
 };
 } // namespace

>From 93ad0c49462e0cd7f9abefde244fae0eade616eb Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 23 Aug 2025 10:36:37 +0000
Subject: [PATCH 3/3] [mlir][Transforms] Add support for
 `ConversionPatternRewriter::replaceAllUsesWith`

---
 mlir/include/mlir/IR/PatternMatch.h           |   2 +-
 .../mlir/Transforms/DialectConversion.h       |  17 +-
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp |   2 +-
 .../Transforms/Utils/DialectConversion.cpp    | 158 +++++++++++-------
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |   5 +-
 5 files changed, 112 insertions(+), 72 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c1..7b0b9cef9c5bd 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder {
 
   /// Find uses of `from` and replace them with `to`. Also notify the listener
   /// about every in-place op modification (for every use that was replaced).
-  void replaceAllUsesWith(Value from, Value to) {
+  virtual void replaceAllUsesWith(Value from, Value to) {
     for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
       Operation *op = operand.getOwner();
       modifyOpInPlace(op, [&]() { operand.set(to); });
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 14dfbf18836c6..1a4e4a3657e95 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -854,15 +854,18 @@ class ConversionPatternRewriter final : public PatternRewriter {
       Region *region, const TypeConverter &converter,
       TypeConverter::SignatureConversion *entryConversion = nullptr);
 
-  /// Replace all the uses of the block argument `from` with `to`. This
-  /// function supports both 1:1 and 1:N replacements.
+  /// Replace all the uses of `from` with `to`. This function supports both 1:1
+  /// and 1:N replacements.
   ///
   /// Note: If `allowPatternRollback` is set to "true", this function replaces
-  /// all current and future uses of the block argument. This same block
-  /// block argument must not be replaced multiple times. Uses are not replaced
-  /// immediately but in a delayed fashion. Patterns may still see the original
-  /// uses when inspecting IR.
-  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
+  /// all current and future uses of the `from` value. This same value must not
+  /// be replaced multiple times. Uses are not replaced immediately but in a
+  /// delayed fashion. Patterns may still see the original uses when inspecting
+  /// IR.
+  void replaceAllUsesWith(Value from, ValueRange to);
+  void replaceAllUsesWith(Value from, Value to) override {
+    replaceAllUsesWith(from, ValueRange{to});
+  }
 
   /// Return the converted value of 'key' with a type defined by the type
   /// converter of the currently executing pattern. Return nullptr in the case
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 42c76ed475b4c..93fe2edad5274 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -284,7 +284,7 @@ static void restoreByValRefArgumentType(
         cast<TypeAttr>(byValRefAttr->getValue()).getValue());
 
     Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
-    rewriter.replaceUsesOfBlockArgument(arg, valueArg);
+    rewriter.replaceAllUsesWith(arg, valueArg);
   }
 }
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5ba109d96cf13..d72429298754f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -277,13 +277,14 @@ class IRRewrite {
     InlineBlock,
     MoveBlock,
     BlockTypeConversion,
-    ReplaceBlockArg,
     // Operation rewrites
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
     CreateOperation,
-    UnresolvedMaterialization
+    UnresolvedMaterialization,
+    // Value rewrites
+    ReplaceValue
   };
 
   virtual ~IRRewrite() = default;
@@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::CreateBlock &&
-           rewrite->getKind() <= Kind::ReplaceBlockArg;
+           rewrite->getKind() <= Kind::BlockTypeConversion;
   }
 
 protected:
@@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite {
   Block *block;
 };
 
+/// A value rewrite.
+class ValueRewrite : public IRRewrite {
+public:
+  /// Return the value that this rewrite operates on.
+  Value getValue() const { return value; }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::ReplaceValue;
+  }
+
+protected:
+  ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+               Value value)
+      : IRRewrite(kind, rewriterImpl), value(value) {}
+
+  // The value that this rewrite operates on.
+  Value value;
+};
+
 /// Creation of a block. Block creations are immediately reflected in the IR.
 /// There is no extra work to commit the rewrite. During rollback, the newly
 /// created block is erased.
@@ -548,19 +568,18 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   Block *newBlock;
 };
 
-/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// Replacing a value. This rewrite is not immediately reflected in the
 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
 /// until the rewrite is committed.
-class ReplaceBlockArgRewrite : public BlockRewrite {
+class ReplaceValueRewrite : public ValueRewrite {
 public:
-  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                         Block *block, BlockArgument arg,
-                         const TypeConverter *converter)
-      : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
+  ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
+                      const TypeConverter *converter)
+      : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
         converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
-    return rewrite->getKind() == Kind::ReplaceBlockArg;
+    return rewrite->getKind() == Kind::ReplaceValue;
   }
 
   void commit(RewriterBase &rewriter) override;
@@ -568,9 +587,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
   void rollback() override;
 
 private:
-  BlockArgument arg;
-
-  /// The current type converter when the block argument was replaced.
+  /// The current type converter when the value was replaced.
   const TypeConverter *converter;
 };
 
@@ -940,10 +957,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// uses.
   void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
 
-  /// Replace the given block argument with the given values. The specified
+  /// Replace the uses of the given value with the given values. The specified
   /// converter is used to build materializations (if necessary).
-  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
-                                  const TypeConverter *converter);
+  void replaceAllUsesWith(Value from, ValueRange to,
+                          const TypeConverter *converter);
 
   /// Erase the given block and its contents.
   void eraseBlock(Block *block);
@@ -1129,10 +1146,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   IRRewriter notifyingRewriter;
 
 #ifndef NDEBUG
-  /// A set of replaced block arguments. This set is for debugging purposes
-  /// only and it is maintained only if `allowPatternRollback` is set to
-  /// "true".
-  DenseSet<BlockArgument> replacedArgs;
+  /// A set of replaced values. This set is for debugging purposes only and it
+  /// is maintained only if `allowPatternRollback` is set to "true".
+  DenseSet<Value> replacedValues;
 
   /// A set of operations that have pending updates. This tracking isn't
   /// strictly necessary, and is thus only active during debug builds for extra
@@ -1169,32 +1185,54 @@ void BlockTypeConversionRewrite::rollback() {
   getNewBlock()->replaceAllUsesWith(getOrigBlock());
 }
 
-static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
-                                   Value repl) {
+/// Replace all uses of `from` with `repl`.
+static void performReplaceValue(RewriterBase &rewriter, Value from,
+                                Value repl) {
   if (isa<BlockArgument>(repl)) {
-    rewriter.replaceAllUsesWith(arg, repl);
+    // `repl` is a block argument. Directly replace all uses.
+    rewriter.replaceAllUsesWith(from, repl);
     return;
   }
 
-  // If the replacement value is an operation, we check to make sure that we
-  // don't replace uses that are within the parent operation of the
-  // replacement value.
-  Operation *replOp = cast<OpResult>(repl).getOwner();
+  // If the replacement value is an operation, only replace those uses that:
+  // - are in a different block than the replacement operation, or
+  // - are in the same block but after the replacement operation.
+  //
+  // Example:
+  // ^bb0(%arg0: i32):
+  // %0 = "consumer"(%arg0) : (i32) -> (i32)
+  // "another_consumer"(%arg0) : (i32) -> ()
+  //
+  // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
+  // use in "another_consumer" but not the use in "consumer". When using the
+  // normal RewriterBase API, this would typically be done with
+  // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
+  // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
+  // it cannot be supported efficiently with `allowPatternRollback` set to
+  // "true". Therefore, the conversion driver is trying to be smart and replaces
+  // only those uses that do not lead to a dominance violation. E.g., the
+  // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
+  // behavior.
+  //
+  // TODO: As we move more and more towards `allowPatternRollback` set to
+  // "false", we should remove this special handling, in order to align the
+  // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
+  Operation *replOp = repl.getDefiningOp();
   Block *replBlock = replOp->getBlock();
-  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
   });
 }
 
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+  Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
   if (!repl)
     return;
-  performReplaceBlockArg(rewriter, arg, repl);
+  performReplaceValue(rewriter, value, repl);
 }
 
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
+void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener =
@@ -1584,7 +1622,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
               /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
               /*isPureTypeConversion=*/false)
               .front();
-      replaceUsesOfBlockArgument(origArg, mat, converter);
+      replaceAllUsesWith(origArg, mat, converter);
       continue;
     }
 
@@ -1593,15 +1631,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
-                                 converter);
+      replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
       continue;
     }
 
     // This is a 1->1+ mapping.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    replaceUsesOfBlockArgument(origArg, replArgs, converter);
+    replaceAllUsesWith(origArg, replArgs, converter);
   }
 
   if (config.allowPatternRollback)
@@ -1873,8 +1910,8 @@ void ConversionPatternRewriterImpl::replaceOp(
   op->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
-void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
-    BlockArgument from, ValueRange to, const TypeConverter *converter) {
+void ConversionPatternRewriterImpl::replaceAllUsesWith(
+    Value from, ValueRange to, const TypeConverter *converter) {
   if (!config.allowPatternRollback) {
     SmallVector<Value> toConv = llvm::to_vector(to);
     SmallVector<Value> repls =
@@ -1884,25 +1921,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
     if (!repl)
       return;
 
-    performReplaceBlockArg(r, from, repl);
+    performReplaceValue(r, from, repl);
     return;
   }
 
 #ifndef NDEBUG
-  // Make sure that a block argument is not replaced multiple times. In
-  // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
-  // uses of the given block argument, but also all future uses that may be
-  // introduced by future pattern applications. Therefore, it does not make
-  // sense to call `replaceUsesOfBlockArgument` multiple times with the same
-  // block argument. Doing so would overwrite the mapping and mess with the
-  // internal state of the dialect conversion driver.
-  assert(!replacedArgs.contains(from) &&
-         "attempting to replace a block argument that was already replaced");
-  replacedArgs.insert(from);
+  // Make sure that a value is not replaced multiple times. In rollback mode,
+  // `replaceAllUsesWith` replaces not only all current uses of the given value,
+  // but also all future uses that may be introduced by future pattern
+  // applications. Therefore, it does not make sense to call
+  // `replaceAllUsesWith` multiple times with the same value. Doing so would
+  // overwrite the mapping and mess with the internal state of the dialect
+  // conversion driver.
+  assert(!replacedValues.contains(from) &&
+         "attempting to replace a value that was already replaced");
+  replacedValues.insert(from);
 #endif // NDEBUG
 
-  appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
   mapping.map(from, to);
+  appendRewrite<ReplaceValueRewrite>(from, converter);
 }
 
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
@@ -2107,18 +2144,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   return impl->convertRegionTypes(region, converter, entryConversion);
 }
 
-void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
-                                                           ValueRange to) {
+void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
   LLVM_DEBUG({
-    impl->logger.startLine() << "** Replace Argument : '" << from << "'";
-    if (Operation *parentOp = from.getOwner()->getParentOp()) {
-      impl->logger.getOStream() << " (in region of '" << parentOp->getName()
-                                << "' (" << parentOp << ")\n";
-    } else {
-      impl->logger.getOStream() << " (unlinked block)\n";
+    impl->logger.startLine() << "** Replace Value : '" << from << "'";
+    if (auto blockArg = dyn_cast<BlockArgument>(from)) {
+      if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
+        impl->logger.getOStream() << " (in region of '" << parentOp->getName()
+                                  << "' (" << parentOp << ")\n";
+      } else {
+        impl->logger.getOStream() << " (unlinked block)\n";
+      }
     }
   });
-  impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
+  impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -2176,7 +2214,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
 
   // Replace all uses of block arguments.
   for (auto it : llvm::zip(source->getArguments(), argValues))
-    replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
   if (fastPath) {
     // Move all ops at once.
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 95f381ec471d6..5b8e4170b62ce 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -952,7 +952,7 @@ struct TestCreateIllegalBlock : public RewritePattern {
   }
 };
 
-/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
+/// A simple pattern that tests the "replaceAllUsesWith" API.
 struct TestBlockArgReplace : public ConversionPattern {
   TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
       : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
@@ -963,8 +963,7 @@ struct TestBlockArgReplace : public ConversionPattern {
                   ConversionPatternRewriter &rewriter) const final {
     // Replace the first block argument with 2x the second block argument.
     Value repl = op->getRegion(0).getArgument(1);
-    rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
-                                        {repl, repl});
+    rewriter.replaceAllUsesWith(op->getRegion(0).getArgument(0), {repl, repl});
     rewriter.modifyOpInPlace(op, [&] {
       // If the "trigger_rollback" attribute is set, keep the op illegal, so
       // that a rollback is triggered.



More information about the Mlir-commits mailing list