[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Add flag to disable rollback (PR #136490)

Matthias Springer llvmlistbot at llvm.org
Tue Apr 22 00:22:30 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/136490

>From c494a9dbaab7170cd7f260a134cfb324c5ce0c67 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 20 Apr 2025 13:39:20 +0200
Subject: [PATCH 1/2] no rollback flag

---
 .../mlir/Transforms/DialectConversion.h       | 20 +++++++
 .../Transforms/Utils/DialectConversion.cpp    | 57 ++++++++++++++-----
 2 files changed, 63 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b6ab252456e70..b65b3ea971f91 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1219,6 +1219,26 @@ struct ConversionConfig {
   /// materializations and instead inserts "builtin.unrealized_conversion_cast"
   /// ops to ensure that the resulting IR is valid.
   bool buildMaterializations = true;
+
+  /// If set to "true", pattern rollback is allowed. The conversion driver
+  /// rolls back IR modifications in the following situations.
+  ///
+  /// 1. Pattern implementation returns "failure" after modifying IR.
+  /// 2. Pattern produces IR (in-place modification or new IR) that is illegal
+  ///    and cannot be legalized by subsequent foldings / pattern applications.
+  ///
+  /// If set to "false", the conversion driver will produce an LLVM fatal error
+  /// instead of rolling back IR modifications. Moreover, in case of a failed
+  /// conversion, the original IR is not restored. The resulting IR may be a
+  /// mix of original and rewritten IR. (Same as a failed greedy pattern
+  /// rewrite.)
+  ///
+  /// Note: This flag was added in preparation of the One-Shot Dialect
+  /// Conversion refactoring, which will remove the ability to roll back IR
+  /// modifications from the conversion driver. Use this flag to ensure that
+  /// your patterns do not trigger any IR rollbacks. For details, see
+  /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
+  bool allowPatternRollback = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4d250329c6f45..6deedd41bb9ea 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -861,8 +861,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// conversion process succeeds.
   void applyRewrites();
 
-  /// Reset the state of the rewriter to a previously saved point.
-  void resetState(RewriterState state);
+  /// Reset the state of the rewriter to a previously saved point. Optionally,
+  /// the name of the pattern that triggered the rollback can specified for
+  /// debugging purposes.
+  void resetState(RewriterState state, StringRef patternName = "");
 
   /// Append a rewrite. Rewrites are committed upon success and rolled back upon
   /// failure.
@@ -873,8 +875,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   }
 
   /// Undo the rewrites (motions, splits) one by one in reverse order until
-  /// "numRewritesToKeep" rewrites remains.
-  void undoRewrites(unsigned numRewritesToKeep = 0);
+  /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
+  /// that triggered the rollback can specified for debugging purposes.
+  void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = "");
 
   /// Remap the given values to those with potentially different types. Returns
   /// success if the values could be remapped, failure otherwise. `valueDiagTag`
@@ -1204,9 +1207,10 @@ RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
 }
 
-void ConversionPatternRewriterImpl::resetState(RewriterState state) {
+void ConversionPatternRewriterImpl::resetState(RewriterState state,
+                                               StringRef patternName) {
   // Undo any rewrites.
-  undoRewrites(state.numRewrites);
+  undoRewrites(state.numRewrites, patternName);
 
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
@@ -1216,10 +1220,19 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     replacedOps.pop_back();
 }
 
-void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
+void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
+                                                 StringRef patternName) {
   for (auto &rewrite :
-       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
+       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
+    if (!config.allowPatternRollback &&
+        !isa<UnresolvedMaterializationRewrite>(rewrite)) {
+      // Unresolved materializations can always be rolled back (erased).
+      std::string errorMessage = "pattern '" + std::string(patternName) +
+                                 "' rollback of IR modifications requested";
+      llvm_unreachable(errorMessage.c_str());
+    }
     rewrite->rollback();
+  }
   rewrites.resize(numRewritesToKeep);
 }
 
@@ -2158,7 +2171,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     });
     if (config.listener)
       config.listener->notifyPatternEnd(pattern, failure());
-    rewriterImpl.resetState(curState);
+    rewriterImpl.resetState(curState, pattern.getDebugName());
     appliedPatterns.erase(&pattern);
   };
 
@@ -2168,8 +2181,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
     auto result = legalizePatternResult(op, pattern, rewriter, curState);
     appliedPatterns.erase(&pattern);
-    if (failed(result))
-      rewriterImpl.resetState(curState);
+    if (failed(result)) {
+      if (!rewriterImpl.config.allowPatternRollback)
+        op->emitError("pattern '")
+            << pattern.getDebugName()
+            << "' produced IR that could not be legalized";
+      rewriterImpl.resetState(curState, pattern.getDebugName());
+    }
     if (config.listener)
       config.listener->notifyPatternEnd(pattern, result);
     return result;
@@ -2674,9 +2692,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
 
-  for (auto *op : toConvert)
-    if (failed(convert(rewriter, op)))
-      return rewriterImpl.undoRewrites(), failure();
+  for (auto *op : toConvert) {
+    if (failed(convert(rewriter, op))) {
+      // Dialect conversion failed.
+      if (rewriterImpl.config.allowPatternRollback) {
+        // Rollback is allowed: restore the original IR.
+        rewriterImpl.undoRewrites();
+      } else {
+        // Rollback is not allowed: apply all modifications that have been
+        // performed so far.
+        rewriterImpl.applyRewrites();
+      }
+      return failure();
+    }
+  }
 
   // After a successful conversion, apply rewrites.
   rewriterImpl.applyRewrites();

>From 7b4cb9d23853e828e8ff323ea6f1b273eabde433 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 22 Apr 2025 09:22:09 +0200
Subject: [PATCH 2/2] address comments

---
 mlir/lib/Transforms/Utils/DialectConversion.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 6deedd41bb9ea..9ace360d79ed9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1227,9 +1227,8 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
     if (!config.allowPatternRollback &&
         !isa<UnresolvedMaterializationRewrite>(rewrite)) {
       // Unresolved materializations can always be rolled back (erased).
-      std::string errorMessage = "pattern '" + std::string(patternName) +
-                                 "' rollback of IR modifications requested";
-      llvm_unreachable(errorMessage.c_str());
+      llvm::report_fatal_error("pattern '" + std::string(patternName) +
+                               "' rollback of IR modifications requested");
     }
     rewrite->rollback();
   }



More information about the Mlir-commits mailing list