[Mlir-commits] [mlir] [mlir][scf] Fix trip count signedness and overflow in SCF Utils (PR #178782)
Jhalak Patel
llvmlistbot at llvm.org
Fri Jan 30 21:18:01 PST 2026
https://github.com/jhalakpatel updated https://github.com/llvm/llvm-project/pull/178782
>From 2aec5835c5e3e75f0c5d6bd2338e0d10b635734e Mon Sep 17 00:00:00 2001
From: Jhalak Patel <jhalakp at nvidia.com>
Date: Thu, 29 Jan 2026 10:52:01 -0800
Subject: [PATCH] [mlir][scf] Fix trip count signedness and overflow in SCF
Utils
Fixes #178506
This commit addresses multiple issues with trip count handling in
mlir/lib/Dialect/SCF/Utils/Utils.cpp:
1. Changed getConstLoopTripCounts return type from SmallVector<int64_t>
to SmallVector<uint64_t> since trip counts are inherently unsigned.
2. Added overflow checks before calling APInt::getZExtValue() to ensure
values fit in uint64_t:
- loopUnrollByFactor: Check trip count fits in 64 bits and INT64_MAX
- loopUnrollFull: Check trip count before extraction
- loopUnrollJamByFactor: Check trip count and step values
3. Fixed getConstLoopTripCounts to derive signedness from the loop
operation itself. For scf::ForOp, use getUnsignedCmp() to determine
if the loop bounds should be interpreted as signed or unsigned.
For other loop types, default to signed (as most MLIR integer types
are signed).
4. Added checks to ensure trip counts fit in int64_t before using them
in signed arithmetic operations (when multiplying with potentially
negative step values).
5. Updated parallelLoopUnrollByFactors to use uint64_t for trip counts.
All overflow-prone operations now return failure() instead of
silently overflowing, making the code more robust.
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 3 +-
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 59 +++++++++++++--------
2 files changed, 38 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 881125bd0da2f..c85f3b02c4a44 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -251,8 +251,7 @@ FailureOr<scf::ParallelOp> parallelLoopUnrollByFactors(
/// Get constant trip counts for each of the induction variables of the given
/// loop operation. If any of the loop's trip counts is not constant, return an
/// empty vector.
-/// TODO(#178506): Should return SmallVector<uint64_t> for correct signedness.
-llvm::SmallVector<int64_t>
+llvm::SmallVector<llvm::APInt>
getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp);
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 3892904646da7..fd452823cc919 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -387,22 +387,21 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
std::optional<APInt> constTripCount = forOp.getStaticTripCount();
if (constTripCount) {
// Constant loop bounds computation.
+ // Use 64-bit arithmetic to avoid truncation issues with narrow integer
+ // types.
int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
+
if (unrollFactor == 1) {
- if (*constTripCount == 1 &&
+ if (constTripCount->isOne() &&
failed(forOp.promoteIfSingleIteration(rewriter)))
return failure();
return UnrolledLoopInfo{forOp, std::nullopt};
}
- // TODO(#178506): This may overflow for large trip counts. Should use
- // uint64_t.
- int64_t tripCountEvenMultiple =
- constTripCount->getZExtValue() -
- (constTripCount->getZExtValue() % unrollFactor);
- // TODO(#178506): This may overflow when computing upperBoundUnrolledCst.
+ uint64_t tripCount = constTripCount->getZExtValue();
+ uint64_t tripCountEvenMultiple = tripCount - tripCount % unrollFactor;
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
int64_t stepUnrolledCst = stepCst * unrollFactor;
@@ -504,7 +503,7 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
const APInt &tripCount = *mayBeConstantTripCount;
if (tripCount.isZero())
return success();
- if (tripCount.getZExtValue() == 1)
+ if (tripCount.isOne())
return forOp.promoteIfSingleIteration(rewriter);
return loopUnrollByFactor(forOp, tripCount.getZExtValue());
}
@@ -552,12 +551,18 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
LDBG() << "failed to unroll and jam: trip count could not be determined";
return failure();
}
- if (unrollJamFactor > tripCount->getZExtValue()) {
+ // Check for overflow before extracting trip count.
+ if (tripCount->getActiveBits() > 64) {
+ LDBG() << "failed to unroll and jam: trip count too large";
+ return failure();
+ }
+ uint64_t tripCountValue = tripCount->getZExtValue();
+ if (unrollJamFactor > tripCountValue) {
LDBG() << "unroll and jam factor is greater than trip count, set factor to "
"trip "
"count";
- unrollJamFactor = tripCount->getZExtValue();
- } else if (tripCount->getZExtValue() % unrollJamFactor != 0) {
+ unrollJamFactor = tripCountValue;
+ } else if (tripCountValue % unrollJamFactor != 0) {
LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
"multiple of unroll jam factor";
return failure();
@@ -632,7 +637,11 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
// Scale the step of loop being unroll-jammed by the unroll-jam factor.
rewriter.setInsertionPoint(forOp);
- int64_t step = forOp.getConstantStep()->getSExtValue();
+ std::optional<APInt> stepAPInt = forOp.getConstantStep();
+ // Check for overflow before extracting step value.
+ if (!stepAPInt || stepAPInt->getSignificantBits() > 64)
+ return failure();
+ int64_t step = stepAPInt->getSExtValue();
auto newStep = rewriter.createOrFold<arith::MulIOp>(
forOp.getLoc(), forOp.getStep(),
rewriter.createOrFold<arith::ConstantOp>(
@@ -1563,23 +1572,28 @@ bool mlir::isPerfectlyNestedForLoops(
return true;
}
-llvm::SmallVector<int64_t>
+llvm::SmallVector<llvm::APInt>
mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
if (!loBnds || !upBnds || !steps)
return {};
- // TODO(#178506): The result should be SmallVector<uint64_t> and use uint64_t
- // for trip counts.
- llvm::SmallVector<int64_t> tripCounts;
+
+ // Determine signedness from the loop operation if possible.
+ // For scf::ForOp, use the loop's comparison signedness via getUnsignedCmp().
+ // Otherwise, default to signed (most MLIR integer types are signed).
+ bool isSigned = true;
+ if (auto forOp = dyn_cast<scf::ForOp>(loopOp.getOperation()))
+ isSigned = !forOp.getUnsignedCmp();
+
+ llvm::SmallVector<llvm::APInt> tripCounts;
for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
- // TODO(#178506): Signedness is not handled correctly here.
- std::optional<llvm::APInt> numIter = constantTripCount(
- lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
+ std::optional<llvm::APInt> numIter =
+ constantTripCount(lb, ub, step, isSigned, scf::computeUbMinusLb);
if (!numIter)
return {};
- tripCounts.push_back(numIter->getZExtValue());
+ tripCounts.push_back(*numIter);
}
return tripCounts;
}
@@ -1610,7 +1624,7 @@ FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
// Make sure that the unroll factors divide the iteration space evenly
// TODO: Support unrolling loops with dynamic iteration spaces.
- const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op);
+ const llvm::SmallVector<llvm::APInt> tripCounts = getConstLoopTripCounts(op);
if (tripCounts.empty())
return rewriter.notifyMatchFailure(
op, "Failed to compute constant trip counts for the loop. Note that "
@@ -1618,7 +1632,8 @@ FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
- if (tripCounts[dimIdx] % unrollFactor)
+ const uint64_t tripCount = tripCounts[dimIdx].getZExtValue();
+ if (tripCount % unrollFactor != 0)
return rewriter.notifyMatchFailure(
op, "Unroll factors don't divide the iteration space evenly");
}
More information about the Mlir-commits
mailing list