[llvm-branch-commits] [mlir] [mlir][Transforms] Support `replaceAllUsesWith` in dialect conversion (PR #84725)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Mar 11 00:33:07 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/84725

This commit adds support for `RewriterBase::replaceAllUsesWith` to the dialect conversion. Uses are not immediately replaced, but in a delayed fashion during the "commit" phase. No type conversions are performed; this is consistent with `ConversionPatternRewriter::replaceUsesOfBlockArgument`.

- `RewriterBase::replaceAllUsesWith` is now virtual, so that it can be overridden in the dialect conversion. Note: `RewriterBase::replaceOp` can now be turned into a non-virtual function in a follow-up commit.
- `ConversionPatternRewriter::replaceUsesOfBlockArgument` is generalized to `ConversionPatternRewriter::replaceAllUsesWith`, following the same implementation strategy.
- A new kind of "IR rewrite" is added: `ValueRewrite` with `ReplaceAllUsesRewrite` (replacing `ReplaceBlockArgRewrite`) as the only value rewrite for now.
- `replacedOps` is renamed to `erasedOps` to better capture its meaning.


>From 15c5ef4723628eae7dbfc1f1738a69f641dd5cc8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 11 Mar 2024 07:31:49 +0000
Subject: [PATCH] [mlir][Transforms] Support `replaceAllUsesWith` in dialect
 conversion

This commit adds support for `RewriterBase::replaceAllUsesWith` to the dialect conversion. Uses are not immediately replaced, but in a delayed fashion during the "commit" phase. No type conversions are performed; this is consistent with `ConversionPatternRewriter::replaceUsesOfBlockArgument`.

- `RewriterBase::replaceAllUsesWith` is now virtual, so that it can be overridden in the dialect conversion. Note: `RewriterBase::replaceOp` can now be turned into a non-virtual function in a follow-up commit.
- `ConversionPatternRewriter::replaceUsesOfBlockArgument` is generalized to `ConversionPatternRewriter::replaceAllUsesWith`, following the same implementation strategy.
- A new kind of "IR rewrite" is added: `ValueRewrite` with `ReplaceAllUsesRewrite` (replacing `ReplaceBlockArgRewrite`) as the only value rewrite for now.
- `replacedOps` is renamed to `erasedOps` to better capture its meaning.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 mlir/include/mlir/IR/PatternMatch.h           |   2 +-
 .../mlir/Transforms/DialectConversion.h       |   8 +-
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp |   2 +-
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   |   2 +-
 .../Transforms/Utils/DialectConversion.cpp    | 209 ++++++++++--------
 mlir/test/Transforms/test-legalizer.mlir      |  18 ++
 mlir/test/lib/Dialect/Test/TestOps.td         |   1 +
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  27 ++-
 8 files changed, 172 insertions(+), 97 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2be1e2e2b40276..3e11e00b9d4b40 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -614,7 +614,7 @@ class RewriterBase : public OpBuilder {
 
   /// Find uses of `from` and replace them with `to`. Also notify the listener
   /// about every in-place op modification (for every use that was replaced).
-  void replaceAllUsesWith(Value from, Value to) {
+  virtual void replaceAllUsesWith(Value from, Value to) {
     for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
       Operation *op = operand.getOwner();
       modifyOpInPlace(op, [&]() { operand.set(to); });
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9b0db545..1797ee0876e437 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -697,9 +697,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
       Region *region, const TypeConverter &converter,
       ArrayRef<TypeConverter::SignatureConversion> blockConversions);
 
-  /// Replace all the uses of the block argument `from` with value `to`.
-  void replaceUsesOfBlockArgument(BlockArgument from, Value to);
-
   /// Return the converted value of 'key' with a type defined by the type
   /// converter of the currently executing pattern. Return nullptr in the case
   /// of failure, the remapped value otherwise.
@@ -720,6 +717,11 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// patterns even if a failure is encountered during the rewrite step.
   bool canRecoverFromRewriteFailure() const override { return true; }
 
+  /// Find uses of `from` and replace them with `to`.
+  ///
+  /// Note: This function does not convert types.
+  void replaceAllUsesWith(Value from, Value to) override;
+
   /// PatternRewriter hook for replacing an operation.
   void replaceOp(Operation *op, ValueRange newValues) override;
 
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 53b44aa3241bb1..d7ed9a196e8938 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -310,7 +310,7 @@ static void modifyFuncOpToUseBarePtrCallingConv(
     Location loc = funcOp.getLoc();
     auto placeholder = rewriter.create<LLVM::UndefOp>(
         loc, typeConverter.convertType(memrefTy));
-    rewriter.replaceUsesOfBlockArgument(arg, placeholder);
+    rewriter.replaceAllUsesWith(arg, placeholder);
 
     Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
                                                    memrefTy, arg);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 73d418cb841327..c6d2ddac9dbb19 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -201,7 +201,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
           llvmFuncOp.getBody().getArgument(remapping->inputNo);
       auto placeholder = rewriter.create<LLVM::UndefOp>(
           loc, getTypeConverter()->convertType(memrefTy));
-      rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
+      rewriter.replaceAllUsesWith(newArg, placeholder);
       Value desc = MemRefDescriptor::fromStaticShape(
           rewriter, loc, *getTypeConverter(), memrefTy, newArg);
       rewriter.replaceOp(placeholder, {desc});
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c1a261eab8487d..e4a022b7a0288b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numReplacedOps)
+                unsigned numErasedOps)
       : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
-        numReplacedOps(numReplacedOps) {}
+        numErasedOps(numErasedOps) {}
 
   /// The current number of rewrites performed.
   unsigned numRewrites;
@@ -163,8 +163,8 @@ struct RewriterState {
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
 
-  /// The current number of replaced ops that are scheduled for erasure.
-  unsigned numReplacedOps;
+  /// The current number of ops that are scheduled for erasure.
+  unsigned numErasedOps;
 };
 
 //===----------------------------------------------------------------------===//
@@ -190,13 +190,14 @@ class IRRewrite {
     InlineBlock,
     MoveBlock,
     BlockTypeConversion,
-    ReplaceBlockArg,
     // Operation rewrites
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
     CreateOperation,
-    UnresolvedMaterialization
+    UnresolvedMaterialization,
+    // Value rewrites
+    ReplaceAllUses
   };
 
   virtual ~IRRewrite() = default;
@@ -243,7 +244,7 @@ class BlockRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::CreateBlock &&
-           rewrite->getKind() <= Kind::ReplaceBlockArg;
+           rewrite->getKind() <= Kind::BlockTypeConversion;
   }
 
 protected:
@@ -487,27 +488,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   const TypeConverter *converter;
 };
 
-/// Replacing a block argument. This rewrite is not immediately reflected in the
-/// IR. An internal IR mapping is updated, but the actual replacement is delayed
-/// until the rewrite is committed.
-class ReplaceBlockArgRewrite : public BlockRewrite {
-public:
-  ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                         Block *block, BlockArgument arg)
-      : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
-
-  static bool classof(const IRRewrite *rewrite) {
-    return rewrite->getKind() == Kind::ReplaceBlockArg;
-  }
-
-  void commit(RewriterBase &rewriter) override;
-
-  void rollback() override;
-
-private:
-  BlockArgument arg;
-};
-
 /// An operation rewrite.
 class OperationRewrite : public IRRewrite {
 public:
@@ -751,6 +731,44 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   /// The original output type. This is only used for argument conversions.
   Type origOutputType;
 };
+
+/// A value rewrite.
+class ValueRewrite : public IRRewrite {
+public:
+  /// Return the operation that this rewrite operates on.
+  Value getValue() const { return value; }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() >= Kind::ReplaceAllUses &&
+           rewrite->getKind() <= Kind::ReplaceAllUses;
+  }
+
+protected:
+  ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+               Value value)
+      : IRRewrite(kind, rewriterImpl), value(value) {}
+
+  // The value that this rewrite operates on.
+  Value value;
+};
+
+/// Replacing a value. This rewrite is not immediately reflected in the IR. An
+/// internal IR mapping is updated, but the actual replacement is delayed until
+/// the rewrite is committed.
+class ReplaceAllUsesRewrite : public ValueRewrite {
+public:
+  ReplaceAllUsesRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                        Value value)
+      : ValueRewrite(Kind::ReplaceAllUses, rewriterImpl, value) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::ReplaceAllUses;
+  }
+
+  void commit(RewriterBase &rewriter) override;
+
+  void rollback() override;
+};
 } // namespace
 
 /// Return "true" if there is an operation rewrite that matches the specified
@@ -832,8 +850,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// converted.
   bool isOpIgnored(Operation *op) const;
 
-  /// Return "true" if the given operation was replaced or erased.
-  bool wasOpReplaced(Operation *op) const;
+  /// Return "true" if the given operation is scheduled for erasure. (It may
+  /// still be visible in the IR, but should not be accessed.)
+  bool wasOpErased(Operation *op) const;
 
   //===--------------------------------------------------------------------===//
   // Type Conversion
@@ -982,11 +1001,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// tracked separately.
   SetVector<Operation *> ignoredOps;
 
-  /// A set of operations that were replaced/erased. Such ops are not erased
-  /// immediately but only when the dialect conversion succeeds. In the mean
-  /// time, they should no longer be considered for legalization and any attempt
-  /// to modify/access them is invalid rewriter API usage.
-  SetVector<Operation *> replacedOps;
+  /// A set of operations that were erased. Such ops are not erased immediately
+  /// but only when the dialect conversion succeeds. In the mean time, they
+  /// should no longer be considered for legalization and any attempt to
+  /// modify/access them is invalid rewriter API usage.
+  SetVector<Operation *> erasedOps;
 
   /// The current type converter, or nullptr if no type converter is currently
   /// active.
@@ -1099,13 +1118,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
   return success();
 }
 
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
-  if (!repl)
-    return;
+void ReplaceAllUsesRewrite::commit(RewriterBase &rewriter) {
+  Value repl = rewriterImpl.mapping.lookupOrNull(value);
+  assert(repl && "expected that value is mapped");
 
   if (isa<BlockArgument>(repl)) {
-    rewriter.replaceAllUsesWith(arg, repl);
+    rewriter.replaceAllUsesWith(value, repl);
     return;
   }
 
@@ -1114,13 +1132,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   // replacement value.
   Operation *replOp = cast<OpResult>(repl).getOwner();
   Block *replBlock = replOp->getBlock();
-  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
   });
 }
 
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+void ReplaceAllUsesRewrite::rollback() { rewriterImpl.mapping.erase(value); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
@@ -1205,7 +1223,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
+  return RewriterState(rewrites.size(), ignoredOps.size(), erasedOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1216,8 +1234,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
 
-  while (replacedOps.size() != state.numReplacedOps)
-    replacedOps.pop_back();
+  while (erasedOps.size() != state.numErasedOps)
+    erasedOps.pop_back();
 }
 
 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1282,13 +1300,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
 }
 
 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
-  // Check to see if this operation is ignored or was replaced.
-  return replacedOps.count(op) || ignoredOps.count(op);
+  // Check to see if this operation is ignored or was erased.
+  return erasedOps.count(op) || ignoredOps.count(op);
 }
 
-bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
-  // Check to see if this operation was replaced.
-  return replacedOps.count(op);
+bool ConversionPatternRewriterImpl::wasOpErased(Operation *op) const {
+  // Check to see if this operation was scheduled for erasure.
+  return erasedOps.count(op);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1434,7 +1452,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
       mapping.map(origArg, inputMap->replacementValue);
-      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+      appendRewrite<ReplaceAllUsesRewrite>(origArg);
       continue;
     }
 
@@ -1469,7 +1487,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     }
 
     mapping.map(origArg, newArg);
-    appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+    appendRewrite<ReplaceAllUsesRewrite>(origArg);
     argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
@@ -1535,8 +1553,8 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
-  assert(!wasOpReplaced(op->getParentOp()) &&
-         "attempting to insert into a block within a replaced/erased op");
+  assert(!wasOpErased(op->getParentOp()) &&
+         "attempting to insert into a block within an erased op");
 
   if (!previous.isSet()) {
     // This is a newly created op.
@@ -1571,8 +1589,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
                                          resultChanged);
 
-  // Mark this operation and all nested ops as replaced.
-  op->walk([&](Operation *op) { replacedOps.insert(op); });
+  // Mark this operation and all nested ops as erased.
+  op->walk([&](Operation *op) { erasedOps.insert(op); });
 }
 
 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
@@ -1583,8 +1601,8 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
 
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
-  assert(!wasOpReplaced(block->getParentOp()) &&
-         "attempting to insert into a region within a replaced/erased op");
+  assert(!wasOpErased(block->getParentOp()) &&
+         "attempting to insert into a region within an erased op");
   LLVM_DEBUG(
       {
         Operation *parent = block->getParentOp();
@@ -1660,8 +1678,8 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {
-  assert(!impl->wasOpReplaced(block->getParentOp()) &&
-         "attempting to erase a block within a replaced/erased op");
+  assert(!impl->wasOpErased(block->getParentOp()) &&
+         "attempting to erase a block within an erased op");
 
   // Mark all ops for erasure.
   for (Operation &op : *block)
@@ -1678,41 +1696,59 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
 Block *ConversionPatternRewriter::applySignatureConversion(
     Region *region, TypeConverter::SignatureConversion &conversion,
     const TypeConverter *converter) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
-         "attempting to apply a signature conversion to a block within a "
-         "replaced/erased op");
+  assert(!impl->wasOpErased(region->getParentOp()) &&
+         "attempting to apply a signature conversion to a block within an "
+         "erased op");
   return impl->applySignatureConversion(*this, region, conversion, converter);
 }
 
 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
     Region *region, const TypeConverter &converter,
     TypeConverter::SignatureConversion *entryConversion) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
-         "attempting to apply a signature conversion to a block within a "
-         "replaced/erased op");
+  assert(!impl->wasOpErased(region->getParentOp()) &&
+         "attempting to apply a signature conversion to a block within an "
+         "erased op");
   return impl->convertRegionTypes(*this, region, converter, entryConversion);
 }
 
 LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
     Region *region, const TypeConverter &converter,
     ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
-         "attempting to apply a signature conversion to a block within a "
-         "replaced/erased op");
+  assert(!impl->wasOpErased(region->getParentOp()) &&
+         "attempting to apply a signature conversion to a block within an "
+         "erased op");
   return impl->convertNonEntryRegionTypes(*this, region, converter,
                                           blockConversions);
 }
 
-void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
-                                                           Value to) {
+void ConversionPatternRewriter::replaceAllUsesWith(Value from, Value to) {
+#ifndef NDEBUG
   LLVM_DEBUG({
-    Operation *parentOp = from.getOwner()->getParentOp();
-    impl->logger.startLine() << "** Replace Argument : '" << from
-                             << "'(in region of '" << parentOp->getName()
-                             << "'(" << from.getOwner()->getParentOp() << ")\n";
+    Block *parentBlock = from.getParentBlock();
+    Operation *parentOp = parentBlock ? parentBlock->getParentOp() : nullptr;
+    impl->logger.startLine() << "** Replace value : '" << from;
+    if (parentOp) {
+      impl->logger.getOStream() << "' (in region of '" << parentOp->getName()
+                                << "'(" << parentOp << ")\n";
+    } else {
+      impl->logger.getOStream() << "' (detached)\n";
+    }
   });
-  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
-  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+  if (OpResult opResult = dyn_cast<OpResult>(from)) {
+    assert(!impl->wasOpErased(opResult.getDefiningOp()) &&
+           "attempting to replace an OpResult defined by an erased op");
+  }
+  if (OpResult opResult = dyn_cast<OpResult>(to)) {
+    assert(!impl->wasOpErased(opResult.getDefiningOp()) &&
+           "attempting to replace with an OpResult defined by an erased op");
+  }
+  // A value cannot be replaced multiple times. That would likely require a more
+  // fine-grained tracking of replacements (i.e., each use must be tracked).
+  assert(!impl->mapping.lookupOrNull(from) && "value was already replaced");
+#endif // NDEBUG
+
+  impl->appendRewrite<ReplaceAllUsesRewrite>(from);
+  impl->mapping.map(from, to);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -1738,10 +1774,10 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
 #ifndef NDEBUG
   assert(argValues.size() == source->getNumArguments() &&
          "incorrect # of argument replacement values");
-  assert(!impl->wasOpReplaced(source->getParentOp()) &&
-         "attempting to inline a block from a replaced/erased op");
-  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
-         "attempting to inline a block into a replaced/erased op");
+  assert(!impl->wasOpErased(source->getParentOp()) &&
+         "attempting to inline a block from an erased op");
+  assert(!impl->wasOpErased(dest->getParentOp()) &&
+         "attempting to inline a block into an erased op");
   auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
   // The source block will be deleted, so it should not have any users (i.e.,
   // there should be no predecessors).
@@ -1762,7 +1798,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
 
   // Replace all uses of block arguments.
   for (auto it : llvm::zip(source->getArguments(), argValues))
-    replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
   if (fastPath) {
     // Move all ops at once.
@@ -1778,8 +1814,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
 }
 
 void ConversionPatternRewriter::startOpModification(Operation *op) {
-  assert(!impl->wasOpReplaced(op) &&
-         "attempting to modify a replaced/erased op");
+  assert(!impl->wasOpErased(op) && "attempting to modify an erased op");
 #ifndef NDEBUG
   impl->pendingRootUpdates.insert(op);
 #endif
@@ -1787,8 +1822,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
 }
 
 void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
-  assert(!impl->wasOpReplaced(op) &&
-         "attempting to modify a replaced/erased op");
+  assert(!impl->wasOpErased(op) && "attempting to modify an erased op");
   PatternRewriter::finalizeOpModification(op);
   // There is nothing to do here, we only need to track the operation at the
   // start of the update.
@@ -2204,8 +2238,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     if (!rewrite)
       continue;
     Block *block = rewrite->getBlock();
-    if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
-            ReplaceBlockArgRewrite>(rewrite))
+    if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite))
       continue;
     // Only check blocks outside of the current operation.
     Operation *parentOp = block->getParentOp();
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d552f0346644b3..78dc3f988a45ab 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -427,3 +427,21 @@ func.func @use_of_replaced_bbarg(%arg0: i64) {
   }) : (i64) -> (i64)
   "test.invalid"(%0) : (i64) -> ()
 }
+
+// -----
+
+// CHECK: notifyOperationInserted: test.legal_op_b, was unlinked
+// CHECK: notifyOperationModified: test.valid
+// CHECK: notifyOperationModified: test.illegal_op_h
+
+// CHECK-LABEL: func @replace_all_uses_with()
+func.func @replace_all_uses_with() {
+  // CHECK: %[[legal:.*]] = "test.legal_op_b"() : () -> i32
+  // CHECK: %[[illegal:.*]] = "test.illegal_op_h"() {not_illegal} : () -> i64
+  %result = "test.illegal_op_h"() : () -> (i64)
+
+  // replaceAllUsesWith does not perform any type conversion. The uses are
+  // directly updated during the commit phase.
+  // CHECK: "test.valid"(%[[legal]]) : (i32) -> ()
+  "test.valid"(%result) : (i64) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index dfd2f21a5ea249..c19b0d2bc43c8f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1856,6 +1856,7 @@ def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32)>;
 def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>;
 def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>;
 def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>;
+def ILLegalOpH : TEST_Op<"illegal_op_h">, Results<(outs I64)>;
 def LegalOpA : TEST_Op<"legal_op_a">,
   Arguments<(ins StrAttr:$status)>, Results<(outs I32)>;
 def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 76dc825fe44515..718fbf10f59883 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -785,8 +785,8 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
                   ConversionPatternRewriter &rewriter) const final {
     auto illegalOp =
         rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
-    rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
-                                        illegalOp->getResult(0));
+    rewriter.replaceAllUsesWith(op->getRegion(0).getArgument(0),
+                                illegalOp->getResult(0));
     rewriter.modifyOpInPlace(op, [] {});
     return success();
   }
@@ -840,6 +840,24 @@ struct TestUndoPropertiesModification : public ConversionPattern {
   }
 };
 
+/// A pattern that replaces all uses of illegal_op_h with a newly created op
+/// that has one i32 result. The old op is marked as "legal".
+struct ReplaceAllUsesOfIllegalOp : public ConversionPattern {
+  ReplaceAllUsesOfIllegalOp(MLIRContext *context)
+      : ConversionPattern("test.illegal_op_h", /*benefit=*/1, context) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    Operation *legalOp =
+        rewriter.create<LegalOpB>(op->getLoc(), rewriter.getIntegerType(32));
+    rewriter.replaceAllOpUsesWith(op, legalOp);
+    rewriter.modifyOpInPlace(
+        op, [&] { op->setAttr("not_illegal", rewriter.getUnitAttr()); });
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Type-Conversion Rewrite Testing
 
@@ -1120,7 +1138,8 @@ struct TestLegalizePatternDriver
              TestNonRootReplacement, TestBoundedRecursiveRewrite,
              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
              TestCreateUnregisteredOp, TestUndoMoveOpBefore,
-             TestUndoPropertiesModification>(&getContext());
+             TestUndoPropertiesModification, ReplaceAllUsesOfIllegalOp>(
+            &getContext());
     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1133,6 +1152,8 @@ struct TestLegalizePatternDriver
                       TerminatorOp, OneRegionOp>();
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
+    target.addDynamicallyLegalOp<ILLegalOpH>(
+        [](ILLegalOpH op) { return op->hasAttr("not_illegal"); });
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
       // Don't allow F32 operands.
       return llvm::none_of(op.getOperandTypes(),



More information about the llvm-branch-commits mailing list