[Mlir-commits] [mlir] [mlir][Arith] Avoid sign overflow when narrowing signed operations (PR #189676)

Artem Gindinson llvmlistbot at llvm.org
Tue Mar 31 08:54:03 PDT 2026


https://github.com/AGindinson updated https://github.com/llvm/llvm-project/pull/189676

>From c3b544f3e780f2445b87d7e880e926721298d26e Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 27 Mar 2026 11:15:09 +0000
Subject: [PATCH 1/2] [mlir][Arith] Avoid sign overflow when narrowing signed
 operations

Whether an arith operation can be truncated to a given bitwidth should also
depend on the sign semantics of the operation itself. Consider:
```
%input = /* upper bound > INT32_MAX, <= UINT32_MAX */ : index
%c0 = arith.constant 0 : index
%cmp = arith.cmpi sle, %input, %c0 : index
```

Previously, `checkTruncatability()` would correctly judge that only an unsigned
truncation could be legal, however the narrowing would still proceed despite the
fact that the `sle` predicate treated the MSB as the sign.

Ensure that the sign is checked for signed comparison predicates and for signed
elementwise operations by enforcing a `CastKind::Signed` restriction, whereby
the narrowing patterns bail out on incompatible input range/operation signedness.

**AI tooling usage disclaimer**
LIT tests were expanded from manual reproducer examples with LLM assistance.
Those additional test cases were verified to regression-test, proofread and
edited manually in accordance with the "Human in the loop" policy. LLMs/
generative tooling were not used for implementation/documentation purposes.

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
Co-authored-by: GPT 5.4 <codex at openai.com>
---
 .../Transforms/IntRangeOptimizations.cpp      | 28 +++++-
 .../Dialect/Arith/int-range-narrowing.mlir    | 97 +++++++++++++++++++
 2 files changed, 124 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 6acfc2c15af42..85578c22799c6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -8,6 +8,8 @@
 
 #include <utility>
 
+#include "llvm/ADT/TypeSwitch.h"
+
 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Analysis/DataFlow/Utils.h"
 #include "mlir/Analysis/DataFlowFramework.h"
@@ -356,6 +358,16 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
         if (castKind == CastKind::None)
           break;
       }
+      // For operations that explicitly treat the values as signed, we should
+      // only do signed casts, if those are deemed possible as such based on the
+      // value range.
+      auto castKindForOp =
+          llvm::TypeSwitch<Operation *, CastKind>(op)
+              .Case<arith::DivSIOp, arith::CeilDivSIOp, arith::FloorDivSIOp,
+                    arith::RemSIOp, arith::MaxSIOp, arith::MinSIOp,
+                    arith::ShRSIOp>([](auto) { return CastKind::Signed; })
+              .Default(CastKind::Both);
+      castKind = mergeCastKinds(castKind, castKindForOp);
       if (castKind == CastKind::None)
         continue;
       Type targetType = getTargetType(srcType, targetBitwidth);
@@ -414,12 +426,26 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
     const ConstantIntRanges &lhsRange = ranges[0];
     const ConstantIntRanges &rhsRange = ranges[1];
 
+    auto isSignedCmpPredicate = [](arith::CmpIPredicate pred) -> bool {
+      return pred == arith::CmpIPredicate::sge ||
+             pred == arith::CmpIPredicate::sgt ||
+             pred == arith::CmpIPredicate::sle ||
+             pred == arith::CmpIPredicate::slt;
+    };
+    // If we're to narrow the input values via a cast, we should preserve the
+    // sign.
+    CastKind predicateBasedCastRestriction =
+        isSignedCmpPredicate(op.getPredicate()) ? CastKind::Signed
+                                                : CastKind::Both;
+
     Type srcType = lhs.getType();
     for (unsigned targetBitwidth : targetBitwidths) {
       CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
       CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
       CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
-      // Note: this includes target width > src width.
+      castKind = mergeCastKinds(castKind, predicateBasedCastRestriction);
+      // Note: this includes target width > src width, as well as the unsigned
+      // truncatability & signed predicate scenario.
       if (castKind == CastKind::None)
         continue;
 
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index c3b0d280b1350..e64ca3b50f6e7 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -79,6 +79,52 @@ func.func @test_cmpi() -> i1 {
   return %2 : i1
 }
 
+// CHECK-LABEL: func @test_cmpi_si_pred_out_of_signed_bounds
+//       CHECK-NOT: arith.cmpi slt, {{.*}} : i32
+//       CHECK-NOT: arith.cmpi sgt, {{.*}} : i32
+//       CHECK-NOT: arith.cmpi sle, {{.*}} : i32
+//       CHECK-NOT: arith.cmpi sge, {{.*}} : i32
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index
+//       CHECK:  %[[SLT:.*]] = arith.cmpi slt, %[[B]], %[[A]] : index
+//       CHECK:  %[[C:.*]] = test.with_bounds {smax = 2147483648 : index, smin = 0 : index, umax = 2147483648 : index, umin = 0 : index} : index
+//       CHECK:  %[[ZERO:.*]] = test.with_bounds {smax = 0 : index, smin = 0 : index, umax = 0 : index, umin = 0 : index} : index
+//       CHECK:  %[[SGT:.*]] = arith.cmpi sgt, %[[C]], %[[ZERO]] : index
+//       CHECK:  %[[SLE:.*]] = arith.cmpi sle, %[[A]], %[[ZERO]] : index
+//       CHECK:  %[[SGE:.*]] = arith.cmpi sge, %[[C]], %[[ZERO]] : index
+//       CHECK:  %[[AND0:.*]] = arith.andi %[[SLT]], %[[SGT]] : i1
+//       CHECK:  %[[AND1:.*]] = arith.andi %[[SLE]], %[[SGE]] : i1
+//       CHECK:  %[[AND2:.*]] = arith.andi %[[AND0]], %[[AND1]] : i1
+//       CHECK:  return %[[AND2]] : i1
+func.func @test_cmpi_si_pred_out_of_signed_bounds() -> i1 {
+  %0 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index
+  %1 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index
+  %2 = arith.cmpi slt, %1, %0 : index
+  %3 = test.with_bounds { umin = 0 : index, umax = 2147483648 : index, smin = 0 : index, smax = 2147483648 : index } : index
+  %4 = test.with_bounds { umin = 0 : index, umax = 0 : index, smin = 0 : index, smax = 0 : index } : index
+  %5 = arith.cmpi sgt, %3, %4 : index
+  %6 = arith.cmpi sle, %0, %4 : index
+  %7 = arith.cmpi sge, %3, %4 : index
+  %8 = arith.andi %2, %5 : i1
+  %9 = arith.andi %6, %7 : i1
+  %10 = arith.andi %8, %9 : i1
+  return %10 : i1
+}
+
+// CHECK-LABEL: func @test_cmpi_ui_pred_out_of_signed_bounds
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index
+//       CHECK:  %[[A_I32:.*]] = arith.index_castui %[[A]] : index to i32
+//       CHECK:  %[[B_I32:.*]] = arith.index_castui %[[B]] : index to i32
+//       CHECK:  %[[RES:.*]] = arith.cmpi ult, %[[A_I32]], %[[B_I32]] : i32
+//       CHECK:  return %[[RES]] : i1
+func.func @test_cmpi_ui_pred_out_of_signed_bounds() -> i1 {
+  %0 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index
+  %1 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index
+  %2 = arith.cmpi ult, %0, %1 : index
+  return %2 : i1
+}
+
 // CHECK-LABEL: func @test_cmpi_vec
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
@@ -224,6 +270,57 @@ func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
   return %r : i32
 }
 
+// CHECK-LABEL: func.func @signed_ops_out_of_narrowed_signed_range
+// CHECK-NOT: arith.divsi {{.*}} : i32
+// CHECK-NOT: arith.ceildivsi {{.*}} : i32
+// CHECK-NOT: arith.floordivsi {{.*}} : i32
+// CHECK-NOT: arith.remsi {{.*}} : i32
+// CHECK-NOT: arith.maxsi {{.*}} : i32
+// CHECK-NOT: arith.minsi {{.*}} : i32
+// CHECK-NOT: arith.shrsi {{.*}} : i32
+// CHECK: %[[DIV_I64:.*]] = arith.divsi {{.*}} : i64
+// CHECK: %[[CEIL_I64:.*]] = arith.ceildivsi {{.*}} : i64
+// CHECK: %[[FLOOR_I64:.*]] = arith.floordivsi {{.*}} : i64
+// CHECK: %[[REM_I64:.*]] = arith.remsi {{.*}} : i64
+// CHECK: %[[MAX_I64:.*]] = arith.maxsi {{.*}} : i64
+// CHECK: %[[MIN_I64:.*]] = arith.minsi {{.*}} : i64
+// CHECK: %[[SHR_I64:.*]] = arith.shrsi {{.*}} : i64
+// CHECK: return %[[DIV_I64]], %[[CEIL_I64]], %[[FLOOR_I64]], %[[REM_I64]], %[[MAX_I64]], %[[MIN_I64]], %[[SHR_I64]] : i64, i64, i64, i64, i64, i64, i64
+func.func @signed_ops_out_of_narrowed_signed_range() -> (i64, i64, i64, i64, i64, i64, i64) {
+  %0 = test.with_bounds { umin = 0 : i64, umax = 4292870144 : i64, smin = 0 : i64, smax = 4292870144 : i64 } : i64
+  %1 = test.with_bounds { umin = 1 : i64, umax = 8 : i64, smin = 1 : i64, smax = 8 : i64 } : i64
+  %2 = test.with_bounds { umin = 0 : i64, umax = 0 : i64, smin = 0 : i64, smax = 0 : i64 } : i64
+  %3 = arith.divsi %0, %1 : i64
+  %4 = arith.ceildivsi %0, %1 : i64
+  %5 = arith.floordivsi %0, %1 : i64
+  %6 = arith.remsi %0, %1 : i64
+  %7 = arith.maxsi %0, %2 : i64
+  %8 = arith.minsi %0, %2 : i64
+  %9 = arith.shrsi %0, %1 : i64
+  return %3, %4, %5, %6, %7, %8, %9 : i64, i64, i64, i64, i64, i64, i64
+}
+
+// CHECK-LABEL: func.func @unsigned_ops_out_of_narrowed_signed_range
+// CHECK: arith.divui {{.*}} : i32
+// CHECK: arith.ceildivui {{.*}} : i32
+// CHECK: arith.remui {{.*}} : i32
+// CHECK: arith.maxui {{.*}} : i32
+// CHECK: arith.minui {{.*}} : i32
+// CHECK: arith.shrui {{.*}} : i32
+// CHECK: return %{{.*}} : i64, i64, i64, i64, i64, i64
+func.func @unsigned_ops_out_of_narrowed_signed_range() -> (i64, i64, i64, i64, i64, i64) {
+  %0 = test.with_bounds { umin = 0 : i64, umax = 4292870144 : i64, smin = 0 : i64, smax = 4292870144 : i64 } : i64
+  %1 = test.with_bounds { umin = 1 : i64, umax = 8 : i64, smin = 1 : i64, smax = 8 : i64 } : i64
+  %2 = test.with_bounds { umin = 0 : i64, umax = 0 : i64, smin = 0 : i64, smax = 0 : i64 } : i64
+  %3 = arith.divui %0, %1 : i64
+  %4 = arith.ceildivui %0, %1 : i64
+  %5 = arith.remui %0, %1 : i64
+  %6 = arith.maxui %0, %2 : i64
+  %7 = arith.minui %0, %2 : i64
+  %8 = arith.shrui %0, %1 : i64
+  return %3, %4, %5, %6, %7, %8 : i64, i64, i64, i64, i64, i64
+}
+
 //===----------------------------------------------------------------------===//
 // arith.muli
 //===----------------------------------------------------------------------===//

>From 51f93e2b0daf9542355ac1cc1275322528a24283 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Tue, 31 Mar 2026 15:53:31 +0000
Subject: [PATCH 2/2] [fixup] reduce the less important check line

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/test/Dialect/Arith/int-range-narrowing.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index e64ca3b50f6e7..7ba22af0c0f1b 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -285,7 +285,7 @@ func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
 // CHECK: %[[MAX_I64:.*]] = arith.maxsi {{.*}} : i64
 // CHECK: %[[MIN_I64:.*]] = arith.minsi {{.*}} : i64
 // CHECK: %[[SHR_I64:.*]] = arith.shrsi {{.*}} : i64
-// CHECK: return %[[DIV_I64]], %[[CEIL_I64]], %[[FLOOR_I64]], %[[REM_I64]], %[[MAX_I64]], %[[MIN_I64]], %[[SHR_I64]] : i64, i64, i64, i64, i64, i64, i64
+// CHECK: return %{{.*}} : i64, i64, i64, i64, i64, i64, i64
 func.func @signed_ops_out_of_narrowed_signed_range() -> (i64, i64, i64, i64, i64, i64, i64) {
   %0 = test.with_bounds { umin = 0 : i64, umax = 4292870144 : i64, smin = 0 : i64, smax = 4292870144 : i64 } : i64
   %1 = test.with_bounds { umin = 1 : i64, umax = 8 : i64, smin = 1 : i64, smax = 8 : i64 } : i64



More information about the Mlir-commits mailing list