[Mlir-commits] [flang] [mlir] [mlir][Transforms] Add support for `ConversionPatternRewriter::replaceAllUsesWith` (PR #155244)

Matthias Springer llvmlistbot at llvm.org
Sat Aug 30 10:55:22 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/155244

>From 2031bc554b5df42d0f93a09ba2128b70ee49c8b1 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 23 Aug 2025 10:36:37 +0000
Subject: [PATCH] [mlir][Transforms] Add support for
 `ConversionPatternRewriter::replaceAllUsesWith`

---
 .../OpenMP/DoConcurrentConversion.cpp         |  23 ++-
 mlir/include/mlir/IR/PatternMatch.h           |   2 +-
 .../mlir/Transforms/DialectConversion.h       |  27 ++-
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp |   2 +-
 .../Transforms/Utils/DialectConversion.cpp    | 158 +++++++++++-------
 .../Transforms/test-legalizer-rollback.mlir   |   8 +-
 mlir/test/Transforms/test-legalizer.mlir      |  26 ++-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  24 +--
 8 files changed, 179 insertions(+), 91 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index c928b76065ade..de3b8d730072f 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -444,6 +444,18 @@ class DoConcurrentConversion
   mlir::SymbolTable &moduleSymbolTable;
 };
 
+/// A listener that forwards notifyOperationErased to the given callback.
+struct CallbackListener : public mlir::RewriterBase::Listener {
+  CallbackListener(std::function<void(mlir::Operation *op)> onOperationErased)
+      : onOperationErased(onOperationErased) {}
+
+  void notifyOperationErased(mlir::Operation *op) override {
+    onOperationErased(op);
+  }
+
+  std::function<void(mlir::Operation *op)> onOperationErased;
+};
+
 class DoConcurrentConversionPass
     : public flangomp::impl::DoConcurrentConversionPassBase<
           DoConcurrentConversionPass> {
@@ -468,6 +480,10 @@ class DoConcurrentConversionPass
     }
 
     llvm::DenseSet<fir::DoConcurrentOp> concurrentLoopsToSkip;
+    CallbackListener callbackListener([&](mlir::Operation *op) {
+      if (auto loop = mlir::dyn_cast<fir::DoConcurrentOp>(op))
+        concurrentLoopsToSkip.erase(loop);
+    });
     mlir::RewritePatternSet patterns(context);
     patterns.insert<DoConcurrentConversion>(
         context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
@@ -480,8 +496,11 @@ class DoConcurrentConversionPass
     target.markUnknownOpDynamicallyLegal(
         [](mlir::Operation *) { return true; });
 
-    if (mlir::failed(
-            mlir::applyFullConversion(module, target, std::move(patterns)))) {
+    mlir::ConversionConfig config;
+    config.allowPatternRollback = false;
+    config.listener = &callbackListener;
+    if (mlir::failed(mlir::applyFullConversion(module, target,
+                                               std::move(patterns), config))) {
       signalPassFailure();
     }
   }
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c1..7b0b9cef9c5bd 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder {
 
   /// Find uses of `from` and replace them with `to`. Also notify the listener
   /// about every in-place op modification (for every use that was replaced).
-  void replaceAllUsesWith(Value from, Value to) {
+  virtual void replaceAllUsesWith(Value from, Value to) {
     for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
       Operation *op = operand.getOwner();
       modifyOpInPlace(op, [&]() { operand.set(to); });
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 14dfbf18836c6..fe48e45a9b98c 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -854,15 +854,26 @@ class ConversionPatternRewriter final : public PatternRewriter {
       Region *region, const TypeConverter &converter,
       TypeConverter::SignatureConversion *entryConversion = nullptr);
 
-  /// Replace all the uses of the block argument `from` with `to`. This
-  /// function supports both 1:1 and 1:N replacements.
+  /// Replace all the uses of `from` with `to`. The type of `from` and `to` is
+  /// allowed to differ. The conversion driver will try to reconcile all type
+  /// mismatches that still exist at the end of the conversion with
+  /// materializations. This function supports both 1:1 and 1:N replacements.
   ///
-  /// Note: If `allowPatternRollback` is set to "true", this function replaces
-  /// all current and future uses of the block argument. This same block
-  /// block argument must not be replaced multiple times. Uses are not replaced
-  /// immediately but in a delayed fashion. Patterns may still see the original
-  /// uses when inspecting IR.
-  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
+  /// Note: If `allowPatternRollback` is set to "true", this function behaves
+  /// slightly different:
+  ///
+  /// 1. All current and future uses of `from` are replaced. The same value must
+  ///    not be replaced multiple times. That's an API violation.
+  /// 2. Uses are not replaced immediately but in a delayed fashion. Patterns
+  ///    may still see the original uses when inspecting IR.
+  /// 3. Uses within the same block that appear before the defining operation
+  ///    of the replacement value are not replaced. This allows users to
+  ///    perform certain replaceAllUsesExcept-style replacements, even though
+  ///    such API is not directly supported.
+  void replaceAllUsesWith(Value from, ValueRange to);
+  void replaceAllUsesWith(Value from, Value to) override {
+    replaceAllUsesWith(from, ValueRange{to});
+  }
 
   /// Return the converted value of 'key' with a type defined by the type
   /// converter of the currently executing pattern. Return nullptr in the case
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 42c76ed475b4c..93fe2edad5274 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -284,7 +284,7 @@ static void restoreByValRefArgumentType(
         cast<TypeAttr>(byValRefAttr->getValue()).getValue());
 
     Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
-    rewriter.replaceUsesOfBlockArgument(arg, valueArg);
+    rewriter.replaceAllUsesWith(arg, valueArg);
   }
 }
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5ba109d96cf13..d72429298754f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -277,13 +277,14 @@ class IRRewrite {
     InlineBlock,
     MoveBlock,
     BlockTypeConversion,
-    ReplaceBlockArg,
     // Operation rewrites
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
     CreateOperation,
-    UnresolvedMaterialization
+    UnresolvedMaterialization,
+    // Value rewrites
+    ReplaceValue
   };
 
   virtual ~IRRewrite() = default;
@@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::CreateBlock &&
-           rewrite->getKind() <= Kind::ReplaceBlockArg;
+           rewrite->getKind() <= Kind::BlockTypeConversion;
   }
 
 protected:
@@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite {
   Block *block;
 };
 
+/// A value rewrite.
+class ValueRewrite : public IRRewrite {
+public:
+  /// Return the value that this rewrite operates on.
+  Value getValue() const { return value; }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::ReplaceValue;
+  }
+
+protected:
+  ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+               Value value)
+      : IRRewrite(kind, rewriterImpl), value(value) {}
+
+  // The value that this rewrite operates on.
+  Value value;
+};
+
 /// Creation of a block. Block creations are immediately reflected in the IR.
 /// There is no extra work to commit the rewrite. During rollback, the newly
 /// created block is erased.
@@ -548,19 +568,18 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   Block *newBlock;
 };
 
-/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// Replacing a value. This rewrite is not immediately reflected in the
 /// IR. An internal IR mapping is updated, but the actual replacement is delayed
 /// until the rewrite is committed.
-class ReplaceBlockArgRewrite : public BlockRewrite {
+class ReplaceValueRewrite : public ValueRewrite {
 public:
-  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                         Block *block, BlockArgument arg,
-                         const TypeConverter *converter)
-      : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
+  ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
+                      const TypeConverter *converter)
+      : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
         converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
-    return rewrite->getKind() == Kind::ReplaceBlockArg;
+    return rewrite->getKind() == Kind::ReplaceValue;
   }
 
   void commit(RewriterBase &rewriter) override;
@@ -568,9 +587,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
   void rollback() override;
 
 private:
-  BlockArgument arg;
-
-  /// The current type converter when the block argument was replaced.
+  /// The current type converter when the value was replaced.
   const TypeConverter *converter;
 };
 
@@ -940,10 +957,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// uses.
   void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
 
-  /// Replace the given block argument with the given values. The specified
+  /// Replace the uses of the given value with the given values. The specified
   /// converter is used to build materializations (if necessary).
-  void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
-                                  const TypeConverter *converter);
+  void replaceAllUsesWith(Value from, ValueRange to,
+                          const TypeConverter *converter);
 
   /// Erase the given block and its contents.
   void eraseBlock(Block *block);
@@ -1129,10 +1146,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   IRRewriter notifyingRewriter;
 
 #ifndef NDEBUG
-  /// A set of replaced block arguments. This set is for debugging purposes
-  /// only and it is maintained only if `allowPatternRollback` is set to
-  /// "true".
-  DenseSet<BlockArgument> replacedArgs;
+  /// A set of replaced values. This set is for debugging purposes only and it
+  /// is maintained only if `allowPatternRollback` is set to "true".
+  DenseSet<Value> replacedValues;
 
   /// A set of operations that have pending updates. This tracking isn't
   /// strictly necessary, and is thus only active during debug builds for extra
@@ -1169,32 +1185,54 @@ void BlockTypeConversionRewrite::rollback() {
   getNewBlock()->replaceAllUsesWith(getOrigBlock());
 }
 
-static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
-                                   Value repl) {
+/// Replace all uses of `from` with `repl`.
+static void performReplaceValue(RewriterBase &rewriter, Value from,
+                                Value repl) {
   if (isa<BlockArgument>(repl)) {
-    rewriter.replaceAllUsesWith(arg, repl);
+    // `repl` is a block argument. Directly replace all uses.
+    rewriter.replaceAllUsesWith(from, repl);
     return;
   }
 
-  // If the replacement value is an operation, we check to make sure that we
-  // don't replace uses that are within the parent operation of the
-  // replacement value.
-  Operation *replOp = cast<OpResult>(repl).getOwner();
+  // If the replacement value is an operation, only replace those uses that:
+  // - are in a different block than the replacement operation, or
+  // - are in the same block but after the replacement operation.
+  //
+  // Example:
+  // ^bb0(%arg0: i32):
+  // %0 = "consumer"(%arg0) : (i32) -> (i32)
+  // "another_consumer"(%arg0) : (i32) -> ()
+  //
+  // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
+  // use in "another_consumer" but not the use in "consumer". When using the
+  // normal RewriterBase API, this would typically be done with
+  // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
+  // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
+  // it cannot be supported efficiently with `allowPatternRollback` set to
+  // "true". Therefore, the conversion driver is trying to be smart and replaces
+  // only those uses that do not lead to a dominance violation. E.g., the
+  // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
+  // behavior.
+  //
+  // TODO: As we move more and more towards `allowPatternRollback` set to
+  // "false", we should remove this special handling, in order to align the
+  // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
+  Operation *replOp = repl.getDefiningOp();
   Block *replBlock = replOp->getBlock();
-  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
   });
 }
 
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+  Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
   if (!repl)
     return;
-  performReplaceBlockArg(rewriter, arg, repl);
+  performReplaceValue(rewriter, value, repl);
 }
 
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
+void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener =
@@ -1584,7 +1622,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
               /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
               /*isPureTypeConversion=*/false)
               .front();
-      replaceUsesOfBlockArgument(origArg, mat, converter);
+      replaceAllUsesWith(origArg, mat, converter);
       continue;
     }
 
@@ -1593,15 +1631,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
-                                 converter);
+      replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
       continue;
     }
 
     // This is a 1->1+ mapping.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    replaceUsesOfBlockArgument(origArg, replArgs, converter);
+    replaceAllUsesWith(origArg, replArgs, converter);
   }
 
   if (config.allowPatternRollback)
@@ -1873,8 +1910,8 @@ void ConversionPatternRewriterImpl::replaceOp(
   op->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
-void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
-    BlockArgument from, ValueRange to, const TypeConverter *converter) {
+void ConversionPatternRewriterImpl::replaceAllUsesWith(
+    Value from, ValueRange to, const TypeConverter *converter) {
   if (!config.allowPatternRollback) {
     SmallVector<Value> toConv = llvm::to_vector(to);
     SmallVector<Value> repls =
@@ -1884,25 +1921,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
     if (!repl)
       return;
 
-    performReplaceBlockArg(r, from, repl);
+    performReplaceValue(r, from, repl);
     return;
   }
 
 #ifndef NDEBUG
-  // Make sure that a block argument is not replaced multiple times. In
-  // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
-  // uses of the given block argument, but also all future uses that may be
-  // introduced by future pattern applications. Therefore, it does not make
-  // sense to call `replaceUsesOfBlockArgument` multiple times with the same
-  // block argument. Doing so would overwrite the mapping and mess with the
-  // internal state of the dialect conversion driver.
-  assert(!replacedArgs.contains(from) &&
-         "attempting to replace a block argument that was already replaced");
-  replacedArgs.insert(from);
+  // Make sure that a value is not replaced multiple times. In rollback mode,
+  // `replaceAllUsesWith` replaces not only all current uses of the given value,
+  // but also all future uses that may be introduced by future pattern
+  // applications. Therefore, it does not make sense to call
+  // `replaceAllUsesWith` multiple times with the same value. Doing so would
+  // overwrite the mapping and mess with the internal state of the dialect
+  // conversion driver.
+  assert(!replacedValues.contains(from) &&
+         "attempting to replace a value that was already replaced");
+  replacedValues.insert(from);
 #endif // NDEBUG
 
-  appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
   mapping.map(from, to);
+  appendRewrite<ReplaceValueRewrite>(from, converter);
 }
 
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
@@ -2107,18 +2144,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   return impl->convertRegionTypes(region, converter, entryConversion);
 }
 
-void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
-                                                           ValueRange to) {
+void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
   LLVM_DEBUG({
-    impl->logger.startLine() << "** Replace Argument : '" << from << "'";
-    if (Operation *parentOp = from.getOwner()->getParentOp()) {
-      impl->logger.getOStream() << " (in region of '" << parentOp->getName()
-                                << "' (" << parentOp << ")\n";
-    } else {
-      impl->logger.getOStream() << " (unlinked block)\n";
+    impl->logger.startLine() << "** Replace Value : '" << from << "'";
+    if (auto blockArg = dyn_cast<BlockArgument>(from)) {
+      if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
+        impl->logger.getOStream() << " (in region of '" << parentOp->getName()
+                                  << "' (" << parentOp << ")\n";
+      } else {
+        impl->logger.getOStream() << " (unlinked block)\n";
+      }
     }
   });
-  impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
+  impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -2176,7 +2214,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
 
   // Replace all uses of block arguments.
   for (auto it : llvm::zip(source->getArguments(), argValues))
-    replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
   if (fastPath) {
     // Move all ops at once.
diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir
index 460911fd88ad1..71e11782e14b0 100644
--- a/mlir/test/Transforms/test-legalizer-rollback.mlir
+++ b/mlir/test/Transforms/test-legalizer-rollback.mlir
@@ -49,14 +49,16 @@ func.func @create_illegal_block() {
 // expected-remark at +1{{applyPartialConversion failed}}
 module {
 func.func @undo_block_arg_replace() {
-  // expected-error at +1{{failed to legalize operation 'test.block_arg_replace' that was explicitly marked illegal}}
-  "test.block_arg_replace"() ({
+  "test.legal_op"() ({
   ^bb0(%arg0: i32, %arg1: i16):
     // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK-NEXT: "test.value_replace"(%[[ARG0]], %[[ARG1]]) {trigger_rollback}
     // CHECK-NEXT: "test.return"(%[[ARG0]]) : (i32)
 
+    // expected-error at +1{{failed to legalize operation 'test.value_replace' that was explicitly marked illegal}}
+    "test.value_replace"(%arg0, %arg1) {trigger_rollback} : (i32, i16) -> ()
     "test.return"(%arg0) : (i32) -> ()
-  }) {trigger_rollback} : () -> ()
+  }) : () -> ()
   return
 }
 }
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 3fa42ff6b2757..94c5bb4e93b06 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -269,12 +269,14 @@ builtin.module {
 
 // CHECK-LABEL: @replace_block_arg_1_to_n
 func.func @replace_block_arg_1_to_n() {
-  // CHECK: "test.block_arg_replace"
-  "test.block_arg_replace"() ({
+  // CHECK: "test.legal_op"
+  "test.legal_op"() ({
   ^bb0(%arg0: i32, %arg1: i16):
-    // CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
-    // CHECK: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+    // CHECK-NEXT: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16):
+    // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[ARG1]], %[[ARG1]]) : (i16, i16) -> i32
+    // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[ARG1]]) {is_legal} : (i32, i16) -> ()
     // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+    "test.value_replace"(%arg0, %arg1) : (i32, i16) -> ()
     "test.return"(%arg0) : (i32) -> ()
   }) : () -> ()
   "test.return"() : () -> ()
@@ -282,6 +284,22 @@ func.func @replace_block_arg_1_to_n() {
 
 // -----
 
+// CHECK-LABEL: @replace_op_result_1_to_n
+func.func @replace_op_result_1_to_n() {
+  // CHECK: %[[orig:.*]] = "test.legal_op"() : () -> i32
+  // CHECK: %[[repl:.*]] = "test.legal_op"() : () -> i16
+  %0 = "test.legal_op"() : () -> i32
+  %1 = "test.legal_op"() : () -> i16
+
+  // CHECK-NEXT: %[[cast:.*]] = "test.cast"(%[[repl]], %[[repl]]) : (i16, i16) -> i32
+  // CHECK-NEXT: "test.value_replace"(%[[cast]], %[[repl]]) {is_legal} : (i32, i16) -> ()
+  // CHECK-NEXT: "test.return"(%[[cast]]) : (i32)
+  "test.value_replace"(%0, %1) : (i32, i16) -> ()
+  "test.return"(%0) : (i32) -> ()
+}
+
+// -----
+
 // Check that a conversion pattern on `test.blackhole` can mark the producer
 // for deletion.
 // CHECK-LABEL: @blackhole
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 95f381ec471d6..93b007c792ad9 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -952,19 +952,19 @@ struct TestCreateIllegalBlock : public RewritePattern {
   }
 };
 
-/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
-struct TestBlockArgReplace : public ConversionPattern {
-  TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
-      : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
-                          ctx) {}
+/// A simple pattern that tests the "replaceAllUsesWith" API.
+struct TestValueReplace : public ConversionPattern {
+  TestValueReplace(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.value_replace", /*benefit=*/1, ctx) {
+  }
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    // Replace the first block argument with 2x the second block argument.
-    Value repl = op->getRegion(0).getArgument(1);
-    rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
-                                        {repl, repl});
+    // Replace the first operand with 2x the second operand.
+    Value from = op->getOperand(0);
+    Value repl = op->getOperand(1);
+    rewriter.replaceAllUsesWith(from, {repl, repl});
     rewriter.modifyOpInPlace(op, [&] {
       // If the "trigger_rollback" attribute is set, keep the op illegal, so
       // that a rollback is triggered.
@@ -1515,7 +1515,7 @@ struct TestLegalizePatternDriver
         TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
-                 TestBlockArgReplace, TestReplaceWithValidConsumer,
+                 TestValueReplace, TestReplaceWithValidConsumer,
                  TestTypeConsumerOpPattern>(&getContext(), converter);
     patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
@@ -1543,7 +1543,7 @@ struct TestLegalizePatternDriver
     target.addDynamicallyLegalOp<func::CallOp>(
         [&](func::CallOp op) { return converter.isLegal(op); });
     target.addDynamicallyLegalOp(
-        OperationName("test.block_arg_replace", &getContext()),
+        OperationName("test.value_replace", &getContext()),
         [](Operation *op) { return op->hasAttr("is_legal"); });
 
     // TestCreateUnregisteredOp creates `arith.constant` operation,



More information about the Mlir-commits mailing list