[Mlir-commits] [mlir] [mlir][Transforms] Fix crash in `reconcile-unrealized-casts` (PR #158067)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 11 06:17:37 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
The `reconcile-unrealized-casts` pass used to crash when the input contains circular chains of `unrealized_conversion_cast` ops.
Furthermore, the `reconcileUnrealizedCasts` helper functions used to erase ops that were not passed via the `castOps` operand. Such ops are now preserved. That's why some integration tests had to be changed.
---
Full diff: https://github.com/llvm/llvm-project/pull/158067.diff
5 Files Affected:
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+57-19)
- (modified) mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir (+50)
- (modified) mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir (+2-1)
- (modified) mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir (+2-1)
- (modified) mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir (+2-1)
``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 36ee87b533b3b..f6a8e7e60a69c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3306,9 +3306,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
void mlir::reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ // Set of all cast ops for faster lookups.
+ DenseSet<Operation *> castOpSet;
+ for (UnrealizedConversionCastOp op : castOps)
+ castOpSet.insert(op);
+
+ // A worklist of cast ops to process.
SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
- // This set is maintained only if `remainingCastOps` is provided.
- DenseSet<Operation *> erasedOps;
// Helper function that adds all operands to the worklist that are an
// unrealized_conversion_cast op result.
@@ -3337,39 +3341,73 @@ void mlir::reconcileUnrealizedCasts(
// Process ops in the worklist bottom-to-top.
while (!worklist.empty()) {
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
- if (castOp->use_empty()) {
- // DCE: If the op has no users, erase it. Add the operands to the
- // worklist to find additional DCE opportunities.
- enqueueOperands(castOp);
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
- continue;
- }
// Traverse the chain of input cast ops to see if an op with the same
// input types can be found.
UnrealizedConversionCastOp nextCast = castOp;
while (nextCast) {
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
+ return v.getDefiningOp() == castOp;
+ })) {
+ // Ran into a cycle.
+ break;
+ }
+
// Found a cast where the input types match the output types of the
- // matched op. We can directly use those inputs and the matched op can
- // be removed.
+ // matched op. We can directly use those inputs.
enqueueOperands(castOp);
castOp.replaceAllUsesWith(nextCast.getInputs());
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
break;
}
nextCast = getInputCast(nextCast);
}
}
- if (remainingCastOps)
- for (UnrealizedConversionCastOp op : castOps)
- if (!erasedOps.contains(op.getOperation()))
+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
+ // used by an op that is not a cast op.
+ DenseSet<Operation *> liveOps;
+
+ // Helper function that marks the given op and all ops transitively reachable
+ // input cast ops as alive.
+ auto markOpLive = [&](Operation *op) {
+ SmallVector<Operation *> worklist;
+ worklist.push_back(op);
+ while (!worklist.empty()) {
+ Operation *op = worklist.pop_back_val();
+ if (liveOps.insert(op).second) {
+ // Successfully inserted: the op is live. Add its operands to the
+ // worklist to mark them live.
+ for (Value v : op->getOperands())
+ if (castOpSet.contains(v.getDefiningOp()))
+ worklist.push_back(v.getDefiningOp());
+ }
+ }
+ };
+
+ // Find all alive cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ // If any of the users is not a cast op, mark the current op (and its
+ // input ops) as live.
+ if (llvm::any_of(op->getUsers(), [&](Operation *user) {
+ return !castOpSet.contains(user);
+ }))
+ markOpLive(op);
+ }
+
+ // Erase all dead cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ if (liveOps.contains(op)) {
+ // Op is alive and was not erased. Add it to the remaining cast ops.
+ if (remainingCastOps)
remainingCastOps->push_back(op);
+ continue;
+ }
+
+ // Op is dead. Erase it.
+ op->dropAllUses();
+ op->erase();
+ }
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
index 3573114f5e038..ac5ca321c066f 100644
--- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
+++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
@@ -194,3 +194,53 @@ func.func @emptyCast() -> index {
%0 = builtin.unrealized_conversion_cast to index
return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64
+// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16
+// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32
+// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.user"(%2) : (i32) -> ()
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %0 : i32 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[c0:.*]] = arith.constant
+// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32
+// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %cst = arith.constant 0 : i32
+ %0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32
+ "test.user"(%0) : (i32) -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
index 25a338df8d790..01a826a638606 100644
--- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -1,7 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -expand-strided-metadata \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
index 4c6a48d577a6c..1144a7caf36e8 100644
--- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
index dd000c6904bcb..82e63805cd027 100644
--- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
``````````
</details>
https://github.com/llvm/llvm-project/pull/158067
More information about the Mlir-commits
mailing list