[Mlir-commits] [mlir] [mlir][Arith] Avoid sign overflow when narrowing signed operations (PR #189676)
Artem Gindinson
llvmlistbot at llvm.org
Tue Mar 31 07:13:56 PDT 2026
https://github.com/AGindinson created https://github.com/llvm/llvm-project/pull/189676
Whether an arith operation can be truncated to a given bitwidth should also depend on the sign semantics of the operation itself. Consider:
```
%input = /* upper bound > INT32_MAX, <= UINT32_MAX */ : index
%c0 = arith.constant 0 : index
%cmp = arith.cmpi sle, %input, %c0 : index
```
Previously, `checkTruncatability()` would correctly judge that only an unsigned truncation could be legal, however the narrowing would still proceed despite the fact that the `sle` predicate treated the MSB as the sign.
Ensure that the sign is checked for signed comparison predicates and for signed elementwise operations by enforcing a `CastKind::Signed` restriction, whereby the narrowing patterns bail out on incompatible input range/operation signedness.
**AI tooling usage disclaimer**
LIT tests were expanded from manual reproducer examples with LLM assistance. Those additional test cases were verified to regression-test, proofread and edited manually in accordance with the "Human in the loop" policy. LLMs/ generative tooling were not used for implementation/documentation purposes.
>From c3b544f3e780f2445b87d7e880e926721298d26e Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 27 Mar 2026 11:15:09 +0000
Subject: [PATCH] [mlir][Arith] Avoid sign overflow when narrowing signed
operations
Whether an arith operation can be truncated to a given bitwidth should also
depend on the sign semantics of the operation itself. Consider:
```
%input = /* upper bound > INT32_MAX, <= UINT32_MAX */ : index
%c0 = arith.constant 0 : index
%cmp = arith.cmpi sle, %input, %c0 : index
```
Previously, `checkTruncatability()` would correctly judge that only an unsigned
truncation could be legal, however the narrowing would still proceed despite the
fact that the `sle` predicate treated the MSB as the sign.
Ensure that the sign is checked for signed comparison predicates and for signed
elementwise operations by enforcing a `CastKind::Signed` restriction, whereby
the narrowing patterns bail out on incompatible input range/operation signedness.
**AI tooling usage disclaimer**
LIT tests were expanded from manual reproducer examples with LLM assistance.
Those additional test cases were verified to regression-test, proofread and
edited manually in accordance with the "Human in the loop" policy. LLMs/
generative tooling were not used for implementation/documentation purposes.
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
Co-authored-by: GPT 5.4 <codex at openai.com>
---
.../Transforms/IntRangeOptimizations.cpp | 28 +++++-
.../Dialect/Arith/int-range-narrowing.mlir | 97 +++++++++++++++++++
2 files changed, 124 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 6acfc2c15af42..85578c22799c6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -8,6 +8,8 @@
#include <utility>
+#include "llvm/ADT/TypeSwitch.h"
+
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/Utils.h"
#include "mlir/Analysis/DataFlowFramework.h"
@@ -356,6 +358,16 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
if (castKind == CastKind::None)
break;
}
+ // For operations that explicitly treat the values as signed, we should
+ // only do signed casts, if those are deemed possible as such based on the
+ // value range.
+ auto castKindForOp =
+ llvm::TypeSwitch<Operation *, CastKind>(op)
+ .Case<arith::DivSIOp, arith::CeilDivSIOp, arith::FloorDivSIOp,
+ arith::RemSIOp, arith::MaxSIOp, arith::MinSIOp,
+ arith::ShRSIOp>([](auto) { return CastKind::Signed; })
+ .Default(CastKind::Both);
+ castKind = mergeCastKinds(castKind, castKindForOp);
if (castKind == CastKind::None)
continue;
Type targetType = getTargetType(srcType, targetBitwidth);
@@ -414,12 +426,26 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
const ConstantIntRanges &lhsRange = ranges[0];
const ConstantIntRanges &rhsRange = ranges[1];
+ auto isSignedCmpPredicate = [](arith::CmpIPredicate pred) -> bool {
+ return pred == arith::CmpIPredicate::sge ||
+ pred == arith::CmpIPredicate::sgt ||
+ pred == arith::CmpIPredicate::sle ||
+ pred == arith::CmpIPredicate::slt;
+ };
+ // If we're to narrow the input values via a cast, we should preserve the
+ // sign.
+ CastKind predicateBasedCastRestriction =
+ isSignedCmpPredicate(op.getPredicate()) ? CastKind::Signed
+ : CastKind::Both;
+
Type srcType = lhs.getType();
for (unsigned targetBitwidth : targetBitwidths) {
CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
- // Note: this includes target width > src width.
+ castKind = mergeCastKinds(castKind, predicateBasedCastRestriction);
+ // Note: this includes target width > src width, as well as the unsigned
+ // truncatability & signed predicate scenario.
if (castKind == CastKind::None)
continue;
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index c3b0d280b1350..e64ca3b50f6e7 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -79,6 +79,52 @@ func.func @test_cmpi() -> i1 {
return %2 : i1
}
+// CHECK-LABEL: func @test_cmpi_si_pred_out_of_signed_bounds
+// CHECK-NOT: arith.cmpi slt, {{.*}} : i32
+// CHECK-NOT: arith.cmpi sgt, {{.*}} : i32
+// CHECK-NOT: arith.cmpi sle, {{.*}} : i32
+// CHECK-NOT: arith.cmpi sge, {{.*}} : i32
+// CHECK: %[[A:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index
+// CHECK: %[[B:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index
+// CHECK: %[[SLT:.*]] = arith.cmpi slt, %[[B]], %[[A]] : index
+// CHECK: %[[C:.*]] = test.with_bounds {smax = 2147483648 : index, smin = 0 : index, umax = 2147483648 : index, umin = 0 : index} : index
+// CHECK: %[[ZERO:.*]] = test.with_bounds {smax = 0 : index, smin = 0 : index, umax = 0 : index, umin = 0 : index} : index
+// CHECK: %[[SGT:.*]] = arith.cmpi sgt, %[[C]], %[[ZERO]] : index
+// CHECK: %[[SLE:.*]] = arith.cmpi sle, %[[A]], %[[ZERO]] : index
+// CHECK: %[[SGE:.*]] = arith.cmpi sge, %[[C]], %[[ZERO]] : index
+// CHECK: %[[AND0:.*]] = arith.andi %[[SLT]], %[[SGT]] : i1
+// CHECK: %[[AND1:.*]] = arith.andi %[[SLE]], %[[SGE]] : i1
+// CHECK: %[[AND2:.*]] = arith.andi %[[AND0]], %[[AND1]] : i1
+// CHECK: return %[[AND2]] : i1
+func.func @test_cmpi_si_pred_out_of_signed_bounds() -> i1 {
+ %0 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index
+ %1 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index
+ %2 = arith.cmpi slt, %1, %0 : index
+ %3 = test.with_bounds { umin = 0 : index, umax = 2147483648 : index, smin = 0 : index, smax = 2147483648 : index } : index
+ %4 = test.with_bounds { umin = 0 : index, umax = 0 : index, smin = 0 : index, smax = 0 : index } : index
+ %5 = arith.cmpi sgt, %3, %4 : index
+ %6 = arith.cmpi sle, %0, %4 : index
+ %7 = arith.cmpi sge, %3, %4 : index
+ %8 = arith.andi %2, %5 : i1
+ %9 = arith.andi %6, %7 : i1
+ %10 = arith.andi %8, %9 : i1
+ return %10 : i1
+}
+
+// CHECK-LABEL: func @test_cmpi_ui_pred_out_of_signed_bounds
+// CHECK: %[[A:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index
+// CHECK: %[[B:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index
+// CHECK: %[[A_I32:.*]] = arith.index_castui %[[A]] : index to i32
+// CHECK: %[[B_I32:.*]] = arith.index_castui %[[B]] : index to i32
+// CHECK: %[[RES:.*]] = arith.cmpi ult, %[[A_I32]], %[[B_I32]] : i32
+// CHECK: return %[[RES]] : i1
+func.func @test_cmpi_ui_pred_out_of_signed_bounds() -> i1 {
+ %0 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index
+ %1 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index
+ %2 = arith.cmpi ult, %0, %1 : index
+ return %2 : i1
+}
+
// CHECK-LABEL: func @test_cmpi_vec
// CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
// CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
@@ -224,6 +270,57 @@ func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
return %r : i32
}
+// CHECK-LABEL: func.func @signed_ops_out_of_narrowed_signed_range
+// CHECK-NOT: arith.divsi {{.*}} : i32
+// CHECK-NOT: arith.ceildivsi {{.*}} : i32
+// CHECK-NOT: arith.floordivsi {{.*}} : i32
+// CHECK-NOT: arith.remsi {{.*}} : i32
+// CHECK-NOT: arith.maxsi {{.*}} : i32
+// CHECK-NOT: arith.minsi {{.*}} : i32
+// CHECK-NOT: arith.shrsi {{.*}} : i32
+// CHECK: %[[DIV_I64:.*]] = arith.divsi {{.*}} : i64
+// CHECK: %[[CEIL_I64:.*]] = arith.ceildivsi {{.*}} : i64
+// CHECK: %[[FLOOR_I64:.*]] = arith.floordivsi {{.*}} : i64
+// CHECK: %[[REM_I64:.*]] = arith.remsi {{.*}} : i64
+// CHECK: %[[MAX_I64:.*]] = arith.maxsi {{.*}} : i64
+// CHECK: %[[MIN_I64:.*]] = arith.minsi {{.*}} : i64
+// CHECK: %[[SHR_I64:.*]] = arith.shrsi {{.*}} : i64
+// CHECK: return %[[DIV_I64]], %[[CEIL_I64]], %[[FLOOR_I64]], %[[REM_I64]], %[[MAX_I64]], %[[MIN_I64]], %[[SHR_I64]] : i64, i64, i64, i64, i64, i64, i64
+func.func @signed_ops_out_of_narrowed_signed_range() -> (i64, i64, i64, i64, i64, i64, i64) {
+ %0 = test.with_bounds { umin = 0 : i64, umax = 4292870144 : i64, smin = 0 : i64, smax = 4292870144 : i64 } : i64
+ %1 = test.with_bounds { umin = 1 : i64, umax = 8 : i64, smin = 1 : i64, smax = 8 : i64 } : i64
+ %2 = test.with_bounds { umin = 0 : i64, umax = 0 : i64, smin = 0 : i64, smax = 0 : i64 } : i64
+ %3 = arith.divsi %0, %1 : i64
+ %4 = arith.ceildivsi %0, %1 : i64
+ %5 = arith.floordivsi %0, %1 : i64
+ %6 = arith.remsi %0, %1 : i64
+ %7 = arith.maxsi %0, %2 : i64
+ %8 = arith.minsi %0, %2 : i64
+ %9 = arith.shrsi %0, %1 : i64
+ return %3, %4, %5, %6, %7, %8, %9 : i64, i64, i64, i64, i64, i64, i64
+}
+
+// CHECK-LABEL: func.func @unsigned_ops_out_of_narrowed_signed_range
+// CHECK: arith.divui {{.*}} : i32
+// CHECK: arith.ceildivui {{.*}} : i32
+// CHECK: arith.remui {{.*}} : i32
+// CHECK: arith.maxui {{.*}} : i32
+// CHECK: arith.minui {{.*}} : i32
+// CHECK: arith.shrui {{.*}} : i32
+// CHECK: return %{{.*}} : i64, i64, i64, i64, i64, i64
+func.func @unsigned_ops_out_of_narrowed_signed_range() -> (i64, i64, i64, i64, i64, i64) {
+ %0 = test.with_bounds { umin = 0 : i64, umax = 4292870144 : i64, smin = 0 : i64, smax = 4292870144 : i64 } : i64
+ %1 = test.with_bounds { umin = 1 : i64, umax = 8 : i64, smin = 1 : i64, smax = 8 : i64 } : i64
+ %2 = test.with_bounds { umin = 0 : i64, umax = 0 : i64, smin = 0 : i64, smax = 0 : i64 } : i64
+ %3 = arith.divui %0, %1 : i64
+ %4 = arith.ceildivui %0, %1 : i64
+ %5 = arith.remui %0, %1 : i64
+ %6 = arith.maxui %0, %2 : i64
+ %7 = arith.minui %0, %2 : i64
+ %8 = arith.shrui %0, %1 : i64
+ return %3, %4, %5, %6, %7, %8 : i64, i64, i64, i64, i64, i64
+}
+
//===----------------------------------------------------------------------===//
// arith.muli
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list