[Mlir-commits] [mlir] [mlir][Transforms] Fix crash in `reconcile-unrealized-casts` (PR #158067)

Matthias Springer llvmlistbot at llvm.org
Fri Sep 12 06:22:12 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/158067

>From f7686bd8c84024bb9d6e17ad0d234711f09ef3ad Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 11 Sep 2025 13:14:56 +0000
Subject: [PATCH 1/3] [mlir][Transforms] Fix crash in
 `reconcile-unrealized-casts`

---
 .../mlir/Transforms/DialectConversion.h       |   3 +
 .../Transforms/Utils/DialectConversion.cpp    | 143 +++++++++++++-----
 .../reconcile-unrealized-casts.mlir           |  50 ++++++
 ...assume-alignment-runtime-verification.mlir |   3 +-
 .../atomic-rmw-runtime-verification.mlir      |   3 +-
 .../MemRef/store-runtime-verification.mlir    |   3 +-
 6 files changed, 163 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 6949f4a14fdba..5bf73aa350534 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1421,6 +1421,9 @@ struct ConversionConfig {
 ///
 /// In the above example, %0 can be used instead of %3 and all cast ops are
 /// folded away.
+void reconcileUnrealizedCasts(
+    const DenseSet<UnrealizedConversionCastOp> &castOps,
+    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
 void reconcileUnrealizedCasts(
     ArrayRef<UnrealizedConversionCastOp> castOps,
     SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 36ee87b533b3b..627e5d50c42ac 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3100,6 +3100,7 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
 //===----------------------------------------------------------------------===//
 // OperationConverter
 //===----------------------------------------------------------------------===//
+
 namespace {
 enum OpConversionMode {
   /// In this mode, the conversion will ignore failed conversions to allow
@@ -3117,6 +3118,13 @@ enum OpConversionMode {
 } // namespace
 
 namespace mlir {
+
+// Predeclaration only.
+static void reconcileUnrealizedCasts(
+    const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
+        &castOps,
+    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);
+
 // This class converts operations to a given conversion target via a set of
 // rewrite patterns. The conversion behaves differently depending on the
 // conversion mode.
@@ -3264,18 +3272,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   // After a successful conversion, apply rewrites.
   rewriterImpl.applyRewrites();
 
-  // Gather all unresolved materializations.
-  SmallVector<UnrealizedConversionCastOp> allCastOps;
-  const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
-      &materializations = rewriterImpl.unresolvedMaterializations;
-  for (auto it : materializations)
-    allCastOps.push_back(it.first);
-
   // Reconcile all UnrealizedConversionCastOps that were inserted by the
-  // dialect conversion frameworks. (Not the one that were inserted by
+  // dialect conversion frameworks. (Not the ones that were inserted by
   // patterns.)
+  const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
+      &materializations = rewriterImpl.unresolvedMaterializations;
   SmallVector<UnrealizedConversionCastOp> remainingCastOps;
-  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
+  reconcileUnrealizedCasts(materializations, &remainingCastOps);
 
   // Drop markers.
   for (UnrealizedConversionCastOp castOp : remainingCastOps)
@@ -3303,20 +3306,15 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 // Reconcile Unrealized Casts
 //===----------------------------------------------------------------------===//
 
-void mlir::reconcileUnrealizedCasts(
-    ArrayRef<UnrealizedConversionCastOp> castOps,
+/// Try to reconcile all given UnrealizedConversionCastOps and store the
+/// left-over ops in `remainingCastOps` (if provided). See documentation in
+/// DialectConversion.h for more details.
+template <typename RangeT>
+static void reconcileUnrealizedCastsImpl(
+    RangeT castOps, function_ref<bool(UnrealizedConversionCastOp)> isCastOpFn,
     SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+  // A worklist of cast ops to process.
   SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
-  // This set is maintained only if `remainingCastOps` is provided.
-  DenseSet<Operation *> erasedOps;
-
-  // Helper function that adds all operands to the worklist that are an
-  // unrealized_conversion_cast op result.
-  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
-    for (Value v : castOp.getInputs())
-      if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
-        worklist.insert(inputCastOp);
-  };
 
   // Helper function that return the unrealized_conversion_cast op that
   // defines all inputs of the given op (in the same order). Return "nullptr"
@@ -3337,39 +3335,106 @@ void mlir::reconcileUnrealizedCasts(
   // Process ops in the worklist bottom-to-top.
   while (!worklist.empty()) {
     UnrealizedConversionCastOp castOp = worklist.pop_back_val();
-    if (castOp->use_empty()) {
-      // DCE: If the op has no users, erase it. Add the operands to the
-      // worklist to find additional DCE opportunities.
-      enqueueOperands(castOp);
-      if (remainingCastOps)
-        erasedOps.insert(castOp.getOperation());
-      castOp->erase();
-      continue;
-    }
 
     // Traverse the chain of input cast ops to see if an op with the same
     // input types can be found.
     UnrealizedConversionCastOp nextCast = castOp;
     while (nextCast) {
       if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+        if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
+              return v.getDefiningOp() == castOp;
+            })) {
+          // Ran into a cycle.
+          break;
+        }
+
         // Found a cast where the input types match the output types of the
-        // matched op. We can directly use those inputs and the matched op can
-        // be removed.
-        enqueueOperands(castOp);
+        // matched op. We can directly use those inputs.
         castOp.replaceAllUsesWith(nextCast.getInputs());
-        if (remainingCastOps)
-          erasedOps.insert(castOp.getOperation());
-        castOp->erase();
         break;
       }
       nextCast = getInputCast(nextCast);
     }
   }
 
-  if (remainingCastOps)
-    for (UnrealizedConversionCastOp op : castOps)
-      if (!erasedOps.contains(op.getOperation()))
+  // A set of all alive cast ops. I.e., ops whose results are (transitively)
+  // used by an op that is not a cast op.
+  DenseSet<Operation *> liveOps;
+
+  // Helper function that marks the given op and transitively reachable input
+  // cast ops as alive.
+  auto markOpLive = [&](Operation *op) {
+    SmallVector<Operation *> worklist;
+    worklist.push_back(op);
+    while (!worklist.empty()) {
+      Operation *op = worklist.pop_back_val();
+      if (liveOps.insert(op).second) {
+        // Successfully inserted: process reachable input cast ops.
+        for (Value v : op->getOperands())
+          if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+            if (isCastOpFn(castOp))
+              worklist.push_back(castOp);
+      }
+    }
+  };
+
+  // Find all alive cast ops.
+  for (UnrealizedConversionCastOp op : castOps) {
+    // If any of the users is not a cast op, mark the current op (and its
+    // input ops) as live.
+    if (llvm::any_of(op->getUsers(), [&](Operation *user) {
+          auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
+          return !castOp || !isCastOpFn(castOp);
+        }))
+      markOpLive(op);
+  }
+
+  // Erase all dead cast ops.
+  for (UnrealizedConversionCastOp op : castOps) {
+    if (liveOps.contains(op)) {
+      // Op is alive and was not erased. Add it to the remaining cast ops.
+      if (remainingCastOps)
         remainingCastOps->push_back(op);
+      continue;
+    }
+
+    // Op is dead. Erase it.
+    op->dropAllUses();
+    op->erase();
+  }
+}
+
+void mlir::reconcileUnrealizedCasts(
+    ArrayRef<UnrealizedConversionCastOp> castOps,
+    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+  // Set of all cast ops for faster lookups.
+  DenseSet<UnrealizedConversionCastOp> castOpSet;
+  for (UnrealizedConversionCastOp op : castOps)
+    castOpSet.insert(op);
+  reconcileUnrealizedCasts(castOpSet, remainingCastOps);
+}
+
+void mlir::reconcileUnrealizedCasts(
+    const DenseSet<UnrealizedConversionCastOp> &castOps,
+    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+  reconcileUnrealizedCastsImpl(
+      llvm::make_range(castOps.begin(), castOps.end()),
+      [&](UnrealizedConversionCastOp castOp) {
+        return castOps.contains(castOp);
+      },
+      remainingCastOps);
+}
+
+static void mlir::reconcileUnrealizedCasts(
+    const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
+        &castOps,
+    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+  reconcileUnrealizedCastsImpl(
+      castOps.keys(),
+      [&](UnrealizedConversionCastOp castOp) {
+        return castOps.contains(castOp);
+      },
+      remainingCastOps);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
index 3573114f5e038..ac5ca321c066f 100644
--- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
+++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
@@ -194,3 +194,53 @@ func.func @emptyCast() -> index {
     %0 = builtin.unrealized_conversion_cast to index
     return %0 : index
 }
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+//  CHECK-NEXT:   "test.return"() : () -> ()
+test.graph_region {
+  %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+  %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+  %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+  "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+//  CHECK-NEXT:   %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64
+//  CHECK-NEXT:   %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16
+//  CHECK-NEXT:   %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32
+//  CHECK-NEXT:   "test.user"(%[[cast2]]) : (i32) -> ()
+//  CHECK-NEXT:   "test.return"() : () -> ()
+test.graph_region {
+  %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+  %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+  %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+  "test.user"(%2) : (i32) -> ()
+  "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+//  CHECK-NEXT:   "test.return"() : () -> ()
+test.graph_region {
+  %0 = builtin.unrealized_conversion_cast %0 : i32 to i32
+  "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+//  CHECK-NEXT:   %[[c0:.*]] = arith.constant
+//  CHECK-NEXT:   %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32
+//  CHECK-NEXT:   "test.user"(%[[cast]]#0) : (i32) -> ()
+//  CHECK-NEXT:   "test.return"() : () -> ()
+test.graph_region {
+  %cst = arith.constant 0 : i32
+  %0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32
+  "test.user"(%0) : (i32) -> ()
+  "test.return"() : () -> ()
+}
diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
index 25a338df8d790..01a826a638606 100644
--- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -1,7 +1,8 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
 // RUN:     -expand-strided-metadata \
 // RUN:     -test-cf-assert \
-// RUN:     -convert-to-llvm | \
+// RUN:     -convert-to-llvm \
+// RUN:     -reconcile-unrealized-casts | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
index 4c6a48d577a6c..1144a7caf36e8 100644
--- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
 // RUN:     -test-cf-assert \
-// RUN:     -convert-to-llvm | \
+// RUN:     -convert-to-llvm \
+// RUN:     -reconcile-unrealized-casts | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
index dd000c6904bcb..82e63805cd027 100644
--- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
 // RUN:     -test-cf-assert \
-// RUN:     -convert-to-llvm | \
+// RUN:     -convert-to-llvm \
+// RUN:     -reconcile-unrealized-casts | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s

>From fa3bebd3835a77898404c6516df675ed116fa249 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 12 Sep 2025 14:16:09 +0100
Subject: [PATCH 2/3] Apply suggestions from code review

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
 mlir/lib/Transforms/Utils/DialectConversion.cpp | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 627e5d50c42ac..0c29b168d771a 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3309,9 +3309,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 /// Try to reconcile all given UnrealizedConversionCastOps and store the
 /// left-over ops in `remainingCastOps` (if provided). See documentation in
 /// DialectConversion.h for more details.
+/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the algorithm may
+/// visit an operand (or user) which is a cast op, but will not try to reconcile it if not in the
+/// filtered set.
 template <typename RangeT>
 static void reconcileUnrealizedCastsImpl(
-    RangeT castOps, function_ref<bool(UnrealizedConversionCastOp)> isCastOpFn,
+    RangeT castOps, function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
     SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
   // A worklist of cast ops to process.
   SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
@@ -3363,9 +3366,9 @@ static void reconcileUnrealizedCastsImpl(
 
   // Helper function that marks the given op and transitively reachable input
   // cast ops as alive.
-  auto markOpLive = [&](Operation *op) {
+  auto markOpLive = [&](Operation *rootOp) {
     SmallVector<Operation *> worklist;
-    worklist.push_back(op);
+    worklist.push_back(rootOp);
     while (!worklist.empty()) {
       Operation *op = worklist.pop_back_val();
       if (liveOps.insert(op).second) {
@@ -3380,6 +3383,8 @@ static void reconcileUnrealizedCastsImpl(
 
   // Find all alive cast ops.
   for (UnrealizedConversionCastOp op : castOps) {
+     // The op may have been marked live already as being an operand of another live cast op.
+     if (liveOps.contains(op.getOperation()) continue;
     // If any of the users is not a cast op, mark the current op (and its
     // input ops) as live.
     if (llvm::any_of(op->getUsers(), [&](Operation *user) {

>From 1bb194a5500be9369063ab805f62285459319565 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 12 Sep 2025 13:21:47 +0000
Subject: [PATCH 3/3] fix

---
 .../lib/Transforms/Utils/DialectConversion.cpp | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0c29b168d771a..6841d4b37d949 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3309,12 +3309,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 /// Try to reconcile all given UnrealizedConversionCastOps and store the
 /// left-over ops in `remainingCastOps` (if provided). See documentation in
 /// DialectConversion.h for more details.
-/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the algorithm may
-/// visit an operand (or user) which is a cast op, but will not try to reconcile it if not in the
-/// filtered set.
+/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
+/// algorithm may visit an operand (or user) which is a cast op, but will not
+/// try to reconcile it if not in the filtered set.
 template <typename RangeT>
 static void reconcileUnrealizedCastsImpl(
-    RangeT castOps, function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
+    RangeT castOps,
+    function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
     SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
   // A worklist of cast ops to process.
   SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
@@ -3375,7 +3376,7 @@ static void reconcileUnrealizedCastsImpl(
         // Successfully inserted: process reachable input cast ops.
         for (Value v : op->getOperands())
           if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
-            if (isCastOpFn(castOp))
+            if (isCastOpOfInterestFn(castOp))
               worklist.push_back(castOp);
       }
     }
@@ -3383,13 +3384,14 @@ static void reconcileUnrealizedCastsImpl(
 
   // Find all alive cast ops.
   for (UnrealizedConversionCastOp op : castOps) {
-     // The op may have been marked live already as being an operand of another live cast op.
+    // The op may have been marked live already as being an operand of another
+    // live cast op.
      if (liveOps.contains(op.getOperation()) continue;
     // If any of the users is not a cast op, mark the current op (and its
     // input ops) as live.
     if (llvm::any_of(op->getUsers(), [&](Operation *user) {
-          auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
-          return !castOp || !isCastOpFn(castOp);
+      auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
+      return !castOp || !isCastOpOfInterestFn(castOp);
         }))
       markOpLive(op);
   }



More information about the Mlir-commits mailing list