[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Add flag to disable rollback (PR #136490)
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 22 00:22:30 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/136490
>From c494a9dbaab7170cd7f260a134cfb324c5ce0c67 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 20 Apr 2025 13:39:20 +0200
Subject: [PATCH 1/2] no rollback flag
---
.../mlir/Transforms/DialectConversion.h | 20 +++++++
.../Transforms/Utils/DialectConversion.cpp | 57 ++++++++++++++-----
2 files changed, 63 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b6ab252456e70..b65b3ea971f91 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1219,6 +1219,26 @@ struct ConversionConfig {
/// materializations and instead inserts "builtin.unrealized_conversion_cast"
/// ops to ensure that the resulting IR is valid.
bool buildMaterializations = true;
+
+ /// If set to "true", pattern rollback is allowed. The conversion driver
+ /// rolls back IR modifications in the following situations.
+ ///
+ /// 1. Pattern implementation returns "failure" after modifying IR.
+ /// 2. Pattern produces IR (in-place modification or new IR) that is illegal
+ /// and cannot be legalized by subsequent foldings / pattern applications.
+ ///
+ /// If set to "false", the conversion driver will produce an LLVM fatal error
+ /// instead of rolling back IR modifications. Moreover, in case of a failed
+ /// conversion, the original IR is not restored. The resulting IR may be a
+ /// mix of original and rewritten IR. (Same as a failed greedy pattern
+ /// rewrite.)
+ ///
+ /// Note: This flag was added in preparation of the One-Shot Dialect
+ /// Conversion refactoring, which will remove the ability to roll back IR
+ /// modifications from the conversion driver. Use this flag to ensure that
+ /// your patterns do not trigger any IR rollbacks. For details, see
+ /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
+ bool allowPatternRollback = true;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4d250329c6f45..6deedd41bb9ea 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -861,8 +861,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// conversion process succeeds.
void applyRewrites();
- /// Reset the state of the rewriter to a previously saved point.
- void resetState(RewriterState state);
+ /// Reset the state of the rewriter to a previously saved point. Optionally,
+ /// the name of the pattern that triggered the rollback can specified for
+ /// debugging purposes.
+ void resetState(RewriterState state, StringRef patternName = "");
/// Append a rewrite. Rewrites are committed upon success and rolled back upon
/// failure.
@@ -873,8 +875,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
}
/// Undo the rewrites (motions, splits) one by one in reverse order until
- /// "numRewritesToKeep" rewrites remains.
- void undoRewrites(unsigned numRewritesToKeep = 0);
+ /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
+ /// that triggered the rollback can specified for debugging purposes.
+ void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
/// Remap the given values to those with potentially different types. Returns
/// success if the values could be remapped, failure otherwise. `valueDiagTag`
@@ -1204,9 +1207,10 @@ RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
}
-void ConversionPatternRewriterImpl::resetState(RewriterState state) {
+void ConversionPatternRewriterImpl::resetState(RewriterState state,
+ StringRef patternName) {
// Undo any rewrites.
- undoRewrites(state.numRewrites);
+ undoRewrites(state.numRewrites, patternName);
// Pop all of the recorded ignored operations that are no longer valid.
while (ignoredOps.size() != state.numIgnoredOperations)
@@ -1216,10 +1220,19 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
replacedOps.pop_back();
}
-void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
+void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
+ StringRef patternName) {
for (auto &rewrite :
- llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
+ llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
+ if (!config.allowPatternRollback &&
+ !isa<UnresolvedMaterializationRewrite>(rewrite)) {
+ // Unresolved materializations can always be rolled back (erased).
+ std::string errorMessage = "pattern '" + std::string(patternName) +
+ "' rollback of IR modifications requested";
+ llvm_unreachable(errorMessage.c_str());
+ }
rewrite->rollback();
+ }
rewrites.resize(numRewritesToKeep);
}
@@ -2158,7 +2171,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
});
if (config.listener)
config.listener->notifyPatternEnd(pattern, failure());
- rewriterImpl.resetState(curState);
+ rewriterImpl.resetState(curState, pattern.getDebugName());
appliedPatterns.erase(&pattern);
};
@@ -2168,8 +2181,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
auto result = legalizePatternResult(op, pattern, rewriter, curState);
appliedPatterns.erase(&pattern);
- if (failed(result))
- rewriterImpl.resetState(curState);
+ if (failed(result)) {
+ if (!rewriterImpl.config.allowPatternRollback)
+ op->emitError("pattern '")
+ << pattern.getDebugName()
+ << "' produced IR that could not be legalized";
+ rewriterImpl.resetState(curState, pattern.getDebugName());
+ }
if (config.listener)
config.listener->notifyPatternEnd(pattern, result);
return result;
@@ -2674,9 +2692,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
- for (auto *op : toConvert)
- if (failed(convert(rewriter, op)))
- return rewriterImpl.undoRewrites(), failure();
+ for (auto *op : toConvert) {
+ if (failed(convert(rewriter, op))) {
+ // Dialect conversion failed.
+ if (rewriterImpl.config.allowPatternRollback) {
+ // Rollback is allowed: restore the original IR.
+ rewriterImpl.undoRewrites();
+ } else {
+ // Rollback is not allowed: apply all modifications that have been
+ // performed so far.
+ rewriterImpl.applyRewrites();
+ }
+ return failure();
+ }
+ }
// After a successful conversion, apply rewrites.
rewriterImpl.applyRewrites();
>From 7b4cb9d23853e828e8ff323ea6f1b273eabde433 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 22 Apr 2025 09:22:09 +0200
Subject: [PATCH 2/2] address comments
---
mlir/lib/Transforms/Utils/DialectConversion.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 6deedd41bb9ea..9ace360d79ed9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1227,9 +1227,8 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
if (!config.allowPatternRollback &&
!isa<UnresolvedMaterializationRewrite>(rewrite)) {
// Unresolved materializations can always be rolled back (erased).
- std::string errorMessage = "pattern '" + std::string(patternName) +
- "' rollback of IR modifications requested";
- llvm_unreachable(errorMessage.c_str());
+ llvm::report_fatal_error("pattern '" + std::string(patternName) +
+ "' rollback of IR modifications requested");
}
rewrite->rollback();
}
More information about the Mlir-commits
mailing list