[Mlir-commits] [flang] [mlir] [mlir][Transforms] Dialect Conversion: Fix folder rollback (PR #150775)
Matthias Springer
llvmlistbot at llvm.org
Sat Jul 26 09:54:13 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/150775
>From 45803b6d058941d988d595810f9f034855ca12d1 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 26 Jul 2025 16:08:36 +0000
Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Fix folder rollback
---
flang/test/Fir/affine-demotion.fir | 10 +++--
.../Transforms/Utils/DialectConversion.cpp | 38 ++++++++++++++-----
.../test-legalize-type-conversion.mlir | 2 +-
3 files changed, 35 insertions(+), 15 deletions(-)
diff --git a/flang/test/Fir/affine-demotion.fir b/flang/test/Fir/affine-demotion.fir
index bdb84be3624cb..a091ce4e5df9c 100644
--- a/flang/test/Fir/affine-demotion.fir
+++ b/flang/test/Fir/affine-demotion.fir
@@ -34,6 +34,8 @@ module {
}
}
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 + 1)>
+// CHECK: #[[MAP1:.*]] = affine_map<(d0)[s0, s1, s2] -> (d0 * s2 - s0)
// CHECK: func @calc(%[[VAL_0:.*]]: !fir.ref<!fir.array<?xf32>>, %[[VAL_1:.*]]: !fir.ref<!fir.array<?xf32>>, %[[VAL_2:.*]]: !fir.ref<!fir.array<?xf32>>) {
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 100 : index
@@ -43,8 +45,8 @@ module {
// CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
// CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_7]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
-// CHECK: affine.for %[[VAL_11:.*]] = 1 to 101 {
-// CHECK: %[[VAL_12:.*]] = affine.apply #map(%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
+// CHECK: affine.for %[[VAL_11:.*]] = %[[VAL_3]] to #[[MAP]]()[%[[VAL_4]]] {
+// CHECK: %[[VAL_12:.*]] = affine.apply #[[MAP1]](%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
// CHECK: %[[VAL_13:.*]] = fir.coordinate_of %[[VAL_8]], %[[VAL_12]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<f32>
// CHECK: %[[VAL_15:.*]] = fir.coordinate_of %[[VAL_9]], %[[VAL_12]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
@@ -54,8 +56,8 @@ module {
// CHECK: fir.store %[[VAL_17]] to %[[VAL_18]] : !fir.ref<f32>
// CHECK: }
// CHECK: %[[VAL_19:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
-// CHECK: affine.for %[[VAL_20:.*]] = 1 to 101 {
-// CHECK: %[[VAL_21:.*]] = affine.apply #map(%[[VAL_20]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
+// CHECK: affine.for %[[VAL_20:.*]] = %[[VAL_3]] to #[[MAP]]()[%[[VAL_4]]] {
+// CHECK: %[[VAL_21:.*]] = affine.apply #[[MAP1]](%[[VAL_20]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
// CHECK: %[[VAL_22:.*]] = fir.coordinate_of %[[VAL_10]], %[[VAL_21]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
// CHECK: %[[VAL_23:.*]] = fir.load %[[VAL_22]] : !fir.ref<f32>
// CHECK: %[[VAL_24:.*]] = fir.coordinate_of %[[VAL_9]], %[[VAL_21]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index df255cfcf3ec1..3ecd1fd43bf44 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
@@ -2216,22 +2217,45 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriterImpl.logger.startLine() << "* Fold {\n";
rewriterImpl.logger.indent();
});
- (void)rewriterImpl;
+
+ // Clear pattern state, so that the next pattern application starts with a
+ // clean slate. (The op/block sets are populated by listener notifications.)
+ auto cleanup = llvm::make_scope_exit([&]() {
+ rewriterImpl.patternNewOps.clear();
+ rewriterImpl.patternModifiedOps.clear();
+ rewriterImpl.patternInsertedBlocks.clear();
+ });
+
+ // Upon failure, undo all changes made by the folder.
+ RewriterState curState = rewriterImpl.getCurrentState();
+ auto undoFolding = [&]() {
+ rewriterImpl.resetState(curState, std::string(op->getName().getStringRef()) + " folder");
+ return failure();
+ };
// Try to fold the operation.
StringRef opName = op->getName().getStringRef();
SmallVector<Value, 2> replacementValues;
SmallVector<Operation *, 2> newOps;
rewriter.setInsertionPoint(op);
+ rewriter.startOpModification(op);
if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
+ rewriter.cancelOpModification(op);
return failure();
}
+ rewriter.finalizeOpModification(op);
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
- if (replacementValues.empty())
- return legalize(op, rewriter);
+ if (replacementValues.empty()) {
+ if (succeeded(legalize(op, rewriter)))
+ return success();
+ return undoFolding();
+ }
+
+ // Insert a replacement for 'op' with the folded replacement values.
+ rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
@@ -2245,16 +2269,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
"op '" + opName +
"' folder rollback of IR modifications requested");
}
- // Legalization failed: erase all materialized constants.
- for (Operation *op : newOps)
- rewriter.eraseOp(op);
- return failure();
+ return undoFolding();
}
}
- // Insert a replacement for 'op' with the folded replacement values.
- rewriter.replaceOp(op, replacementValues);
-
LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
return success();
}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index db8bd0f6378d2..9bffe92b374d5 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -104,8 +104,8 @@ func.func @test_signature_conversion_no_converter() {
"test.signature_conversion_no_converter"() ({
// expected-error at below {{failed to legalize unresolved materialization from ('f64') to ('f32') that remained live after conversion}}
^bb0(%arg0: f32):
- "test.type_consumer"(%arg0) : (f32) -> ()
// expected-note at below{{see existing live user here}}
+ "test.type_consumer"(%arg0) : (f32) -> ()
"test.return"(%arg0) : (f32) -> ()
}) : () -> ()
return
More information about the Mlir-commits
mailing list