[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Fix folder rollback (PR #150775)

Matthias Springer llvmlistbot at llvm.org
Sat Jul 26 10:01:16 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/150775

>From 5d6505301d6299f5edb310759e8b48c02b11f42b Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 26 Jul 2025 16:08:36 +0000
Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Fix folder rollback

---
 .../Transforms/Utils/DialectConversion.cpp    | 26 ++++++++++++++-----
 .../test-legalize-type-conversion.mlir        |  2 +-
 2 files changed, 20 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index df255cfcf3ec1..0ad48ca884f0d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -2216,23 +2217,39 @@ OperationLegalizer::legalizeWithFold(Operation *op,
     rewriterImpl.logger.startLine() << "* Fold {\n";
     rewriterImpl.logger.indent();
   });
-  (void)rewriterImpl;
+
+  // Clear pattern state, so that the next pattern application starts with a
+  // clean slate. (The op/block sets are populated by listener notifications.)
+  auto cleanup = llvm::make_scope_exit([&]() {
+    rewriterImpl.patternNewOps.clear();
+    rewriterImpl.patternModifiedOps.clear();
+    rewriterImpl.patternInsertedBlocks.clear();
+  });
+
+  // Upon failure, undo all changes made by the folder.
+  RewriterState curState = rewriterImpl.getCurrentState();
 
   // Try to fold the operation.
   StringRef opName = op->getName().getStringRef();
   SmallVector<Value, 2> replacementValues;
   SmallVector<Operation *, 2> newOps;
   rewriter.setInsertionPoint(op);
+  rewriter.startOpModification(op);
   if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
+    rewriter.cancelOpModification(op);
     return failure();
   }
+  rewriter.finalizeOpModification(op);
 
   // An empty list of replacement values indicates that the fold was in-place.
   // As the operation changed, a new legalization needs to be attempted.
   if (replacementValues.empty())
     return legalize(op, rewriter);
 
+  // Insert a replacement for 'op' with the folded replacement values.
+  rewriter.replaceOp(op, replacementValues);
+
   // Recursively legalize any new constant operations.
   for (Operation *newOp : newOps) {
     if (failed(legalize(newOp, rewriter))) {
@@ -2245,16 +2262,11 @@ OperationLegalizer::legalizeWithFold(Operation *op,
             "op '" + opName +
             "' folder rollback of IR modifications requested");
       }
-      // Legalization failed: erase all materialized constants.
-      for (Operation *op : newOps)
-        rewriter.eraseOp(op);
+      rewriterImpl.resetState(curState, std::string(op->getName().getStringRef()) + " folder");
       return failure();
     }
   }
 
-  // Insert a replacement for 'op' with the folded replacement values.
-  rewriter.replaceOp(op, replacementValues);
-
   LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
   return success();
 }
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f6378d2..9bffe92b374d5 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -104,8 +104,8 @@ func.func @test_signature_conversion_no_converter() {
   "test.signature_conversion_no_converter"() ({
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to ('f32') that remained live after conversion}}
   ^bb0(%arg0: f32):
-    "test.type_consumer"(%arg0) : (f32) -> ()
     // expected-note at below{{see existing live user here}}
+    "test.type_consumer"(%arg0) : (f32) -> ()
     "test.return"(%arg0) : (f32) -> ()
   }) : () -> ()
   return



More information about the Mlir-commits mailing list