[Mlir-commits] [mlir] [mlir][IR] Make `RewriterBase::replaceOp` non-virtual (PR #160529)

Matthias Springer llvmlistbot at llvm.org
Wed Sep 24 06:56:21 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/160529

`RewriterBase::replaceOp` used to be virtual, so that `ConversionPatternRewriter` can override the function. This is no longer necessary: both `replaceAllUsesWith` and `eraseOp` are virtual, and `replaceOp` is a combination of these two functions.

Implementation details in the dialect conversion:
* `ReplaceOperationRewrite` is now just a placeholder IRRewrite, which notifies the listener. (It used to be called when erasing an op. This was treated as replacing an op with only null Values.)
* A new `EraseOperationRewrite` was added, which is used when erasing an operation. In rollback mode, an operation can be erased when it still has uses; in that case, source materializations must be created out-of-thin air. This is done with `EraseOperationRewrite::commit`. We used to do this during `ReplaceOperationRewrite::commit`.
* In "no rollback" mode, erasing an operation that still has uses is no longer allowed. That's why the `circular_mapping` test case is moved to a different file.
* The `getReplacementValues` helper function is moved into `ConversionPatternRewriterImpl::replaceAllUsesWith`. (The loop is no longer needed.)


>From 98a9d90bcd9f9d634c95d0e604b4f2a5e2fd4e60 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 12 Sep 2025 09:35:51 +0000
Subject: [PATCH] proto

---
 mlir/include/mlir/IR/PatternMatch.h           |   4 +-
 .../mlir/Transforms/DialectConversion.h       |  16 +-
 .../Transforms/Utils/DialectConversion.cpp    | 306 +++++++++---------
 .../Transforms/test-legalizer-rollback.mlir   |  27 ++
 mlir/test/Transforms/test-legalizer.mlir      |  27 +-
 5 files changed, 175 insertions(+), 205 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 576481a6e7215..208923cc24224 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -515,12 +515,12 @@ class RewriterBase : public OpBuilder {
   /// Replace the results of the given (original) operation with the specified
   /// list of values (replacements). The result types of the given op and the
   /// replacements must match. The original op is erased.
-  virtual void replaceOp(Operation *op, ValueRange newValues);
+  void replaceOp(Operation *op, ValueRange newValues);
 
   /// Replace the results of the given (original) operation with the specified
   /// new op (replacement). The result types of the two ops must match. The
   /// original op is erased.
-  virtual void replaceOp(Operation *op, Operation *newOp);
+  void replaceOp(Operation *op, Operation *newOp);
 
   /// Replace the results of the given (original) op with a new op that is
   /// created without verification (replacement). The result values of the two
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index ed7e2a08ebfd9..6c5e212058d17 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -900,7 +900,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// RewriterBase APIs, (3) may be removed in the future.
   void replaceAllUsesWith(Value from, ValueRange to);
   void replaceAllUsesWith(Value from, Value to) override {
-    replaceAllUsesWith(from, ValueRange{to});
+    replaceAllUsesWith(from, to ? ValueRange{to} : ValueRange{});
   }
 
   /// Return the converted value of 'key' with a type defined by the type
@@ -923,20 +923,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// patterns even if a failure is encountered during the rewrite step.
   bool canRecoverFromRewriteFailure() const override { return true; }
 
-  /// Replace the given operation with the new values. The number of op results
-  /// and replacement values must match. The types may differ: the dialect
-  /// conversion driver will reconcile any surviving type mismatches at the end
-  /// of the conversion process with source materializations. The given
-  /// operation is erased.
-  void replaceOp(Operation *op, ValueRange newValues) override;
-
-  /// Replace the given operation with the results of the new op. The number of
-  /// op results must match. The types may differ: the dialect conversion
-  /// driver will reconcile any surviving type mismatches at the end of the
-  /// conversion process with source materializations. The original operation
-  /// is erased.
-  void replaceOp(Operation *op, Operation *newOp) override;
-
   /// Replace the given operation with the new value ranges. The number of op
   /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op,
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bf0136b39e03c..1f3a8132f47be 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -67,7 +67,7 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
 
 /// Helper function that computes an insertion point where the given values are
 /// defined and can be used without a dominance violation.
-static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+static OpBuilder::InsertPoint computeInsertPoint(ValueRange vals) {
   assert(!vals.empty() && "expected at least one value");
   DominanceInfo domInfo;
   OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
@@ -281,6 +281,7 @@ class IRRewrite {
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
+    EraseOperation,
     CreateOperation,
     UnresolvedMaterialization,
     // Value rewrites
@@ -720,16 +721,13 @@ class ModifyOperationRewrite : public OperationRewrite {
   void *propertiesStorage = nullptr;
 };
 
-/// Replacing an operation. Erasing an operation is treated as a special case
-/// with "null" replacements. This rewrite is not immediately reflected in the
-/// IR. An internal IR mapping is updated, but values are not replaced and the
-/// original op is not erased until the rewrite is committed.
 class ReplaceOperationRewrite : public OperationRewrite {
 public:
   ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                          Operation *op, const TypeConverter *converter)
+                          Operation *op, Operation *replOp,
+                          ValueRange replValues)
       : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
-        converter(converter) {}
+        replOp(replOp), replValues(replValues) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::ReplaceOperation;
@@ -737,6 +735,28 @@ class ReplaceOperationRewrite : public OperationRewrite {
 
   void commit(RewriterBase &rewriter) override;
 
+  void rollback() override {}
+
+private:
+  Operation *replOp;
+  ValueRange replValues;
+};
+
+/// Erasing an operation. This rewrite is not immediately reflected in the
+/// IR. The original op is not erased until the rewrite is committed.
+class EraseOperationRewrite : public OperationRewrite {
+public:
+  EraseOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                        Operation *op, const TypeConverter *converter)
+      : OperationRewrite(Kind::EraseOperation, rewriterImpl, op),
+        converter(converter) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::EraseOperation;
+  }
+
+  void commit(RewriterBase &rewriter) override;
+
   void rollback() override;
 
   void cleanup(RewriterBase &rewriter) override;
@@ -948,15 +968,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       Block *block, const TypeConverter *converter,
       TypeConverter::SignatureConversion &signatureConversion);
 
-  /// Replace the results of the given operation with the given values and
-  /// erase the operation.
-  ///
-  /// There can be multiple replacement values for each result (1:N
-  /// replacement). If the replacement values are empty, the respective result
-  /// is dropped and a source materialization is built if the result still has
-  /// uses.
-  void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
-
   /// Replace the uses of the given value with the given values. The specified
   /// converter is used to build materializations (if necessary).
   void replaceAllUsesWith(Value from, ValueRange to,
@@ -965,6 +976,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Erase the given block and its contents.
   void eraseBlock(Block *block);
 
+  /// Erase the given operation and its contents.
+  void eraseOp(Operation *op);
+
   /// Inline the source block into the destination block before the given
   /// iterator.
   void inlineBlockBefore(Block *source, Block *dest, Block::iterator before);
@@ -1016,6 +1030,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   void notifyBlockInserted(Block *block, Region *previous,
                            Region::iterator previousIt) override;
 
+  /// Notifies that an operation is about to be replaced with another operation.
+  void notifyOperationReplaced(Operation *op, Operation *replacement) override;
+
+  /// Notifies that an operation is about to be replaced with a range of values.
+  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
+
   /// Notifies that a pattern match failed for the given reason.
   void
   notifyMatchFailure(Location loc,
@@ -1242,18 +1262,25 @@ void ReplaceValueRewrite::rollback() {
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener =
       dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener());
+  if (!listener)
+    return;
 
+  // Notify the listener that the operation is about to be replaced.
+  if (replOp) {
+    listener->notifyOperationReplaced(op, replOp);
+  } else {
+    listener->notifyOperationReplaced(op, replValues);
+  }
+}
+
+void EraseOperationRewrite::commit(RewriterBase &rewriter) {
   // Compute replacement values.
   SmallVector<Value> replacements =
       llvm::map_to_vector(op->getResults(), [&](OpResult result) {
         return rewriterImpl.findOrBuildReplacementValue(result, converter);
       });
 
-  // Notify the listener that the operation is about to be replaced.
-  if (listener)
-    listener->notifyOperationReplaced(op, replacements);
-
-  // Replace all uses with the new values.
+  // Replace all uses with the new values (if any).
   for (auto [result, newValue] :
        llvm::zip_equal(op->getResults(), replacements))
     if (newValue)
@@ -1265,7 +1292,8 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
     getConfig().unlegalizedOps->erase(op);
 
   // Notify the listener that the operation and its contents are being erased.
-  if (listener)
+  if (auto *listener =
+          dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
     notifyIRErased(listener, *op);
 
   // Do not erase the operation yet. It may still be referenced in `mapping`.
@@ -1273,12 +1301,12 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   op->getBlock()->getOperations().remove(op);
 }
 
-void ReplaceOperationRewrite::rollback() {
+void EraseOperationRewrite::rollback() {
   for (auto result : op->getResults())
     rewriterImpl.mapping.erase({result});
 }
 
-void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
+void EraseOperationRewrite::cleanup(RewriterBase &rewriter) {
   rewriter.eraseOp(op);
 }
 
@@ -1803,122 +1831,41 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
     appendRewrite<MoveOperationRewrite>(op, previous);
 }
 
-/// Given that `fromRange` is about to be replaced with `toRange`, compute
-/// replacement values with the types of `fromRange`.
-static SmallVector<Value>
-getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
-                     const SmallVector<SmallVector<Value>> &toRange,
-                     const TypeConverter *converter) {
-  assert(!impl.config.allowPatternRollback &&
-         "this code path is valid only in 'no rollback' mode");
-  SmallVector<Value> repls;
-  for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
+void ConversionPatternRewriterImpl::replaceAllUsesWith(
+    Value from, ValueRange to, const TypeConverter *converter) {
+  if (!config.allowPatternRollback) {
+    // In "no rollback" mode, IR changes are materialized immediately.
+
     if (from.use_empty()) {
       // The replaced value is dead. No replacement value is needed.
-      repls.push_back(Value());
-      continue;
+      return;
     }
 
+    Value repl;
     if (to.empty()) {
       // The replaced value is dropped. Materialize a replacement value "out of
       // thin air".
-      Value srcMat = impl.buildUnresolvedMaterialization(
+      repl = buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
           /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
           /*outputTypes=*/from.getType(), /*originalType=*/Type(),
           converter)[0];
-      repls.push_back(srcMat);
-      continue;
-    }
-
-    if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
+    } else if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
       // The replacement value already has the correct type. Use it directly.
-      repls.push_back(to[0]);
-      continue;
+      repl = to[0];
+    } else {
+      // The replacement value has the wrong type. Build a source
+      // materialization to the original type.
+      // TODO: This is a bit inefficient. We should try to reuse existing
+      // materializations if possible. This would require an extension of the
+      // `lookupOrDefault` API.
+      repl = buildUnresolvedMaterialization(
+          MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
+          /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
+          /*originalType=*/Type(), converter)[0];
     }
 
-    // The replacement value has the wrong type. Build a source materialization
-    // to the original type.
-    // TODO: This is a bit inefficient. We should try to reuse existing
-    // materializations if possible. This would require an extension of the
-    // `lookupOrDefault` API.
-    Value srcMat = impl.buildUnresolvedMaterialization(
-        MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
-        /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
-        /*originalType=*/Type(), converter)[0];
-    repls.push_back(srcMat);
-  }
-
-  return repls;
-}
-
-void ConversionPatternRewriterImpl::replaceOp(
-    Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
-  assert(newValues.size() == op->getNumResults() &&
-         "incorrect number of replacement values");
-
-  if (!config.allowPatternRollback) {
-    // Pattern rollback is not allowed: materialize all IR changes immediately.
-    SmallVector<Value> repls = getReplacementValues(
-        *this, op->getResults(), newValues, currentTypeConverter);
-    // Update internal data structures, so that there are no dangling pointers
-    // to erased IR.
-    op->walk([&](Operation *op) {
-      erasedOps.insert(op);
-      ignoredOps.remove(op);
-      if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
-        unresolvedMaterializations.erase(castOp);
-        patternMaterializations.erase(castOp);
-      }
-      // The original op will be erased, so remove it from the set of
-      // unlegalized ops.
-      if (config.unlegalizedOps)
-        config.unlegalizedOps->erase(op);
-    });
-    op->walk([&](Block *block) { erasedBlocks.insert(block); });
-    // Replace the op with the replacement values and notify the listener.
-    notifyingRewriter.replaceOp(op, repls);
-    return;
-  }
-
-  assert(!ignoredOps.contains(op) && "operation was already replaced");
-#ifndef NDEBUG
-  for (Value v : op->getResults())
-    assert(!replacedValues.contains(v) &&
-           "attempting to replace a value that was already replaced");
-#endif // NDEBUG
-
-  // Check if replaced op is an unresolved materialization, i.e., an
-  // unrealized_conversion_cast op that was created by the conversion driver.
-  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
-    // Make sure that the user does not mess with unresolved materializations
-    // that were inserted by the conversion driver. We keep track of these
-    // ops in internal data structures.
-    assert(!unresolvedMaterializations.contains(castOp) &&
-           "attempting to replace/erase an unresolved materialization");
-  }
-
-  // Create mappings for each of the new result values.
-  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults()))
-    mapping.map(static_cast<Value>(result), std::move(repl));
-
-  appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
-  // Mark this operation and all nested ops as replaced.
-  op->walk([&](Operation *op) { replacedOps.insert(op); });
-}
-
-void ConversionPatternRewriterImpl::replaceAllUsesWith(
-    Value from, ValueRange to, const TypeConverter *converter) {
-  if (!config.allowPatternRollback) {
-    SmallVector<Value> toConv = llvm::to_vector(to);
-    SmallVector<Value> repls =
-        getReplacementValues(*this, from, {toConv}, converter);
-    IRRewriter r(from.getContext());
-    Value repl = repls.front();
-    if (!repl)
-      return;
-
-    performReplaceValue(r, from, repl);
+    performReplaceValue(notifyingRewriter, from, repl);
     return;
   }
 
@@ -1978,6 +1925,39 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
   block->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
+void ConversionPatternRewriterImpl::eraseOp(Operation *op) {
+  LLVM_DEBUG({
+    logger.startLine() << "** Erase   : '" << op->getName() << "'(" << op
+                       << ")\n";
+  });
+
+  if (!config.allowPatternRollback) {
+    // Pattern rollback is not allowed: materialize all IR changes immediately.
+    // Update internal data structures, so that there are no dangling pointers
+    // to erased IR.
+    op->walk([&](Operation *op) {
+      erasedOps.insert(op);
+      ignoredOps.remove(op);
+      if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+        unresolvedMaterializations.erase(castOp);
+        patternMaterializations.erase(castOp);
+      }
+      // The original op will be erased, so remove it from the set of
+      // unlegalized ops.
+      if (config.unlegalizedOps)
+        config.unlegalizedOps->erase(op);
+    });
+    op->walk([&](Block *block) { erasedBlocks.insert(block); });
+    // Replace the op with the replacement values and notify the listener.
+    notifyingRewriter.eraseOp(op);
+    return;
+  }
+
+  appendRewrite<EraseOperationRewrite>(op, currentTypeConverter);
+  // Mark this operation and all nested ops as replaced.
+  op->walk([&](Operation *op) { replacedOps.insert(op); });
+}
+
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
   // If no previous insertion point is provided, the block used to be detached.
@@ -2030,6 +2010,28 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
     appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
 }
 
+void ConversionPatternRewriterImpl::notifyOperationReplaced(
+    Operation *op, Operation *replacement) {
+  if (config.allowPatternRollback) {
+    // In rollback mode, the listener is notified when the rewrite is applied.
+    appendRewrite<ReplaceOperationRewrite>(op, replacement, ValueRange());
+  } else if (config.listener) {
+    // In "no rollback" mode, the listener is always notified immediately.
+    config.listener->notifyOperationReplaced(op, replacement);
+  }
+}
+
+void ConversionPatternRewriterImpl::notifyOperationReplaced(
+    Operation *op, ValueRange replacement) {
+  if (config.allowPatternRollback) {
+    // In rollback mode, the listener is notified when the rewrite is applied.
+    appendRewrite<ReplaceOperationRewrite>(op, nullptr, replacement);
+  } else if (config.listener) {
+    // In "no rollback" mode, the listener is always notified immediately.
+    config.listener->notifyOperationReplaced(op, replacement);
+  }
+}
+
 void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
                                                       Block *dest,
                                                       Block::iterator before) {
@@ -2064,29 +2066,12 @@ const ConversionConfig &ConversionPatternRewriter::getConfig() const {
   return impl->config;
 }
 
-void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
-  assert(op && newOp && "expected non-null op");
-  replaceOp(op, newOp->getResults());
-}
-
-void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
-  assert(op->getNumResults() == newValues.size() &&
-         "incorrect # of replacement values");
-  LLVM_DEBUG({
-    impl->logger.startLine()
-        << "** Replace : '" << op->getName() << "'(" << op << ")\n";
-  });
-
-  // If the current insertion point is before the erased operation, we adjust
-  // the insertion point to be after the operation.
-  if (getInsertionPoint() == op->getIterator())
-    setInsertionPointAfter(op);
-
-  SmallVector<SmallVector<Value>> newVals =
-      llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
-        return v ? SmallVector<Value>{v} : SmallVector<Value>();
-      });
-  impl->replaceOp(op, std::move(newVals));
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<SmallVector<Value>> values) {
+  SmallVector<Value> result;
+  for (const auto &vals : values)
+    llvm::append_range(result, vals);
+  return result;
 }
 
 void ConversionPatternRewriter::replaceOpWithMultiple(
@@ -2098,27 +2083,24 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
 
-  // If the current insertion point is before the erased operation, we adjust
-  // the insertion point to be after the operation.
-  if (getInsertionPoint() == op->getIterator())
-    setInsertionPointAfter(op);
+  // Notify the listener that the operation is about to be replaced.
+  impl->notifyOperationReplaced(op, flattenValues(newValues));
+
+  // Replace all uses of the operation's results with the new values.
+  for (auto [result, newValue] : llvm::zip(op->getResults(), newValues))
+    replaceAllUsesWith(result, newValue);
 
-  impl->replaceOp(op, std::move(newValues));
+  // Erase the operation.
+  eraseOp(op);
 }
 
 void ConversionPatternRewriter::eraseOp(Operation *op) {
-  LLVM_DEBUG({
-    impl->logger.startLine()
-        << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
-  });
-
   // If the current insertion point is before the erased operation, we adjust
   // the insertion point to be after the operation.
   if (getInsertionPoint() == op->getIterator())
     setInsertionPointAfter(op);
 
-  SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
-  impl->replaceOp(op, std::move(nullRepls));
+  impl->eraseOp(op);
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {
@@ -2802,7 +2784,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
   // Check that the root was either replaced or updated in place.
   auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
   auto replacedRoot = [&] {
-    return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
+    return hasRewrite<EraseOperationRewrite>(newRewrites, op);
   };
   auto updatedRootInPlace = [&] {
     return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir
index 71e11782e14b0..077de16fe7f77 100644
--- a/mlir/test/Transforms/test-legalizer-rollback.mlir
+++ b/mlir/test/Transforms/test-legalizer-rollback.mlir
@@ -163,3 +163,30 @@ func.func @create_unregistered_op_in_pattern() -> i32 {
   "test.return"(%0) : (i32) -> ()
 }
 }
+
+// -----
+
+// This test cannot run in "no rollback" mode because test.erase_op is
+// erased while it still has uses.
+
+// CHECK: notifyOperationErased: test.dummy_op_lvl_2
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.dummy_op_lvl_1
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.erase_op
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
+// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid
+
+// CHECK-LABEL: func @circular_mapping()
+//  CHECK-NEXT:   "test.valid"() : () -> ()
+func.func @circular_mapping() {
+  // Regression test that used to crash due to circular
+  // unrealized_conversion_cast ops. 
+  %0 = "test.erase_op"() ({
+    "test.dummy_op_lvl_1"() ({
+      "test.dummy_op_lvl_2"() : () -> ()
+    }) : () -> ()
+  }): () -> (i64)
+  "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
+}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 94c5bb4e93b06..0c909b18153e7 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -79,7 +79,7 @@ func.func @remap_call_1_to_1(%arg0: i64) {
 // CHECK-NEXT: notifyOperationInserted: test.return
 
 // The old block is erased.
-// CHECK-NEXT: notifyBlockErased
+// CHECK: notifyBlockErased
 
 // The function op gets a new type attribute.
 // CHECK-NEXT: notifyOperationModified: func.func
@@ -371,31 +371,6 @@ func.func @convert_detached_signature() {
 
 // -----
 
-// CHECK: notifyOperationReplaced: test.erase_op
-// CHECK: notifyOperationErased: test.dummy_op_lvl_2
-// CHECK: notifyBlockErased
-// CHECK: notifyOperationErased: test.dummy_op_lvl_1
-// CHECK: notifyBlockErased
-// CHECK: notifyOperationErased: test.erase_op
-// CHECK: notifyOperationInserted: test.valid, was unlinked
-// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
-// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid
-
-// CHECK-LABEL: func @circular_mapping()
-//  CHECK-NEXT:   "test.valid"() : () -> ()
-func.func @circular_mapping() {
-  // Regression test that used to crash due to circular
-  // unrealized_conversion_cast ops. 
-  %0 = "test.erase_op"() ({
-    "test.dummy_op_lvl_1"() ({
-      "test.dummy_op_lvl_2"() : () -> ()
-    }) : () -> ()
-  }): () -> (i64)
-  "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
-}
-
-// -----
-
 // CHECK-LABEL: func @test_duplicate_block_arg()
 //       CHECK:   test.convert_block_args  is_legal duplicate {
 //       CHECK:   ^{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64):



More information about the Mlir-commits mailing list