[Mlir-commits] [mlir] [mlir][IR] Make `RewriterBase::replaceOp` non-virtual (PR #160529)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 24 06:56:55 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

`RewriterBase::replaceOp` used to be virtual, so that `ConversionPatternRewriter` can override the function. This is no longer necessary: both `replaceAllUsesWith` and `eraseOp` are virtual, and `replaceOp` is a combination of these two functions.

Implementation details in the dialect conversion:
* `ReplaceOperationRewrite` is now just a placeholder IRRewrite, which notifies the listener. (It used to be called when erasing an op. This was treated as replacing an op with only null Values.)
* A new `EraseOperationRewrite` was added, which is used when erasing an operation. In rollback mode, an operation can be erased when it still has uses; in that case, source materializations must be created out-of-thin air. This is done with `EraseOperationRewrite::commit`. We used to do this during `ReplaceOperationRewrite::commit`.
* In "no rollback" mode, erasing an operation that still has uses is no longer allowed. That's why the `circular_mapping` test case is moved to a different file.
* The `getReplacementValues` helper function is moved into `ConversionPatternRewriterImpl::replaceAllUsesWith`. (The loop is no longer needed.)


---

Patch is 24.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160529.diff


5 Files Affected:

- (modified) mlir/include/mlir/IR/PatternMatch.h (+2-2) 
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+1-15) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+144-162) 
- (modified) mlir/test/Transforms/test-legalizer-rollback.mlir (+27) 
- (modified) mlir/test/Transforms/test-legalizer.mlir (+1-26) 


``````````diff
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 576481a6e7215..208923cc24224 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -515,12 +515,12 @@ class RewriterBase : public OpBuilder {
   /// Replace the results of the given (original) operation with the specified
   /// list of values (replacements). The result types of the given op and the
   /// replacements must match. The original op is erased.
-  virtual void replaceOp(Operation *op, ValueRange newValues);
+  void replaceOp(Operation *op, ValueRange newValues);
 
   /// Replace the results of the given (original) operation with the specified
   /// new op (replacement). The result types of the two ops must match. The
   /// original op is erased.
-  virtual void replaceOp(Operation *op, Operation *newOp);
+  void replaceOp(Operation *op, Operation *newOp);
 
   /// Replace the results of the given (original) op with a new op that is
   /// created without verification (replacement). The result values of the two
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index ed7e2a08ebfd9..6c5e212058d17 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -900,7 +900,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// RewriterBase APIs, (3) may be removed in the future.
   void replaceAllUsesWith(Value from, ValueRange to);
   void replaceAllUsesWith(Value from, Value to) override {
-    replaceAllUsesWith(from, ValueRange{to});
+    replaceAllUsesWith(from, to ? ValueRange{to} : ValueRange{});
   }
 
   /// Return the converted value of 'key' with a type defined by the type
@@ -923,20 +923,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// patterns even if a failure is encountered during the rewrite step.
   bool canRecoverFromRewriteFailure() const override { return true; }
 
-  /// Replace the given operation with the new values. The number of op results
-  /// and replacement values must match. The types may differ: the dialect
-  /// conversion driver will reconcile any surviving type mismatches at the end
-  /// of the conversion process with source materializations. The given
-  /// operation is erased.
-  void replaceOp(Operation *op, ValueRange newValues) override;
-
-  /// Replace the given operation with the results of the new op. The number of
-  /// op results must match. The types may differ: the dialect conversion
-  /// driver will reconcile any surviving type mismatches at the end of the
-  /// conversion process with source materializations. The original operation
-  /// is erased.
-  void replaceOp(Operation *op, Operation *newOp) override;
-
   /// Replace the given operation with the new value ranges. The number of op
   /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op,
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bf0136b39e03c..1f3a8132f47be 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -67,7 +67,7 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
 
 /// Helper function that computes an insertion point where the given values are
 /// defined and can be used without a dominance violation.
-static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+static OpBuilder::InsertPoint computeInsertPoint(ValueRange vals) {
   assert(!vals.empty() && "expected at least one value");
   DominanceInfo domInfo;
   OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
@@ -281,6 +281,7 @@ class IRRewrite {
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
+    EraseOperation,
     CreateOperation,
     UnresolvedMaterialization,
     // Value rewrites
@@ -720,16 +721,13 @@ class ModifyOperationRewrite : public OperationRewrite {
   void *propertiesStorage = nullptr;
 };
 
-/// Replacing an operation. Erasing an operation is treated as a special case
-/// with "null" replacements. This rewrite is not immediately reflected in the
-/// IR. An internal IR mapping is updated, but values are not replaced and the
-/// original op is not erased until the rewrite is committed.
 class ReplaceOperationRewrite : public OperationRewrite {
 public:
   ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                          Operation *op, const TypeConverter *converter)
+                          Operation *op, Operation *replOp,
+                          ValueRange replValues)
       : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
-        converter(converter) {}
+        replOp(replOp), replValues(replValues) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::ReplaceOperation;
@@ -737,6 +735,28 @@ class ReplaceOperationRewrite : public OperationRewrite {
 
   void commit(RewriterBase &rewriter) override;
 
+  void rollback() override {}
+
+private:
+  Operation *replOp;
+  ValueRange replValues;
+};
+
+/// Erasing an operation. This rewrite is not immediately reflected in the
+/// IR. The original op is not erased until the rewrite is committed.
+class EraseOperationRewrite : public OperationRewrite {
+public:
+  EraseOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                        Operation *op, const TypeConverter *converter)
+      : OperationRewrite(Kind::EraseOperation, rewriterImpl, op),
+        converter(converter) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::EraseOperation;
+  }
+
+  void commit(RewriterBase &rewriter) override;
+
   void rollback() override;
 
   void cleanup(RewriterBase &rewriter) override;
@@ -948,15 +968,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       Block *block, const TypeConverter *converter,
       TypeConverter::SignatureConversion &signatureConversion);
 
-  /// Replace the results of the given operation with the given values and
-  /// erase the operation.
-  ///
-  /// There can be multiple replacement values for each result (1:N
-  /// replacement). If the replacement values are empty, the respective result
-  /// is dropped and a source materialization is built if the result still has
-  /// uses.
-  void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
-
   /// Replace the uses of the given value with the given values. The specified
   /// converter is used to build materializations (if necessary).
   void replaceAllUsesWith(Value from, ValueRange to,
@@ -965,6 +976,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Erase the given block and its contents.
   void eraseBlock(Block *block);
 
+  /// Erase the given operation and its contents.
+  void eraseOp(Operation *op);
+
   /// Inline the source block into the destination block before the given
   /// iterator.
   void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
@@ -1016,6 +1030,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   void notifyBlockInserted(Block *block, Region *previous,
                            Region::iterator previousIt) override;
 
+  /// Notifies that an operation is about to be replaced with another operation.
+  void notifyOperationReplaced(Operation *op, Operation *replacement) override;
+
+  /// Notifies that an operation is about to be replaced with a range of values.
+  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
+
   /// Notifies that a pattern match failed for the given reason.
   void
   notifyMatchFailure(Location loc,
@@ -1242,18 +1262,25 @@ void ReplaceValueRewrite::rollback() {
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener =
       dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
+  if (!listener)
+    return;
 
+  // Notify the listener that the operation is about to be replaced.
+  if (replOp) {
+    listener->notifyOperationReplaced(op, replOp);
+  } else {
+    listener->notifyOperationReplaced(op, replValues);
+  }
+}
+
+void EraseOperationRewrite::commit(RewriterBase &rewriter) {
   // Compute replacement values.
   SmallVector<Value> replacements =
       llvm::map_to_vector(op->getResults(), [&](OpResult result) {
         return rewriterImpl.findOrBuildReplacementValue(result, converter);
       });
 
-  // Notify the listener that the operation is about to be replaced.
-  if (listener)
-    listener->notifyOperationReplaced(op, replacements);
-
-  // Replace all uses with the new values.
+  // Replace all uses with the new values (if any).
   for (auto [result, newValue] :
        llvm::zip_equal(op->getResults(), replacements))
     if (newValue)
@@ -1265,7 +1292,8 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
     getConfig().unlegalizedOps->erase(op);
 
   // Notify the listener that the operation and its contents are being erased.
-  if (listener)
+  if (auto *listener =
+          dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
     notifyIRErased(listener, *op);
 
   // Do not erase the operation yet. It may still be referenced in `mapping`.
@@ -1273,12 +1301,12 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   op->getBlock()->getOperations().remove(op);
 }
 
-void ReplaceOperationRewrite::rollback() {
+void EraseOperationRewrite::rollback() {
   for (auto result : op->getResults())
     rewriterImpl.mapping.erase({result});
 }
 
-void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
+void EraseOperationRewrite::cleanup(RewriterBase &rewriter) {
   rewriter.eraseOp(op);
 }
 
@@ -1803,122 +1831,41 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
     appendRewrite<MoveOperationRewrite>(op, previous);
 }
 
-/// Given that `fromRange` is about to be replaced with `toRange`, compute
-/// replacement values with the types of `fromRange`.
-static SmallVector<Value>
-getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
-                     const SmallVector<SmallVector<Value>> &toRange,
-                     const TypeConverter *converter) {
-  assert(!impl.config.allowPatternRollback &&
-         "this code path is valid only in 'no rollback' mode");
-  SmallVector<Value> repls;
-  for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
+void ConversionPatternRewriterImpl::replaceAllUsesWith(
+    Value from, ValueRange to, const TypeConverter *converter) {
+  if (!config.allowPatternRollback) {
+    // In "no rollback" mode, IR changes are materialized immediately.
+
     if (from.use_empty()) {
       // The replaced value is dead. No replacement value is needed.
-      repls.push_back(Value());
-      continue;
+      return;
     }
 
+    Value repl;
     if (to.empty()) {
       // The replaced value is dropped. Materialize a replacement value "out of
       // thin air".
-      Value srcMat = impl.buildUnresolvedMaterialization(
+      repl = buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
           /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
           /*outputTypes=*/from.getType(), /*originalType=*/Type(),
           converter)[0];
-      repls.push_back(srcMat);
-      continue;
-    }
-
-    if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
+    } else if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
       // The replacement value already has the correct type. Use it directly.
-      repls.push_back(to[0]);
-      continue;
+      repl = to[0];
+    } else {
+      // The replacement value has the wrong type. Build a source
+      // materialization to the original type.
+      // TODO: This is a bit inefficient. We should try to reuse existing
+      // materializations if possible. This would require an extension of the
+      // `lookupOrDefault` API.
+      repl = buildUnresolvedMaterialization(
+          MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
+          /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
+          /*originalType=*/Type(), converter)[0];
     }
 
-    // The replacement value has the wrong type. Build a source materialization
-    // to the original type.
-    // TODO: This is a bit inefficient. We should try to reuse existing
-    // materializations if possible. This would require an extension of the
-    // `lookupOrDefault` API.
-    Value srcMat = impl.buildUnresolvedMaterialization(
-        MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
-        /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
-        /*originalType=*/Type(), converter)[0];
-    repls.push_back(srcMat);
-  }
-
-  return repls;
-}
-
-void ConversionPatternRewriterImpl::replaceOp(
-    Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
-  assert(newValues.size() == op->getNumResults() &&
-         "incorrect number of replacement values");
-
-  if (!config.allowPatternRollback) {
-    // Pattern rollback is not allowed: materialize all IR changes immediately.
-    SmallVector<Value> repls = getReplacementValues(
-        *this, op->getResults(), newValues, currentTypeConverter);
-    // Update internal data structures, so that there are no dangling pointers
-    // to erased IR.
-    op->walk([&](Operation *op) {
-      erasedOps.insert(op);
-      ignoredOps.remove(op);
-      if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
-        unresolvedMaterializations.erase(castOp);
-        patternMaterializations.erase(castOp);
-      }
-      // The original op will be erased, so remove it from the set of
-      // unlegalized ops.
-      if (config.unlegalizedOps)
-        config.unlegalizedOps->erase(op);
-    });
-    op->walk([&](Block *block) { erasedBlocks.insert(block); });
-    // Replace the op with the replacement values and notify the listener.
-    notifyingRewriter.replaceOp(op, repls);
-    return;
-  }
-
-  assert(!ignoredOps.contains(op) && "operation was already replaced");
-#ifndef NDEBUG
-  for (Value v : op->getResults())
-    assert(!replacedValues.contains(v) &&
-           "attempting to replace a value that was already replaced");
-#endif // NDEBUG
-
-  // Check if replaced op is an unresolved materialization, i.e., an
-  // unrealized_conversion_cast op that was created by the conversion driver.
-  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
-    // Make sure that the user does not mess with unresolved materializations
-    // that were inserted by the conversion driver. We keep track of these
-    // ops in internal data structures.
-    assert(!unresolvedMaterializations.contains(castOp) &&
-           "attempting to replace/erase an unresolved materialization");
-  }
-
-  // Create mappings for each of the new result values.
-  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
-    mapping.map(static_cast<Value>(result), std::move(repl));
-
-  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
-  // Mark this operation and all nested ops as replaced.
-  op->walk([&](Operation *op) { replacedOps.insert(op); });
-}
-
-void ConversionPatternRewriterImpl::replaceAllUsesWith(
-    Value from, ValueRange to, const TypeConverter *converter) {
-  if (!config.allowPatternRollback) {
-    SmallVector<Value> toConv = llvm::to_vector(to);
-    SmallVector<Value> repls =
-        getReplacementValues(*this, from, {toConv}, converter);
-    IRRewriter r(from.getContext());
-    Value repl = repls.front();
-    if (!repl)
-      return;
-
-    performReplaceValue(r, from, repl);
+    performReplaceValue(notifyingRewriter, from, repl);
     return;
   }
 
@@ -1978,6 +1925,39 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
   block->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
+void ConversionPatternRewriterImpl::eraseOp(Operation *op) {
+  LLVM_DEBUG({
+    logger.startLine() << "** Erase   : '" << op->getName() << "'(" << op
+                       << ")\n";
+  });
+
+  if (!config.allowPatternRollback) {
+    // Pattern rollback is not allowed: materialize all IR changes immediately.
+    // Update internal data structures, so that there are no dangling pointers
+    // to erased IR.
+    op->walk([&](Operation *op) {
+      erasedOps.insert(op);
+      ignoredOps.remove(op);
+      if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+        unresolvedMaterializations.erase(castOp);
+        patternMaterializations.erase(castOp);
+      }
+      // The original op will be erased, so remove it from the set of
+      // unlegalized ops.
+      if (config.unlegalizedOps)
+        config.unlegalizedOps->erase(op);
+    });
+    op->walk([&](Block *block) { erasedBlocks.insert(block); });
+    // Replace the op with the replacement values and notify the listener.
+    notifyingRewriter.eraseOp(op);
+    return;
+  }
+
+  appendRewrite<EraseOperationRewrite>(op, currentTypeConverter);
+  // Mark this operation and all nested ops as replaced.
+  op->walk([&](Operation *op) { replacedOps.insert(op); });
+}
+
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
   // If no previous insertion point is provided, the block used to be detached.
@@ -2030,6 +2010,28 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
     appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
 }
 
+void ConversionPatternRewriterImpl::notifyOperationReplaced(
+    Operation *op, Operation *replacement) {
+  if (config.allowPatternRollback) {
+    // In rollback mode, the listener is notified when the rewrite is applied.
+    appendRewrite<ReplaceOperationRewrite>(op, replacement, ValueRange());
+  } else if (config.listener) {
+    // In "no rollback" mode, the listener is always notified immediately.
+    config.listener->notifyOperationReplaced(op, replacement);
+  }
+}
+
+void ConversionPatternRewriterImpl::notifyOperationReplaced(
+    Operation *op, ValueRange replacement) {
+  if (config.allowPatternRollback) {
+    // In rollback mode, the listener is notified when the rewrite is applied.
+    appendRewrite<ReplaceOperationRewrite>(op, nullptr, replacement);
+  } else if (config.listener) {
+    // In "no rollback" mode, the listener is always notified immediately.
+    config.listener->notifyOperationReplaced(op, replacement);
+  }
+}
+
 void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
                                                       Block *dest,
                                                       Block::iterator before) {
@@ -2064,29 +2066,12 @@ const ConversionConfig &ConversionPatternRewriter::getConfig() const {
   return impl->config;
 }
 
-void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
-  assert(op && newOp && "expected non-null op");
-  replaceOp(op, newOp->getResults());
-}
-
-void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
-  assert(op->getNumResults() == newValues.size() &&
-         "incorrect # of replacement values");
-  LLVM_DEBUG({
-    impl->logger.startLine()
-        << "** Replace : '" << op->getName() << "'(" << op << ")\n";
-  });
-
-  // If the current insertion point is before the erased operation, we adjust
-  // the insertion point to be after the operation.
-  if (getInsertionPoint() == op->getIterator())
-    setInsertionPointAfter(op);
-
-  SmallVector<SmallVector<Value>> newVals =
-      llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
-        return v ? SmallVector<Value>{v} : SmallVector<Value>();
-      });
-  impl->replaceOp(op, std::move(newVals));
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<SmallVector<Value>> values) {
+  SmallVector<Value> result;
+  for (const auto &vals : values)
+    llvm::append_range(result, vals);
+  return result;
 }
 
 void ConversionPatternRewriter::replaceOpWithMultiple(
@@ -2098,27 +2083,24 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
         << "** Repl...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/160529


More information about the Mlir-commits mailing list