[Mlir-commits] [mlir] [MLIR][Utils] Fix overflow in constantTripCount for narrow types (PR #179985)
Jhalak Patel
llvmlistbot at llvm.org
Thu Feb 5 08:58:31 PST 2026
https://github.com/jhalakpatel created https://github.com/llvm/llvm-project/pull/179985
Extend operands when computing ub - lb to avoid overflow in signed arithmetic. E.g., i8: ub=127, lb=-128 yields 255, which overflows without extension.
>From 039ee00d7505a07fb926f5f1c11f1bcac0ac6277 Mon Sep 17 00:00:00 2001
From: Jhalak Patel <jhalakp at nvidia.com>
Date: Wed, 4 Feb 2026 22:25:45 -0800
Subject: [PATCH] [MLIR][Utils] Fix overflow in constantTripCount for narrow
types
Extend operands when computing ub - lb to avoid overflow in signed
arithmetic. E.g., i8: ub=127, lb=-128 yields 255, which overflows
without extension.
---
mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 28 +++++++++++++++++++--
mlir/test/Dialect/SCF/trip_count.mlir | 18 +++++++++++++
2 files changed, 44 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 1c19b995b9f3f..47b9161cf8fcc 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -365,6 +365,13 @@ std::optional<APInt> constantTripCount(
<< (isSigned ? "isSigned" : "isUnsigned") << ")";
return APInt(bitwidth, 0);
}
+ // Extend both operands to a wider bitwidth to avoid overflow when computing
+ // ub - lb. For example, with i8: ub=127, lb=-128, the difference is 255,
+ // which overflows in 8-bit signed arithmetic. We need at least one extra
+ // bit.
+ unsigned extendedWidth = bitwidth + 1;
+ lbCst = lbCst.extend(extendedWidth);
+ ubCst = ubCst.extend(extendedWidth);
diff = ubCst - lbCst;
} else {
if (maybeUbCst)
@@ -397,11 +404,28 @@ std::optional<APInt> constantTripCount(
return std::nullopt;
}
+ // Extend stepCst to match the bitwidth of diff if needed (e.g., when diff was
+ // extended to avoid overflow). Step is always positive here, so zero-extend.
+ llvm::APInt extendedStepCst = stepCst;
+ if (extendedStepCst.getBitWidth() < diff.getBitWidth()) {
+ extendedStepCst = extendedStepCst.zext(diff.getBitWidth());
+ }
+
// Create new APSInt instances with explicit signedness to ensure they match
- llvm::APInt tripCount = isSigned ? diff.sdiv(stepCst) : diff.udiv(stepCst);
- llvm::APInt remainder = isSigned ? diff.srem(stepCst) : diff.urem(stepCst);
+ llvm::APInt tripCount =
+ isSigned ? diff.sdiv(extendedStepCst) : diff.udiv(extendedStepCst);
+ llvm::APInt remainder =
+ isSigned ? diff.srem(extendedStepCst) : diff.urem(extendedStepCst);
if (!remainder.isZero())
tripCount = tripCount + 1;
+
+ // Truncate back to original bitwidth if we extended for overflow prevention.
+ // This is safe because ceil(diff/step) ≤ 2^bitwidth - 1, which always fits
+ // in bitwidth bits when interpreted as unsigned (trip counts are inherently
+ // non-negative regardless of loop comparison signedness).
+ if (tripCount.getBitWidth() > bitwidth)
+ tripCount = tripCount.trunc(bitwidth);
+
LDBG() << "constantTripCount found: " << tripCount;
return tripCount;
}
diff --git a/mlir/test/Dialect/SCF/trip_count.mlir b/mlir/test/Dialect/SCF/trip_count.mlir
index 927b405dbaea6..7e74988b35019 100644
--- a/mlir/test/Dialect/SCF/trip_count.mlir
+++ b/mlir/test/Dialect/SCF/trip_count.mlir
@@ -770,6 +770,24 @@ func.func @trip_count_i8_signed_crossing_zero(%a : i32, %b : i32) -> i32 {
// -----
+// CHECK-LABEL:func.func @trip_count_i8_signed_overflow_fix(
+func.func @trip_count_i8_signed_overflow_fix(%a : i32, %b : i32) -> i32 {
+ %c-128 = arith.constant -128 : i8
+ %c127 = arith.constant 127 : i8
+ %c1 = arith.constant 1 : i8
+
+ // Signed i8 from -128 to 127: tests overflow fix
+ // Without the fix, computing (127 - (-128)) would overflow in i8.
+ // The trip count should be 255, but will be printed as -1 in i8 signed format.
+ // CHECK: "test.trip-count" = -1 : i8
+ %r = scf.for %i = %c-128 to %c127 step %c1 iter_args(%0 = %a) -> i32 : i8 {
+ scf.yield %b : i32
+ }
+ return %r : i32
+}
+
+// -----
+
// CHECK-LABEL:func.func @trip_count_i16_unsigned_full_range(
func.func @trip_count_i16_unsigned_full_range(%a : i32, %b : i32) -> i32 {
%c0 = arith.constant 0 : i16
More information about the Mlir-commits
mailing list