[Mlir-commits] [flang] [mlir] [mlir][Transforms] Dialect Conversion: Fix folder rollback (PR #150775)

Matthias Springer llvmlistbot at llvm.org
Sat Jul 26 09:54:13 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/150775

>From 45803b6d058941d988d595810f9f034855ca12d1 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 26 Jul 2025 16:08:36 +0000
Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Fix folder rollback

---
 flang/test/Fir/affine-demotion.fir            | 10 +++--
 .../Transforms/Utils/DialectConversion.cpp    | 38 ++++++++++++++-----
 .../test-legalize-type-conversion.mlir        |  2 +-
 3 files changed, 35 insertions(+), 15 deletions(-)

diff --git a/flang/test/Fir/affine-demotion.fir b/flang/test/Fir/affine-demotion.fir
index bdb84be3624cb..a091ce4e5df9c 100644
--- a/flang/test/Fir/affine-demotion.fir
+++ b/flang/test/Fir/affine-demotion.fir
@@ -34,6 +34,8 @@ module  {
   }
 }
 
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 + 1)>
+// CHECK: #[[MAP1:.*]] = affine_map<(d0)[s0, s1, s2] -> (d0 * s2 - s0)
 // CHECK:  func @calc(%[[VAL_0:.*]]: !fir.ref<!fir.array<?xf32>>, %[[VAL_1:.*]]: !fir.ref<!fir.array<?xf32>>, %[[VAL_2:.*]]: !fir.ref<!fir.array<?xf32>>) {
 // CHECK:    %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK:    %[[VAL_4:.*]] = arith.constant 100 : index
@@ -43,8 +45,8 @@ module  {
 // CHECK:    %[[VAL_8:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
 // CHECK:    %[[VAL_9:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
 // CHECK:    %[[VAL_10:.*]] = fir.convert %[[VAL_7]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
-// CHECK:    affine.for %[[VAL_11:.*]] = 1 to 101 {
-// CHECK:      %[[VAL_12:.*]] = affine.apply #map(%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
+// CHECK:    affine.for %[[VAL_11:.*]] = %[[VAL_3]] to #[[MAP]]()[%[[VAL_4]]] {
+// CHECK:      %[[VAL_12:.*]] = affine.apply #[[MAP1]](%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
 // CHECK:      %[[VAL_13:.*]] = fir.coordinate_of %[[VAL_8]], %[[VAL_12]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
 // CHECK:      %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<f32>
 // CHECK:      %[[VAL_15:.*]] = fir.coordinate_of %[[VAL_9]], %[[VAL_12]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
@@ -54,8 +56,8 @@ module  {
 // CHECK:      fir.store %[[VAL_17]] to %[[VAL_18]] : !fir.ref<f32>
 // CHECK:    }
 // CHECK:    %[[VAL_19:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
-// CHECK:    affine.for %[[VAL_20:.*]] = 1 to 101 {
-// CHECK:      %[[VAL_21:.*]] = affine.apply #map(%[[VAL_20]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
+// CHECK:    affine.for %[[VAL_20:.*]] = %[[VAL_3]] to #[[MAP]]()[%[[VAL_4]]] {
+// CHECK:      %[[VAL_21:.*]] = affine.apply #[[MAP1]](%[[VAL_20]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
 // CHECK:      %[[VAL_22:.*]] = fir.coordinate_of %[[VAL_10]], %[[VAL_21]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
 // CHECK:      %[[VAL_23:.*]] = fir.load %[[VAL_22]] : !fir.ref<f32>
 // CHECK:      %[[VAL_24:.*]] = fir.coordinate_of %[[VAL_9]], %[[VAL_21]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index df255cfcf3ec1..3ecd1fd43bf44 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -2216,22 +2217,45 @@ OperationLegalizer::legalizeWithFold(Operation *op,
     rewriterImpl.logger.startLine() << "* Fold {\n";
     rewriterImpl.logger.indent();
   });
-  (void)rewriterImpl;
+
+  // Clear pattern state, so that the next pattern application starts with a
+  // clean slate. (The op/block sets are populated by listener notifications.)
+  auto cleanup = llvm::make_scope_exit([&]() {
+    rewriterImpl.patternNewOps.clear();
+    rewriterImpl.patternModifiedOps.clear();
+    rewriterImpl.patternInsertedBlocks.clear();
+  });
+
+  // Upon failure, undo all changes made by the folder.
+  RewriterState curState = rewriterImpl.getCurrentState();
+  auto undoFolding = [&]() {
+    rewriterImpl.resetState(curState, std::string(op->getName().getStringRef()) + " folder");
+    return failure();
+  };
 
   // Try to fold the operation.
   StringRef opName = op->getName().getStringRef();
   SmallVector<Value, 2> replacementValues;
   SmallVector<Operation *, 2> newOps;
   rewriter.setInsertionPoint(op);
+  rewriter.startOpModification(op);
   if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
+    rewriter.cancelOpModification(op);
     return failure();
   }
+  rewriter.finalizeOpModification(op);
 
   // An empty list of replacement values indicates that the fold was in-place.
   // As the operation changed, a new legalization needs to be attempted.
-  if (replacementValues.empty())
-    return legalize(op, rewriter);
+  if (replacementValues.empty()) {
+    if (succeeded(legalize(op, rewriter)))
+      return success();
+    return undoFolding();
+  }
+
+  // Insert a replacement for 'op' with the folded replacement values.
+  rewriter.replaceOp(op, replacementValues);
 
   // Recursively legalize any new constant operations.
   for (Operation *newOp : newOps) {
@@ -2245,16 +2269,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
             "op '" + opName +
             "' folder rollback of IR modifications requested");
       }
-      // Legalization failed: erase all materialized constants.
-      for (Operation *op : newOps)
-        rewriter.eraseOp(op);
-      return failure();
+      return undoFolding();
     }
   }
 
-  // Insert a replacement for 'op' with the folded replacement values.
-  rewriter.replaceOp(op, replacementValues);
-
   LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
   return success();
 }
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f6378d2..9bffe92b374d5 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -104,8 +104,8 @@ func.func @test_signature_conversion_no_converter() {
   "test.signature_conversion_no_converter"() ({
   // expected-error at below {{failed to legalize unresolved materialization from ('f64') to ('f32') that remained live after conversion}}
   ^bb0(%arg0: f32):
-    "test.type_consumer"(%arg0) : (f32) -> ()
     // expected-note at below{{see existing live user here}}
+    "test.type_consumer"(%arg0) : (f32) -> ()
     "test.return"(%arg0) : (f32) -> ()
   }) : () -> ()
   return



More information about the Mlir-commits mailing list