[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Fix folder rollback (PR #150775)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jul 26 09:17:33 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Folders are almost like patterns: they can modify IR (in-place op modification) and create new IR (constant op materialization of a folded attribute). If the modified op or the newly-created op cannot be legalized, the folding must be rolled back. The previous implementation did not roll back in-place op modifications.

This issue became apparent when moving the `rewriter.replaceOp(op, replacementValues);` function call before the loop nest that legalizes the newly created constant ops (and therefore `replacementValues`). Conceptually, the folded op must be replaced before attempting to legalize the constants because the constant ops may themselves be replaced as part of the legalization process. This happened to work in the current conversion driver, but caused a failure in the One-Shot Dialect Conversion driver, which expects to see the most recent IR at all time.

Various test cases started failing after moving the `replaceOp` call, which pointed to the missing rollback functionality. A common folder-rollback pattern that is exercised by multiple tests cases: A `memref.dim` is folded to `arith.constant`, but `arith.constant` is not marked as legal as per the conversion target, triggering a rollback.


---
Full diff: https://github.com/llvm/llvm-project/pull/150775.diff


2 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+28-10) 
- (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+1-1) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index df255cfcf3ec1..3ecd1fd43bf44 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -2216,22 +2217,45 @@ OperationLegalizer::legalizeWithFold(Operation *op,
     rewriterImpl.logger.startLine() << "* Fold {\n";
     rewriterImpl.logger.indent();
   });
-  (void)rewriterImpl;
+
+  // Clear pattern state, so that the next pattern application starts with a
+  // clean slate. (The op/block sets are populated by listener notifications.)
+  auto cleanup = llvm::make_scope_exit([&]() {
+    rewriterImpl.patternNewOps.clear();
+    rewriterImpl.patternModifiedOps.clear();
+    rewriterImpl.patternInsertedBlocks.clear();
+  });
+
+  // Upon failure, undo all changes made by the folder.
+  RewriterState curState = rewriterImpl.getCurrentState();
+  auto undoFolding = [&]() {
+    rewriterImpl.resetState(curState, std::string(op->getName().getStringRef()) + " folder");
+    return failure();
+  };
 
   // Try to fold the operation.
   StringRef opName = op->getName().getStringRef();
   SmallVector<Value, 2> replacementValues;
   SmallVector<Operation *, 2> newOps;
   rewriter.setInsertionPoint(op);
+  rewriter.startOpModification(op);
   if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
+    rewriter.cancelOpModification(op);
     return failure();
   }
+  rewriter.finalizeOpModification(op);
 
   // An empty list of replacement values indicates that the fold was in-place.
   // As the operation changed, a new legalization needs to be attempted.
-  if (replacementValues.empty())
-    return legalize(op, rewriter);
+  if (replacementValues.empty()) {
+    if (succeeded(legalize(op, rewriter)))
+      return success();
+    return undoFolding();
+  }
+
+  // Insert a replacement for 'op' with the folded replacement values.
+  rewriter.replaceOp(op, replacementValues);
 
   // Recursively legalize any new constant operations.
   for (Operation *newOp : newOps) {
@@ -2245,16 +2269,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
             "op '" + opName +
             "' folder rollback of IR modifications requested");
       }
-      // Legalization failed: erase all materialized constants.
-      for (Operation *op : newOps)
-        rewriter.eraseOp(op);
-      return failure();
+      return undoFolding();
     }
   }
 
-  // Insert a replacement for 'op' with the folded replacement values.
-  rewriter.replaceOp(op, replacementValues);
-
   LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
   return success();
 }
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f6378d2..9bffe92b374d5 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -104,8 +104,8 @@ func.func @test_signature_conversion_no_converter() {
   "test.signature_conversion_no_converter"() ({
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to ('f32') that remained live after conversion}}
   ^bb0(%arg0: f32):
-    "test.type_consumer"(%arg0) : (f32) -> ()
     // expected-note at below{{see existing live user here}}
+    "test.type_consumer"(%arg0) : (f32) -> ()
     "test.return"(%arg0) : (f32) -> ()
   }) : () -> ()
   return

``````````

</details>


https://github.com/llvm/llvm-project/pull/150775


More information about the Mlir-commits mailing list