[Mlir-commits] [mlir] [mlir][arith] Fix crash in IntRangeOptimizations due to stale solver state (PR #186187)
Mehdi Amini
llvmlistbot at llvm.org
Thu Mar 12 10:14:38 PDT 2026
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/186187
When the IntRangeOptimizationsPass runs applyPatternsGreedily with constant folding enabled, constant folding can restructure blocks — for example by removing a block argument. The integer range solver, which pre-computed range information for the original block arguments, is now out of sync: subsequent range queries about the new (reused) argument positions return stale information computed for the old arguments, causing crashes.
Fix by passing enableConstantFolding(false) in GreedyRewriteConfig so that the solver's state remains consistent with the IR throughout the rewrite.
Fixes #122076
Assisted-by: Claude Code
>From 792131322b2b47bb01345965b001586b420d6a7f Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Feb 2026 16:37:46 -0800
Subject: [PATCH] [mlir][arith] Fix crash in IntRangeOptimizations due to stale
solver state
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
When the IntRangeOptimizationsPass runs applyPatternsGreedily with constant
folding enabled, constant folding can restructure blocks — for example by
removing a block argument. The integer range solver, which pre-computed range
information for the original block arguments, is now out of sync: subsequent
range queries about the new (reused) argument positions return stale
information computed for the old arguments, causing crashes.
Fix by passing enableConstantFolding(false) in GreedyRewriteConfig so that
the solver's state remains consistent with the IR throughout the rewrite.
Adds a regression test that reproduces the crash.
Fixes #122076
Assisted-by: Claude Code
---
.../Transforms/IntRangeOptimizations.cpp | 3 +-
mlir/test/Dialect/Arith/int-range-opts.mlir | 104 ++++++++++++++++++
2 files changed, 106 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index fefbba989b996..9a71650bcb31d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -678,7 +678,8 @@ struct IntRangeOptimizationsPass final
if (failed(applyPatternsGreedily(
op, std::move(patterns),
- GreedyRewriteConfig().setListener(&listener))))
+ GreedyRewriteConfig().enableConstantFolding(false).setListener(
+ &listener))))
signalPassFailure();
}
};
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index e6e48d30cece5..543223a51599c 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -148,3 +148,107 @@ func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64>
%2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
return %2 : tensor<128xi64>
}
+
+// -----
+
+// Make sure int-range-optimizations does not crash when potential control-flow folding
+// is possible. Such folding does not keep the solver state up-to-date.
+// CHECK-LABEL: @no_crash_llvm_ops_with_folding
+module {
+ llvm.mlir.global private constant @assert_msg_23(dense<0x0> : tensor<1xi8>) : !llvm.array<3 x i8> llvm.func @abort()
+ llvm.mlir.global private constant @assert_msg(dense<0x1> : tensor<1xi8>) : !llvm.array<3 x i8>
+ llvm.func @abort()
+ llvm.func @puts(!llvm.ptr)
+ llvm.func @malloc(i64) -> !llvm.ptr
+ func.func @no_crash_llvm_ops_with_folding(%arg0: memref<1x24x29xi32, strided<[?, ?, ?], offset: ?>>, %arg1: memref<1x24x29xi32, strided<[?, ?, ?], offset: ?>>, %arg2: memref<1x24x29xi32, strided<[?, ?, ?], offset: ?>>) -> memref<1x24x29xi32> {
+ %c29 = arith.constant 29 : index
+ %c1 = arith.constant 1 : index
+ %c24 = arith.constant 24 : index
+ %c0 = arith.constant 0 : index
+ %0 = llvm.mlir.constant(true) : i1
+ %1 = llvm.mlir.constant(64 : index) : i64
+ %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %14 = llvm.mlir.constant(0 : index) : i64
+ %20 = llvm.mlir.constant(24 : index) : i64
+ %21 = llvm.mlir.constant(29 : index) : i64
+ %22 = llvm.mlir.constant(1 : index) : i64
+ %23 = llvm.mlir.constant(696 : index) : i64
+ %24 = llvm.mlir.zero : !llvm.ptr
+ %96 = llvm.getelementptr %24[696] : (!llvm.ptr) -> !llvm.ptr, i32
+ %97 = llvm.ptrtoint %96 : !llvm.ptr to i64
+ %98 = llvm.add %97, %1 : i64
+ %99 = llvm.call @malloc(%98) : (i64) -> !llvm.ptr
+ %100 = llvm.ptrtoint %99 : !llvm.ptr to i64
+ %101 = llvm.sub %1, %22 : i64
+ %102 = llvm.add %100, %101 : i64
+ %103 = llvm.urem %102, %1 : i64
+ %104 = llvm.sub %102, %103 : i64
+ %105 = llvm.inttoptr %104 : i64 to !llvm.ptr
+ %106 = llvm.insertvalue %99, %6[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %107 = llvm.insertvalue %105, %106[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %108 = llvm.insertvalue %14, %107[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %109 = llvm.insertvalue %22, %108[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %110 = llvm.insertvalue %20, %109[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %111 = llvm.insertvalue %21, %110[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %112 = llvm.insertvalue %23, %111[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %113 = llvm.insertvalue %21, %112[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %114 = llvm.insertvalue %22, %113[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+ %119 = builtin.unrealized_conversion_cast %114 : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x24x29xi32>
+ %collapse_shape = memref.collapse_shape %119 [[0, 1], [2]] : memref<1x24x29xi32> into memref<24x29xi32>
+ %120 = builtin.unrealized_conversion_cast %collapse_shape : memref<24x29xi32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ llvm.cond_br %0, ^bb72, ^bb81
+ ^bb72: // pred: ^bb71
+ %collapse_shape_3 = memref.collapse_shape %119 [[0, 1], [2]] : memref<1x24x29xi32> into memref<24x29xi32>
+ %211 = builtin.unrealized_conversion_cast %collapse_shape_3 : memref<24x29xi32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ llvm.cond_br %0, ^bb80, ^bb105
+ ^bb80: // pred: ^bb79
+ %213 = scf.while (%arg3 = %c0) : (index) -> index {
+ %374 = arith.cmpi slt, %arg3, %c24 : index
+ scf.condition(%374) %arg3 : index
+ } do {
+ ^bb0(%arg3: index):
+ %374 = arith.addi %arg3, %c1 : index
+ %375 = builtin.unrealized_conversion_cast %arg3 : index to i64
+ %376 = scf.while (%arg4 = %c0) : (index) -> index {
+ %true = arith.constant true
+ scf.condition(%true) %arg4 : index
+ } do {
+ ^bb0(%arg4: index):
+ scf.yield %arg4 : index
+ }
+ %377 = scf.while (%arg4 = %c0) : (index) -> index {
+ %378 = arith.cmpi slt, %arg4, %c29 : index
+ scf.condition(%378) %arg4 : index
+ } do {
+ ^bb0(%arg4: index):
+ %378 = arith.addi %arg4, %c1 : index
+ %379 = builtin.unrealized_conversion_cast %arg4 : index to i64
+ %380 = llvm.mul %375, %21 : i64
+ %381 = llvm.add %380, %379 : i64
+ %382 = llvm.getelementptr %105[%381] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ %383 = llvm.load %382 : !llvm.ptr -> f32
+ %384 = llvm.intr.exp(%383) : (f32) -> f32
+ %385 = llvm.extractvalue %211[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ %386 = llvm.mul %375, %21 : i64
+ %387 = llvm.add %386, %379 : i64
+ %388 = llvm.getelementptr %385[%387] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ llvm.store %384, %388 : f32, !llvm.ptr
+ scf.yield %378 : index
+ }
+ scf.yield %374 : index
+ }
+ return %119 : memref<1x24x29xi32>
+ ^bb81: // pred: ^bb0
+ %214 = llvm.mlir.addressof @assert_msg : !llvm.ptr
+ %215 = llvm.getelementptr %214[0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<558 x i8>
+ llvm.call @puts(%215) : (!llvm.ptr) -> ()
+ llvm.call @abort() : () -> ()
+ llvm.unreachable
+ ^bb105: // pred: ^bb24
+ %262 = llvm.mlir.addressof @assert_msg_23 : !llvm.ptr
+ %263 = llvm.getelementptr %262[0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<674 x i8>
+ llvm.call @puts(%263) : (!llvm.ptr) -> ()
+ llvm.call @abort() : () -> ()
+ llvm.unreachable
+ }
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list