[Mlir-commits] [mlir] [mlir][arith] Fix crash in IntRangeOptimizations due to stale solver state (PR #186187)

Mehdi Amini llvmlistbot at llvm.org
Thu Mar 12 10:14:38 PDT 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/186187

When the IntRangeOptimizationsPass runs applyPatternsGreedily with constant folding enabled, constant folding can restructure blocks — for example by removing a block argument. The integer range solver, which pre-computed range information for the original block arguments, is now out of sync: subsequent range queries about the new (reused) argument positions return stale information computed for the old arguments, causing crashes.

Fix by passing enableConstantFolding(false) in GreedyRewriteConfig so that the solver's state remains consistent with the IR throughout the rewrite.

Fixes #122076

Assisted-by: Claude Code

>From 792131322b2b47bb01345965b001586b420d6a7f Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Feb 2026 16:37:46 -0800
Subject: [PATCH] [mlir][arith] Fix crash in IntRangeOptimizations due to stale
 solver state
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

When the IntRangeOptimizationsPass runs applyPatternsGreedily with constant
folding enabled, constant folding can restructure blocks — for example by
removing a block argument. The integer range solver, which pre-computed range
information for the original block arguments, is now out of sync: subsequent
range queries about the new (reused) argument positions return stale
information computed for the old arguments, causing crashes.

Fix by passing enableConstantFolding(false) in GreedyRewriteConfig so that
the solver's state remains consistent with the IR throughout the rewrite.

Adds a regression test that reproduces the crash.

Fixes #122076

Assisted-by: Claude Code
---
 .../Transforms/IntRangeOptimizations.cpp      |   3 +-
 mlir/test/Dialect/Arith/int-range-opts.mlir   | 104 ++++++++++++++++++
 2 files changed, 106 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index fefbba989b996..9a71650bcb31d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -678,7 +678,8 @@ struct IntRangeOptimizationsPass final
 
     if (failed(applyPatternsGreedily(
             op, std::move(patterns),
-            GreedyRewriteConfig().setListener(&listener))))
+            GreedyRewriteConfig().enableConstantFolding(false).setListener(
+                &listener))))
       signalPassFailure();
   }
 };
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index e6e48d30cece5..543223a51599c 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -148,3 +148,107 @@ func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64>
   %2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
   return %2 : tensor<128xi64>
 }
+
+// -----
+
+// Make sure int-range-optimizations does not crash when potential control-flow folding
+// is possible. Such folding does not keep the solver state up-to-date.
+// CHECK-LABEL: @no_crash_llvm_ops_with_folding
+module {
+  llvm.mlir.global private constant @assert_msg_23(dense<0x0> : tensor<1xi8>) : !llvm.array<3 x i8>  llvm.func @abort()
+  llvm.mlir.global private constant @assert_msg(dense<0x1> : tensor<1xi8>) : !llvm.array<3 x i8>
+  llvm.func @abort()
+  llvm.func @puts(!llvm.ptr)
+  llvm.func @malloc(i64) -> !llvm.ptr
+  func.func @no_crash_llvm_ops_with_folding(%arg0: memref<1x24x29xi32, strided<[?, ?, ?], offset: ?>>, %arg1: memref<1x24x29xi32, strided<[?, ?, ?], offset: ?>>, %arg2: memref<1x24x29xi32, strided<[?, ?, ?], offset: ?>>) -> memref<1x24x29xi32> {
+    %c29 = arith.constant 29 : index
+    %c1 = arith.constant 1 : index
+    %c24 = arith.constant 24 : index
+    %c0 = arith.constant 0 : index
+    %0 = llvm.mlir.constant(true) : i1
+    %1 = llvm.mlir.constant(64 : index) : i64
+    %6 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+    %14 = llvm.mlir.constant(0 : index) : i64
+    %20 = llvm.mlir.constant(24 : index) : i64
+    %21 = llvm.mlir.constant(29 : index) : i64
+    %22 = llvm.mlir.constant(1 : index) : i64
+    %23 = llvm.mlir.constant(696 : index) : i64
+    %24 = llvm.mlir.zero : !llvm.ptr
+    %96 = llvm.getelementptr %24[696] : (!llvm.ptr) -> !llvm.ptr, i32
+    %97 = llvm.ptrtoint %96 : !llvm.ptr to i64
+    %98 = llvm.add %97, %1 : i64
+    %99 = llvm.call @malloc(%98) : (i64) -> !llvm.ptr
+    %100 = llvm.ptrtoint %99 : !llvm.ptr to i64
+    %101 = llvm.sub %1, %22 : i64
+    %102 = llvm.add %100, %101 : i64
+    %103 = llvm.urem %102, %1 : i64
+    %104 = llvm.sub %102, %103 : i64
+    %105 = llvm.inttoptr %104 : i64 to !llvm.ptr
+    %106 = llvm.insertvalue %99, %6[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> 
+    %107 = llvm.insertvalue %105, %106[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> 
+    %108 = llvm.insertvalue %14, %107[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> 
+    %109 = llvm.insertvalue %22, %108[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> 
+    %110 = llvm.insertvalue %20, %109[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> 
+    %111 = llvm.insertvalue %21, %110[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> 
+    %112 = llvm.insertvalue %23, %111[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> 
+    %113 = llvm.insertvalue %21, %112[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> 
+    %114 = llvm.insertvalue %22, %113[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> 
+    %119 = builtin.unrealized_conversion_cast %114 : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x24x29xi32>
+    %collapse_shape = memref.collapse_shape %119 [[0, 1], [2]] : memref<1x24x29xi32> into memref<24x29xi32>
+    %120 = builtin.unrealized_conversion_cast %collapse_shape : memref<24x29xi32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    llvm.cond_br %0, ^bb72, ^bb81
+   ^bb72:  // pred: ^bb71
+    %collapse_shape_3 = memref.collapse_shape %119 [[0, 1], [2]] : memref<1x24x29xi32> into memref<24x29xi32>
+    %211 = builtin.unrealized_conversion_cast %collapse_shape_3 : memref<24x29xi32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    llvm.cond_br %0, ^bb80, ^bb105
+  ^bb80:  // pred: ^bb79
+    %213 = scf.while (%arg3 = %c0) : (index) -> index {
+      %374 = arith.cmpi slt, %arg3, %c24 : index
+      scf.condition(%374) %arg3 : index
+    } do {
+    ^bb0(%arg3: index):
+      %374 = arith.addi %arg3, %c1 : index
+      %375 = builtin.unrealized_conversion_cast %arg3 : index to i64
+      %376 = scf.while (%arg4 = %c0) : (index) -> index {
+        %true = arith.constant true
+        scf.condition(%true) %arg4 : index
+      } do {
+      ^bb0(%arg4: index):
+        scf.yield %arg4 : index
+      }
+      %377 = scf.while (%arg4 = %c0) : (index) -> index {
+        %378 = arith.cmpi slt, %arg4, %c29 : index
+        scf.condition(%378) %arg4 : index
+      } do {
+      ^bb0(%arg4: index):
+        %378 = arith.addi %arg4, %c1 : index
+        %379 = builtin.unrealized_conversion_cast %arg4 : index to i64
+        %380 = llvm.mul %375, %21 : i64
+        %381 = llvm.add %380, %379 : i64
+        %382 = llvm.getelementptr %105[%381] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+        %383 = llvm.load %382 : !llvm.ptr -> f32
+        %384 = llvm.intr.exp(%383) : (f32) -> f32
+        %385 = llvm.extractvalue %211[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+        %386 = llvm.mul %375, %21 : i64
+        %387 = llvm.add %386, %379 : i64
+        %388 = llvm.getelementptr %385[%387] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+        llvm.store %384, %388 : f32, !llvm.ptr
+        scf.yield %378 : index
+      }
+      scf.yield %374 : index
+    }
+    return %119 : memref<1x24x29xi32>
+  ^bb81:  // pred: ^bb0
+    %214 = llvm.mlir.addressof @assert_msg : !llvm.ptr
+    %215 = llvm.getelementptr %214[0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<558 x i8>
+    llvm.call @puts(%215) : (!llvm.ptr) -> ()
+    llvm.call @abort() : () -> ()
+  llvm.unreachable
+  ^bb105:  // pred: ^bb24
+    %262 = llvm.mlir.addressof @assert_msg_23 : !llvm.ptr
+    %263 = llvm.getelementptr %262[0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<674 x i8>
+    llvm.call @puts(%263) : (!llvm.ptr) -> ()
+    llvm.call @abort() : () -> ()
+    llvm.unreachable
+  }
+}
\ No newline at end of file



More information about the Mlir-commits mailing list