[Mlir-commits] [mlir] [mlir][scf] Fix unsigned narrow type trip count calculation (PR #178060)
Jhalak Patel
llvmlistbot at llvm.org
Mon Jan 26 13:50:12 PST 2026
https://github.com/jhalakpatel created https://github.com/llvm/llvm-project/pull/178060
Fix an issue where unsigned narrow integer types (i8, i16, etc.) in SCF loop trip count calculations were incorrectly interpreted due to sign extension. When computing trip counts for loops with unsigned narrow types, values like 255 (0xFF in 8-bit) were being misinterpreted as -1 when using getSExtValue().
Changes:
- Use APSInt with explicit signedness for step validation and arithmetic
- Zero-extend unsigned narrow type results to 64 bits to prevent misinterpretation by callers using getSExtValue()
- Add comprehensive test cases covering unsigned i8 and i16 ranges
This ensures that loops like `for unsigned i8 %i = 0 to 255` correctly report a trip count of 255 instead of -1.
>From 6e835868c4f80b7b53cbe0ac76b0abb4439e2c88 Mon Sep 17 00:00:00 2001
From: Jhalak Patel <jhalakp at nvidia.com>
Date: Mon, 26 Jan 2026 13:46:13 -0800
Subject: [PATCH] [mlir][scf] Fix unsigned narrow type trip count calculation
Fix an issue where unsigned narrow integer types (i8, i16, etc.) in SCF
loop trip count calculations were incorrectly interpreted due to sign
extension. When computing trip counts for loops with unsigned narrow
types, values like 255 (0xFF in 8-bit) were being misinterpreted as -1
when using getSExtValue().
Changes:
- Use APSInt with explicit signedness for step validation and arithmetic
- Zero-extend unsigned narrow type results to 64 bits to prevent
misinterpretation by callers using getSExtValue()
- Add comprehensive test cases covering unsigned i8 and i16 ranges
This ensures that loops like `for unsigned i8 %i = 0 to 255` correctly
report a trip count of 255 instead of -1.
---
.../mlir/Dialect/Utils/StaticValueUtils.h | 3 +
mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 21 +++-
mlir/test/Dialect/SCF/trip_count.mlir | 106 ++++++++++++++++++
3 files changed, 126 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index ba8a0304de9d3..c8ebce1884466 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -228,6 +228,9 @@ LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
/// where %ub is computed as a static offset from %lb.
/// Note: the matched addition should be nsw/nuw (matching the loop comparison)
/// to avoid overflow, otherwise an overflow would imply a zero trip count.
+///
+/// For unsigned narrow types, the result is zero-extended to 64 bits to avoid
+/// misinterpretation when callers use getSExtValue().
std::optional<APInt> constantTripCount(
OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 7fb0d4e9710f8..74431c8cc62ea 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -336,7 +336,10 @@ std::optional<APInt> constantTripCount(
// case applies, so the static trip count is unknown.
return std::nullopt;
}
- if (stepCst.isNegative())
+ // For unsigned values, negative step is impossible; for signed, check the
+ // sign bit properly using APSInt.
+ APSInt stepSInt(stepCst, /*isUnsigned=*/!isSigned);
+ if (stepSInt.isNegative())
return APInt(bitwidth, 0);
}
@@ -391,12 +394,22 @@ std::optional<APInt> constantTripCount(
return std::nullopt;
}
auto &stepCst = maybeStepCst->first;
- llvm::APInt tripCount = isSigned ? diff.sdiv(stepCst) : diff.udiv(stepCst);
- llvm::APInt remainder = isSigned ? diff.srem(stepCst) : diff.urem(stepCst);
+ // Create new APSInt instances with explicit signedness to ensure they match
+ llvm::APSInt diffSigned(diff, /*isUnsigned=*/!isSigned);
+ llvm::APSInt stepSInt(stepCst, /*isUnsigned=*/!isSigned);
+ llvm::APSInt tripCount = diffSigned / stepSInt;
+ llvm::APSInt remainder = diffSigned % stepSInt;
if (!remainder.isZero())
tripCount = tripCount + 1;
LDBG() << "constantTripCount found: " << tripCount;
- return tripCount;
+
+ // For unsigned narrow types, zero-extend to 64 bits to avoid misinterpretation
+ // when callers use getSExtValue(). This ensures values like 255 (0xFF in 8-bit)
+ // are treated as positive when sign-extended.
+ APInt result = tripCount;
+ if (!isSigned && bitwidth < 64)
+ result = result.zext(64);
+ return result;
}
bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
diff --git a/mlir/test/Dialect/SCF/trip_count.mlir b/mlir/test/Dialect/SCF/trip_count.mlir
index 54883d7bb874c..c042723376a6c 100644
--- a/mlir/test/Dialect/SCF/trip_count.mlir
+++ b/mlir/test/Dialect/SCF/trip_count.mlir
@@ -699,4 +699,110 @@ func.func @trip_count_arith_add_nuw_loop_unsigned_invalid(%lb : i32, %other : i3
scf.yield %arg0 : i32
}
return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_i8_unsigned_full_range(
+func.func @trip_count_i8_unsigned_full_range(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : i8
+ %c255 = arith.constant 255 : i8
+ %c1 = arith.constant 1 : i8
+
+ // This is the bug case: unsigned i8 from 0 to 255
+ // Before fix: 0xFF in 8-bit was interpreted as -1 with getSExtValue()
+ // After fix: Narrow unsigned types are zero-extended to 64-bit
+ // Result: 255 as i64
+ // CHECK: "test.trip-count" = 255 : i64
+ %r = scf.for unsigned %i = %c0 to %c255 step %c1 iter_args(%0 = %a) -> i32 : i8 {
+ scf.yield %b : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_i8_unsigned_partial_range(
+func.func @trip_count_i8_unsigned_partial_range(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : i8
+ %c200 = arith.constant 200 : i8
+ %c1 = arith.constant 1 : i8
+
+ // Unsigned i8 from 0 to 200, should work correctly
+ // After fix: Narrow unsigned types are zero-extended to 64-bit
+ // Result: 200 as i64
+ // CHECK: "test.trip-count" = 200 : i64
+ %r = scf.for unsigned %i = %c0 to %c200 step %c1 iter_args(%0 = %a) -> i32 : i8 {
+ scf.yield %b : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_i8_unsigned_high_range(
+func.func @trip_count_i8_unsigned_high_range(%a : i32, %b : i32) -> i32 {
+ %c200 = arith.constant 200 : i8
+ %c255 = arith.constant 255 : i8
+ %c1 = arith.constant 1 : i8
+
+ // Unsigned i8 from 200 to 255
+ // After fix: Narrow unsigned types are zero-extended to 64-bit
+ // CHECK: "test.trip-count" = 55 : i64
+ %r = scf.for unsigned %i = %c200 to %c255 step %c1 iter_args(%0 = %a) -> i32 : i8 {
+ scf.yield %b : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_i8_signed_crossing_zero(
+func.func @trip_count_i8_signed_crossing_zero(%a : i32, %b : i32) -> i32 {
+ %c-128 = arith.constant -128 : i32
+ %c127 = arith.constant 127 : i32
+ %c1 = arith.constant 1 : i32
+
+ // Signed i32 from -128 to 127, crossing zero
+ // CHECK: "test.trip-count" = 255
+ %r = scf.for %i = %c-128 to %c127 step %c1 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_i16_unsigned_full_range(
+func.func @trip_count_i16_unsigned_full_range(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : i16
+ %c65535 = arith.constant 65535 : i16
+ %c1 = arith.constant 1 : i16
+
+ // Unsigned i16 from 0 to 65535, should be 65535
+ // After fix: Narrow unsigned types are zero-extended to 64-bit
+ // Result: 65535 as i64
+ // CHECK: "test.trip-count" = 65535 : i64
+ %r = scf.for unsigned %i = %c0 to %c65535 step %c1 iter_args(%0 = %a) -> i32 : i16 {
+ scf.yield %b : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_i8_unsigned_step_2(
+func.func @trip_count_i8_unsigned_step_2(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : i8
+ %c255 = arith.constant 255 : i8
+ %c2 = arith.constant 2 : i8
+
+ // Unsigned i8 from 0 to 255 step 2, should be 128 (255/2 rounded up)
+ // After fix: Narrow unsigned types are zero-extended to 64-bit
+ // Result: 128 as i64
+ // CHECK: "test.trip-count" = 128 : i64
+ %r = scf.for unsigned %i = %c0 to %c255 step %c2 iter_args(%0 = %a) -> i32 : i8 {
+ scf.yield %b : i32
+ }
+ return %r : i32
}
\ No newline at end of file
More information about the Mlir-commits
mailing list