[Mlir-commits] [mlir] [mlir][scf] Fix trip count signedness and overflow in SCF Utils (PR #178782)
Jhalak Patel
llvmlistbot at llvm.org
Thu Jan 29 15:57:23 PST 2026
https://github.com/jhalakpatel created https://github.com/llvm/llvm-project/pull/178782
Fixes #178506
This commit addresses multiple issues with trip count handling in mlir/lib/Dialect/SCF/Utils/Utils.cpp:
1. Changed getConstLoopTripCounts return type from SmallVector<int64_t> to SmallVector<uint64_t> since trip counts are inherently unsigned.
2. Added overflow checks before calling APInt::getZExtValue() to ensure values fit in uint64_t:
- loopUnrollByFactor: Check trip count fits in 64 bits and INT64_MAX
- loopUnrollFull: Check trip count before extraction
- loopUnrollJamByFactor: Check trip count and step values
3. Fixed getConstLoopTripCounts to derive signedness from the loop operation itself (using ForOp::getUnsignedCmp() when available) rather than hardcoding isSigned=false. This ensures correct handling of both signed and unsigned loop comparisons.
4. Added checks to ensure trip counts fit in int64_t before using them in signed arithmetic operations (when multiplying with potentially negative step values).
5. Updated parallelLoopUnrollByFactors to use uint64_t for trip counts.
All overflow-prone operations now return failure() instead of silently overflowing, making the code more robust.
>From 06516615d93e332bc5f7ea8064d1be5f8a1bc680 Mon Sep 17 00:00:00 2001
From: Jhalak Patel <jhalakp at nvidia.com>
Date: Thu, 29 Jan 2026 10:52:01 -0800
Subject: [PATCH] [mlir][scf] Fix trip count signedness and overflow in SCF
Utils
Fixes #178506
This commit addresses multiple issues with trip count handling in
mlir/lib/Dialect/SCF/Utils/Utils.cpp:
1. Changed getConstLoopTripCounts return type from SmallVector<int64_t>
to SmallVector<uint64_t> since trip counts are inherently unsigned.
2. Added overflow checks before calling APInt::getZExtValue() to ensure
values fit in uint64_t:
- loopUnrollByFactor: Check trip count fits in 64 bits and INT64_MAX
- loopUnrollFull: Check trip count before extraction
- loopUnrollJamByFactor: Check trip count and step values
3. Fixed getConstLoopTripCounts to derive signedness from the loop
operation itself (using ForOp::getUnsignedCmp() when available)
rather than hardcoding isSigned=false. This ensures correct
handling of both signed and unsigned loop comparisons.
4. Added checks to ensure trip counts fit in int64_t before using them
in signed arithmetic operations (when multiplying with potentially
negative step values).
5. Updated parallelLoopUnrollByFactors to use uint64_t for trip counts.
All overflow-prone operations now return failure() instead of
silently overflowing, making the code more robust.
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 3 +-
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 76 +++++++++++++++------
2 files changed, 56 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 881125bd0da2f..59e66567c5ae8 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -251,8 +251,7 @@ FailureOr<scf::ParallelOp> parallelLoopUnrollByFactors(
/// Get constant trip counts for each of the induction variables of the given
/// loop operation. If any of the loop's trip counts is not constant, return an
/// empty vector.
-/// TODO(#178506): Should return SmallVector<uint64_t> for correct signedness.
-llvm::SmallVector<int64_t>
+llvm::SmallVector<uint64_t>
getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp);
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 3892904646da7..4c172bc798e97 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -397,14 +397,26 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
return UnrolledLoopInfo{forOp, std::nullopt};
}
- // TODO(#178506): This may overflow for large trip counts. Should use
- // uint64_t.
- int64_t tripCountEvenMultiple =
- constTripCount->getZExtValue() -
- (constTripCount->getZExtValue() % unrollFactor);
- // TODO(#178506): This may overflow when computing upperBoundUnrolledCst.
- int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
- int64_t stepUnrolledCst = stepCst * unrollFactor;
+ // Check for overflow before extracting trip count.
+ if (constTripCount->getActiveBits() > 64)
+ return failure();
+ uint64_t tripCountValue = constTripCount->getZExtValue();
+ uint64_t tripCountEvenMultiple =
+ tripCountValue - (tripCountValue % unrollFactor);
+
+ // Since we need to compute: upperBoundUnrolled = lbCst + tripCountEvenMultiple * stepCst
+ // and the result must fit in int64_t, check that tripCountEvenMultiple fits in int64_t
+ // when we need to multiply it.
+ if (tripCountEvenMultiple > static_cast<uint64_t>(INT64_MAX))
+ return failure();
+
+ int64_t upperBoundUnrolledCst = lbCst +
+ static_cast<int64_t>(tripCountEvenMultiple) * stepCst;
+
+ // Check step * unrollFactor fits in int64_t.
+ if (unrollFactor > static_cast<uint64_t>(INT64_MAX))
+ return failure();
+ int64_t stepUnrolledCst = stepCst * static_cast<int64_t>(unrollFactor);
// Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
@@ -504,9 +516,13 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
const APInt &tripCount = *mayBeConstantTripCount;
if (tripCount.isZero())
return success();
- if (tripCount.getZExtValue() == 1)
+ // Check for overflow before extracting trip count.
+ if (tripCount.getActiveBits() > 64)
+ return failure();
+ uint64_t tripCountValue = tripCount.getZExtValue();
+ if (tripCountValue == 1)
return forOp.promoteIfSingleIteration(rewriter);
- return loopUnrollByFactor(forOp, tripCount.getZExtValue());
+ return loopUnrollByFactor(forOp, tripCountValue);
}
/// Check if bounds of all inner loops are defined outside of `forOp`
@@ -552,12 +568,18 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
LDBG() << "failed to unroll and jam: trip count could not be determined";
return failure();
}
- if (unrollJamFactor > tripCount->getZExtValue()) {
+ // Check for overflow before extracting trip count.
+ if (tripCount->getActiveBits() > 64) {
+ LDBG() << "failed to unroll and jam: trip count too large";
+ return failure();
+ }
+ uint64_t tripCountValue = tripCount->getZExtValue();
+ if (unrollJamFactor > tripCountValue) {
LDBG() << "unroll and jam factor is greater than trip count, set factor to "
"trip "
"count";
- unrollJamFactor = tripCount->getZExtValue();
- } else if (tripCount->getZExtValue() % unrollJamFactor != 0) {
+ unrollJamFactor = tripCountValue;
+ } else if (tripCountValue % unrollJamFactor != 0) {
LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
"multiple of unroll jam factor";
return failure();
@@ -632,7 +654,11 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
// Scale the step of loop being unroll-jammed by the unroll-jam factor.
rewriter.setInsertionPoint(forOp);
- int64_t step = forOp.getConstantStep()->getSExtValue();
+ std::optional<APInt> stepAPInt = forOp.getConstantStep();
+ // Check for overflow before extracting step value.
+ if (!stepAPInt || stepAPInt->getSignificantBits() > 64)
+ return failure();
+ int64_t step = stepAPInt->getSExtValue();
auto newStep = rewriter.createOrFold<arith::MulIOp>(
forOp.getLoc(), forOp.getStep(),
rewriter.createOrFold<arith::ConstantOp>(
@@ -1563,22 +1589,30 @@ bool mlir::isPerfectlyNestedForLoops(
return true;
}
-llvm::SmallVector<int64_t>
+llvm::SmallVector<uint64_t>
mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
if (!loBnds || !upBnds || !steps)
return {};
- // TODO(#178506): The result should be SmallVector<uint64_t> and use uint64_t
- // for trip counts.
- llvm::SmallVector<int64_t> tripCounts;
+
+ // Determine signedness from the loop operation if possible.
+ // For scf::ForOp, use the loop's comparison signedness.
+ // Otherwise, default to unsigned (trip counts are conceptually unsigned).
+ bool isSigned = false;
+ if (auto forOp = dyn_cast<scf::ForOp>(loopOp.getOperation()))
+ isSigned = !forOp.getUnsignedCmp();
+
+ llvm::SmallVector<uint64_t> tripCounts;
for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
- // TODO(#178506): Signedness is not handled correctly here.
std::optional<llvm::APInt> numIter = constantTripCount(
- lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
+ lb, ub, step, isSigned, scf::computeUbMinusLb);
if (!numIter)
return {};
+ // Check for overflow before extracting the value.
+ if (numIter->getActiveBits() > 64)
+ return {};
tripCounts.push_back(numIter->getZExtValue());
}
return tripCounts;
@@ -1610,7 +1644,7 @@ FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
// Make sure that the unroll factors divide the iteration space evenly
// TODO: Support unrolling loops with dynamic iteration spaces.
- const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op);
+ const llvm::SmallVector<uint64_t> tripCounts = getConstLoopTripCounts(op);
if (tripCounts.empty())
return rewriter.notifyMatchFailure(
op, "Failed to compute constant trip counts for the loop. Note that "
More information about the Mlir-commits
mailing list