[Mlir-commits] [mlir] [mlir][arith] Fix crash in IntRangeOptimizations due to stale solver state (PR #186187)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 12 10:15:18 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
When the IntRangeOptimizationsPass runs applyPatternsGreedily with constant folding enabled, constant folding can restructure blocks — for example by removing a block argument. The integer range solver, which pre-computed range information for the original block arguments, is now out of sync: subsequent range queries about the new (reused) argument positions return stale information computed for the old arguments, causing crashes.
Fix by passing enableConstantFolding(false) in GreedyRewriteConfig so that the solver's state remains consistent with the IR throughout the rewrite.
Fixes #<!-- -->122076
Assisted-by: Claude Code
---
Full diff: https://github.com/llvm/llvm-project/pull/186187.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+2-1)
- (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+104)
``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index fefbba989b996..9a71650bcb31d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -678,7 +678,8 @@ struct IntRangeOptimizationsPass final
if (failed(applyPatternsGreedily(
op, std::move(patterns),
- GreedyRewriteConfig().setListener(&listener))))
+ GreedyRewriteConfig().enableConstantFolding(false).setListener(
+ &listener))))
signalPassFailure();
}
};
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index e6e48d30cece5..543223a51599c 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -148,3 +148,107 @@ func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64>
%2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
return %2 : tensor<128xi64>
}
+
+// -----
+
+// Make sure int-range-optimizations does not crash when potential control-flow folding
+// is possible. Such folding does not keep the solver state up-to-date.
+// CHECK-LABEL: @no_crash_llvm_ops_with_folding
+module {
+ llvm.mlir.global private constant @assert_msg_23(dense<0x0> : tensor<1xi8>) : !llvm.array<3 x i8> llvm.func @abort()
+ llvm.mlir.global private constant @assert_msg(dense<0x1> : tensor<1xi8>) : !llvm.array<3 x i8>
+ llvm.func @abort()
+ llvm.func @puts(!llvm.ptr)
+ llvm.func @malloc(i64) -> !llvm.ptr
+ func.func @no_crash_llvm_ops_with_folding(%arg0: memref<1x24x29xi32, strided<[?, ?, ?], offset: ?>>, %arg1: memref<1x24x29xi32, strided<[?, ?, ?], offset: ?>>, %arg2: memref<1x24x29xi32, strided<[?, ?, ?], offset: ?>>) -> memref<1x24x29xi32> {
+ %c29 = arith.constant 29 : index
+ %c1 = arith.constant 1 : index
+ %c24 = arith.constant 24 : index
+ %c0 = arith.constant 0 : index
+ %0 = llvm.mlir.constant(true) : i1
+ %1 = llvm.mlir.constant(64 : index) : i64
+ %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %14 = llvm.mlir.constant(0 : index) : i64
+ %20 = llvm.mlir.constant(24 : index) : i64
+ %21 = llvm.mlir.constant(29 : index) : i64
+ %22 = llvm.mlir.constant(1 : index) : i64
+ %23 = llvm.mlir.constant(696 : index) : i64
+ %24 = llvm.mlir.zero : !llvm.ptr
+ %96 = llvm.getelementptr %24[696] : (!llvm.ptr) -> !llvm.ptr, i32
+ %97 = llvm.ptrtoint %96 : !llvm.ptr to i64
+ %98 = llvm.add %97, %1 : i64
+ %99 = llvm.call @malloc(%98) : (i64) -> !llvm.ptr
+ %100 = llvm.ptrtoint %99 : !llvm.ptr to i64
+ %101 = llvm.sub %1, %22 : i64
+ %102 = llvm.add %100, %101 : i64
+ %103 = llvm.urem %102, %1 : i64
+ %104 = llvm.sub %102, %103 : i64
+ %105 = llvm.inttoptr %104 : i64 to !llvm.ptr
+ %106 = llvm.insertvalue %99, %6[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %107 = llvm.insertvalue %105, %106[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %108 = llvm.insertvalue %14, %107[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %109 = llvm.insertvalue %22, %108[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %110 = llvm.insertvalue %20, %109[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %111 = llvm.insertvalue %21, %110[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %112 = llvm.insertvalue %23, %111[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %113 = llvm.insertvalue %21, %112[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %114 = llvm.insertvalue %22, %113[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %119 = builtin.unrealized_conversion_cast %114 : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x24x29xi32>
+ %collapse_shape = memref.collapse_shape %119 [[0, 1], [2]] : memref<1x24x29xi32> into memref<24x29xi32>
+ %120 = builtin.unrealized_conversion_cast %collapse_shape : memref<24x29xi32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ llvm.cond_br %0, ^bb72, ^bb81
+ ^bb72: // pred: ^bb71
+ %collapse_shape_3 = memref.collapse_shape %119 [[0, 1], [2]] : memref<1x24x29xi32> into memref<24x29xi32>
+ %211 = builtin.unrealized_conversion_cast %collapse_shape_3 : memref<24x29xi32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ llvm.cond_br %0, ^bb80, ^bb105
+ ^bb80: // pred: ^bb79
+ %213 = scf.while (%arg3 = %c0) : (index) -> index {
+ %374 = arith.cmpi slt, %arg3, %c24 : index
+ scf.condition(%374) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ %374 = arith.addi %arg3, %c1 : index
+ %375 = builtin.unrealized_conversion_cast %arg3 : index to i64
+ %376 = scf.while (%arg4 = %c0) : (index) -> index {
+ %true = arith.constant true
+ scf.condition(%true) %arg4 : index
+ } do {
+ ^bb0(%arg4: index):
+ scf.yield %arg4 : index
+ }
+ %377 = scf.while (%arg4 = %c0) : (index) -> index {
+ %378 = arith.cmpi slt, %arg4, %c29 : index
+ scf.condition(%378) %arg4 : index
+ } do {
+ ^bb0(%arg4: index):
+ %378 = arith.addi %arg4, %c1 : index
+ %379 = builtin.unrealized_conversion_cast %arg4 : index to i64
+ %380 = llvm.mul %375, %21 : i64
+ %381 = llvm.add %380, %379 : i64
+ %382 = llvm.getelementptr %105[%381] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %383 = llvm.load %382 : !llvm.ptr -> f32
+ %384 = llvm.intr.exp(%383) : (f32) -> f32
+ %385 = llvm.extractvalue %211[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ %386 = llvm.mul %375, %21 : i64
+ %387 = llvm.add %386, %379 : i64
+ %388 = llvm.getelementptr %385[%387] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ llvm.store %384, %388 : f32, !llvm.ptr
+ scf.yield %378 : index
+ }
+ scf.yield %374 : index
+ }
+ return %119 : memref<1x24x29xi32>
+ ^bb81: // pred: ^bb0
+ %214 = llvm.mlir.addressof @assert_msg : !llvm.ptr
+ %215 = llvm.getelementptr %214[0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<558 x i8>
+ llvm.call @puts(%215) : (!llvm.ptr) -> ()
+ llvm.call @abort() : () -> ()
+ llvm.unreachable
+ ^bb105: // pred: ^bb24
+ %262 = llvm.mlir.addressof @assert_msg_23 : !llvm.ptr
+ %263 = llvm.getelementptr %262[0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<674 x i8>
+ llvm.call @puts(%263) : (!llvm.ptr) -> ()
+ llvm.call @abort() : () -> ()
+ llvm.unreachable
+ }
+}
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/186187
More information about the Mlir-commits
mailing list