[Mlir-commits] [mlir] [mlir][Transforms][NFC] Store per-pattern IR modifications in separate state (PR #145319)

Matthias Springer llvmlistbot at llvm.org
Tue Jun 24 00:15:55 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/145319

>From 46bf825dddc524101c5eea3bd03f2b898ed713e8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 23 Jun 2025 12:36:02 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Store per-pattern IR modifications in
 separate state

---
 .../Transforms/Utils/DialectConversion.cpp    | 139 ++++++++++--------
 1 file changed, 75 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7cfe7250d02c3..955a106c21941 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1080,6 +1080,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// to modify/access them is invalid rewriter API usage.
   SetVector<Operation *> replacedOps;
 
+  /// A set of operations that were created by the current pattern.
+  SetVector<Operation *> patternNewOps;
+
+  /// A set of operations that were modified by the current pattern.
+  SetVector<Operation *> patternModifiedOps;
+
+  /// A set of blocks that were inserted (newly-created blocks or moved blocks)
+  /// by the current pattern.
+  SetVector<Block *> patternInsertedBlocks;
+
   /// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
   /// to the corresponding rewrite objects.
   DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
@@ -1571,6 +1581,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
   if (!previous.isSet()) {
     // This is a newly created op.
     appendRewrite<CreateOperationRewrite>(op);
+    patternNewOps.insert(op);
     return;
   }
   Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1655,6 +1666,8 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
         }
       });
 
+  patternInsertedBlocks.insert(block);
+
   if (!previous) {
     // This is a newly created block.
     appendRewrite<CreateBlockRewrite>(block);
@@ -1852,6 +1865,8 @@ void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
   assert(!impl->wasOpReplaced(op) &&
          "attempting to modify a replaced/erased op");
   PatternRewriter::finalizeOpModification(op);
+  impl->patternModifiedOps.insert(op);
+
   // There is nothing to do here, we only need to track the operation at the
   // start of the update.
 #ifndef NDEBUG
@@ -1964,21 +1979,25 @@ class OperationLegalizer {
   /// Legalize the resultant IR after successfully applying the given pattern.
   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
                                       ConversionPatternRewriter &rewriter,
-                                      RewriterState &curState);
+                                      const SetVector<Operation *> &newOps,
+                                      const SetVector<Operation *> &modifiedOps,
+                                      const SetVector<Block *> &insertedBlocks);
 
   /// Legalizes the actions registered during the execution of a pattern.
   LogicalResult
   legalizePatternBlockRewrites(Operation *op,
                                ConversionPatternRewriter &rewriter,
                                ConversionPatternRewriterImpl &impl,
-                               RewriterState &state, RewriterState &newState);
-  LogicalResult legalizePatternCreatedOperations(
-      ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
-      RewriterState &state, RewriterState &newState);
-  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
-                                           ConversionPatternRewriterImpl &impl,
-                                           RewriterState &state,
-                                           RewriterState &newState);
+                               const SetVector<Block *> &insertedBlocks,
+                               const SetVector<Operation *> &newOps);
+  LogicalResult
+  legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
+                                   ConversionPatternRewriterImpl &impl,
+                                   const SetVector<Operation *> &newOps);
+  LogicalResult
+  legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
+                             ConversionPatternRewriterImpl &impl,
+                             const SetVector<Operation *> &modifiedOps);
 
   //===--------------------------------------------------------------------===//
   // Cost Model
@@ -2131,6 +2150,15 @@ OperationLegalizer::legalize(Operation *op,
   return failure();
 }
 
+/// Helper function that moves and returns the given object. Also resets the
+/// original object, so that it is in a valid, empty state again.
+template <typename T>
+static T moveAndReset(T &obj) {
+  T result = std::move(obj);
+  obj = T();
+  return result;
+}
+
 LogicalResult
 OperationLegalizer::legalizeWithFold(Operation *op,
                                      ConversionPatternRewriter &rewriter) {
@@ -2192,6 +2220,9 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
   RewriterState curState = rewriterImpl.getCurrentState();
   auto onFailure = [&](const Pattern &pattern) {
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+    rewriterImpl.patternNewOps.clear();
+    rewriterImpl.patternModifiedOps.clear();
+    rewriterImpl.patternInsertedBlocks.clear();
     LLVM_DEBUG({
       logFailure(rewriterImpl.logger, "pattern failed to match");
       if (rewriterImpl.config.notifyCallback) {
@@ -2212,7 +2243,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
   // successfully applied.
   auto onSuccess = [&](const Pattern &pattern) {
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
-    auto result = legalizePatternResult(op, pattern, rewriter, curState);
+    SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
+    SetVector<Operation *> modifiedOps =
+        moveAndReset(rewriterImpl.patternModifiedOps);
+    SetVector<Block *> insertedBlocks =
+        moveAndReset(rewriterImpl.patternInsertedBlocks);
+    auto result = legalizePatternResult(op, pattern, rewriter, newOps,
+                                        modifiedOps, insertedBlocks);
     appliedPatterns.erase(&pattern);
     if (failed(result)) {
       if (!rewriterImpl.config.allowPatternRollback)
@@ -2253,10 +2290,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
   return true;
 }
 
-LogicalResult
-OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
-                                          ConversionPatternRewriter &rewriter,
-                                          RewriterState &curState) {
+LogicalResult OperationLegalizer::legalizePatternResult(
+    Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
+    const SetVector<Operation *> &newOps,
+    const SetVector<Operation *> &modifiedOps,
+    const SetVector<Block *> &insertedBlocks) {
   auto &impl = rewriter.getImpl();
   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
 
@@ -2274,12 +2312,10 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Legalize each of the actions registered during application.
-  RewriterState newState = impl.getCurrentState();
-  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
-                                          newState)) ||
-      failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
-      failed(legalizePatternCreatedOperations(rewriter, impl, curState,
-                                              newState))) {
+  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
+                                          newOps)) ||
+      failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
+      failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
     return failure();
   }
 
@@ -2289,20 +2325,14 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
 
 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     Operation *op, ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &impl, RewriterState &state,
-    RewriterState &newState) {
-  SmallPtrSet<Operation *, 16> operationsToIgnore;
+    ConversionPatternRewriterImpl &impl,
+    const SetVector<Block *> &insertedBlocks,
+    const SetVector<Operation *> &newOps) {
+  SmallPtrSet<Operation *, 16> alreadyLegalized;
 
   // If the pattern moved or created any blocks, make sure the types of block
   // arguments get legalized.
-  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
-    BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
-    if (!rewrite)
-      continue;
-    Block *block = rewrite->getBlock();
-    if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
-            ReplaceBlockArgRewrite, InlineBlockRewrite>(rewrite))
-      continue;
+  for (Block *block : insertedBlocks) {
     // Only check blocks outside of the current operation.
     Operation *parentOp = block->getParentOp();
     if (!parentOp || parentOp == op || block->getNumArguments() == 0)
@@ -2322,41 +2352,26 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
       continue;
     }
 
-    // Otherwise, check that this operation isn't one generated by this pattern.
-    // This is because we will attempt to legalize the parent operation, and
-    // blocks in regions created by this pattern will already be legalized later
-    // on. If we haven't built the set yet, build it now.
-    if (operationsToIgnore.empty()) {
-      for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
-           ++i) {
-        auto *createOp =
-            dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
-        if (!createOp)
-          continue;
-        operationsToIgnore.insert(createOp->getOperation());
+    // Otherwise, try to legalize the parent operation if it was not generated
+    // by this pattern. This is because we will attempt to legalize the parent
+    // operation, and blocks in regions created by this pattern will already be
+    // legalized later on.
+    if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
+      if (failed(legalize(parentOp, rewriter))) {
+        LLVM_DEBUG(logFailure(
+            impl.logger, "operation '{0}'({1}) became illegal after rewrite",
+            parentOp->getName(), parentOp));
+        return failure();
       }
     }
-
-    // If this operation should be considered for re-legalization, try it.
-    if (operationsToIgnore.insert(parentOp).second &&
-        failed(legalize(parentOp, rewriter))) {
-      LLVM_DEBUG(logFailure(impl.logger,
-                            "operation '{0}'({1}) became illegal after rewrite",
-                            parentOp->getName(), parentOp));
-      return failure();
-    }
   }
   return success();
 }
 
 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
-    RewriterState &state, RewriterState &newState) {
-  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
-    auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
-    if (!createOp)
-      continue;
-    Operation *op = createOp->getOperation();
+    const SetVector<Operation *> &newOps) {
+  for (Operation *op : newOps) {
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(impl.logger,
                             "failed to legalize generated operation '{0}'({1})",
@@ -2369,12 +2384,8 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
 
 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
-    RewriterState &state, RewriterState &newState) {
-  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
-    auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
-    if (!rewrite)
-      continue;
-    Operation *op = rewrite->getOperation();
+    const SetVector<Operation *> &modifiedOps) {
+  for (Operation *op : modifiedOps) {
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(
           impl.logger, "failed to legalize operation updated in-place '{0}'",



More information about the Mlir-commits mailing list