[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Fix folder rollback (PR #150775)

Matthias Springer llvmlistbot at llvm.org
Sat Jul 26 09:17:05 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/150775

Folders are almost like patterns: they can modify IR (in-place op modification) and create new IR (constant op materialization of a folded attribute). If the modified op or the newly-created op cannot be legalized, the folding must be rolled back. The previous implementation did not roll back in-place op modifications.

This issue became apparent when moving the `rewriter.replaceOp(op, replacementValues);` function call before the loop nest that legalizes the newly created constant ops (and therefore `replacementValues`). Conceptually, the folded op must be replaced before attempting to legalize the constants because the constant ops may themselves be replaced as part of the legalization process. This happened to work in the current conversion driver, but caused a failure in the One-Shot Dialect Conversion driver, which expects to see the most recent IR at all time.

Various test cases started failing after moving the `replaceOp` call, which pointed to the missing rollback functionality. A common folder-rollback pattern that is exercised by multiple tests cases: A `memref.dim` is folded to `arith.constant`, but `arith.constant` is not marked as legal as per the conversion target, triggering a rollback.


>From 1edd3fdd27645422cde8330da4d010c5d2c35fba Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 26 Jul 2025 16:08:36 +0000
Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Fix folder rollback

---
 .../Transforms/Utils/DialectConversion.cpp    | 38 ++++++++++++++-----
 .../test-legalize-type-conversion.mlir        |  2 +-
 2 files changed, 29 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index df255cfcf3ec1..3ecd1fd43bf44 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -2216,22 +2217,45 @@ OperationLegalizer::legalizeWithFold(Operation *op,
     rewriterImpl.logger.startLine() << "* Fold {\n";
     rewriterImpl.logger.indent();
   });
-  (void)rewriterImpl;
+
+  // Clear pattern state, so that the next pattern application starts with a
+  // clean slate. (The op/block sets are populated by listener notifications.)
+  auto cleanup = llvm::make_scope_exit([&]() {
+    rewriterImpl.patternNewOps.clear();
+    rewriterImpl.patternModifiedOps.clear();
+    rewriterImpl.patternInsertedBlocks.clear();
+  });
+
+  // Upon failure, undo all changes made by the folder.
+  RewriterState curState = rewriterImpl.getCurrentState();
+  auto undoFolding = [&]() {
+    rewriterImpl.resetState(curState, std::string(op->getName().getStringRef()) + " folder");
+    return failure();
+  };
 
   // Try to fold the operation.
   StringRef opName = op->getName().getStringRef();
   SmallVector<Value, 2> replacementValues;
   SmallVector<Operation *, 2> newOps;
   rewriter.setInsertionPoint(op);
+  rewriter.startOpModification(op);
   if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
+    rewriter.cancelOpModification(op);
     return failure();
   }
+  rewriter.finalizeOpModification(op);
 
   // An empty list of replacement values indicates that the fold was in-place.
   // As the operation changed, a new legalization needs to be attempted.
-  if (replacementValues.empty())
-    return legalize(op, rewriter);
+  if (replacementValues.empty()) {
+    if (succeeded(legalize(op, rewriter)))
+      return success();
+    return undoFolding();
+  }
+
+  // Insert a replacement for 'op' with the folded replacement values.
+  rewriter.replaceOp(op, replacementValues);
 
   // Recursively legalize any new constant operations.
   for (Operation *newOp : newOps) {
@@ -2245,16 +2269,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
             "op '" + opName +
             "' folder rollback of IR modifications requested");
       }
-      // Legalization failed: erase all materialized constants.
-      for (Operation *op : newOps)
-        rewriter.eraseOp(op);
-      return failure();
+      return undoFolding();
     }
   }
 
-  // Insert a replacement for 'op' with the folded replacement values.
-  rewriter.replaceOp(op, replacementValues);
-
   LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
   return success();
 }
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f6378d2..9bffe92b374d5 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -104,8 +104,8 @@ func.func @test_signature_conversion_no_converter() {
   "test.signature_conversion_no_converter"() ({
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to ('f32') that remained live after conversion}}
   ^bb0(%arg0: f32):
-    "test.type_consumer"(%arg0) : (f32) -> ()
     // expected-note at below{{see existing live user here}}
+    "test.type_consumer"(%arg0) : (f32) -> ()
     "test.return"(%arg0) : (f32) -> ()
   }) : () -> ()
   return



More information about the Mlir-commits mailing list