[Mlir-commits] [mlir] [mlir][scf] Fix trip count signedness and overflow in SCF Utils (PR #178782)
Jhalak Patel
llvmlistbot at llvm.org
Fri Jan 30 21:34:10 PST 2026
https://github.com/jhalakpatel updated https://github.com/llvm/llvm-project/pull/178782
>From 721021f7957dbb1eda4c923160a4a59b19ae07b3 Mon Sep 17 00:00:00 2001
From: Jhalak Patel <jhalakp at nvidia.com>
Date: Thu, 29 Jan 2026 10:52:01 -0800
Subject: [PATCH] [mlir][scf] Fix trip count signedness and overflow in SCF
Utils
Fixes #178506
This commit addresses multiple issues with trip count handling in
mlir/lib/Dialect/SCF/Utils/Utils.cpp:
1. Changed getConstLoopTripCounts return type from SmallVector<int64_t>
to SmallVector<uint64_t> since trip counts are inherently unsigned.
2. Added overflow checks before calling APInt::getZExtValue() to ensure
values fit in uint64_t:
- loopUnrollByFactor: Check trip count fits in 64 bits and INT64_MAX
- loopUnrollFull: Check trip count before extraction
- loopUnrollJamByFactor: Check trip count and step values
3. Fixed getConstLoopTripCounts to derive signedness from the loop
operation itself. For scf::ForOp, use getUnsignedCmp() to determine
if the loop bounds should be interpreted as signed or unsigned.
For other loop types, default to signed (as most MLIR integer types
are signed).
4. Added checks to ensure trip counts fit in int64_t before using them
in signed arithmetic operations (when multiplying with potentially
negative step values).
5. Updated parallelLoopUnrollByFactors to use uint64_t for trip counts.
All overflow-prone operations now return failure() instead of
silently overflowing, making the code more robust.
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 3 +-
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 34 +++++++++++----------
2 files changed, 19 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 881125bd0da2f..c85f3b02c4a44 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -251,8 +251,7 @@ FailureOr<scf::ParallelOp> parallelLoopUnrollByFactors(
/// Get constant trip counts for each of the induction variables of the given
/// loop operation. If any of the loop's trip counts is not constant, return an
/// empty vector.
-/// TODO(#178506): Should return SmallVector<uint64_t> for correct signedness.
-llvm::SmallVector<int64_t>
+llvm::SmallVector<llvm::APInt>
getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp);
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 3892904646da7..d2f7ca6577fd5 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -387,22 +387,21 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
std::optional<APInt> constTripCount = forOp.getStaticTripCount();
if (constTripCount) {
// Constant loop bounds computation.
+ // Use 64-bit arithmetic to avoid truncation issues with narrow integer
+ // types.
int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
+
if (unrollFactor == 1) {
- if (*constTripCount == 1 &&
+ if (constTripCount->isOne() &&
failed(forOp.promoteIfSingleIteration(rewriter)))
return failure();
return UnrolledLoopInfo{forOp, std::nullopt};
}
- // TODO(#178506): This may overflow for large trip counts. Should use
- // uint64_t.
- int64_t tripCountEvenMultiple =
- constTripCount->getZExtValue() -
- (constTripCount->getZExtValue() % unrollFactor);
- // TODO(#178506): This may overflow when computing upperBoundUnrolledCst.
+ uint64_t tripCount = constTripCount->getZExtValue();
+ uint64_t tripCountEvenMultiple = tripCount - tripCount % unrollFactor;
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
int64_t stepUnrolledCst = stepCst * unrollFactor;
@@ -504,7 +503,7 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
const APInt &tripCount = *mayBeConstantTripCount;
if (tripCount.isZero())
return success();
- if (tripCount.getZExtValue() == 1)
+ if (tripCount.isOne())
return forOp.promoteIfSingleIteration(rewriter);
return loopUnrollByFactor(forOp, tripCount.getZExtValue());
}
@@ -552,12 +551,13 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
LDBG() << "failed to unroll and jam: trip count could not be determined";
return failure();
}
- if (unrollJamFactor > tripCount->getZExtValue()) {
+ uint64_t tripCountValue = tripCount->getZExtValue();
+ if (unrollJamFactor > tripCountValue) {
LDBG() << "unroll and jam factor is greater than trip count, set factor to "
"trip "
"count";
- unrollJamFactor = tripCount->getZExtValue();
- } else if (tripCount->getZExtValue() % unrollJamFactor != 0) {
+ unrollJamFactor = tripCountValue;
+ } else if (tripCountValue % unrollJamFactor != 0) {
LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
"multiple of unroll jam factor";
return failure();
@@ -1563,23 +1563,24 @@ bool mlir::isPerfectlyNestedForLoops(
return true;
}
-llvm::SmallVector<int64_t>
+llvm::SmallVector<llvm::APInt>
mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
if (!loBnds || !upBnds || !steps)
return {};
+
// TODO(#178506): The result should be SmallVector<uint64_t> and use uint64_t
// for trip counts.
- llvm::SmallVector<int64_t> tripCounts;
+ llvm::SmallVector<llvm::APInt> tripCounts;
for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
// TODO(#178506): Signedness is not handled correctly here.
std::optional<llvm::APInt> numIter = constantTripCount(
lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
if (!numIter)
return {};
- tripCounts.push_back(numIter->getZExtValue());
+ tripCounts.push_back(*numIter);
}
return tripCounts;
}
@@ -1610,7 +1611,7 @@ FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
// Make sure that the unroll factors divide the iteration space evenly
// TODO: Support unrolling loops with dynamic iteration spaces.
- const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op);
+ const llvm::SmallVector<llvm::APInt> tripCounts = getConstLoopTripCounts(op);
if (tripCounts.empty())
return rewriter.notifyMatchFailure(
op, "Failed to compute constant trip counts for the loop. Note that "
@@ -1618,7 +1619,8 @@ FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
- if (tripCounts[dimIdx] % unrollFactor)
+ const uint64_t tripCount = tripCounts[dimIdx].getZExtValue();
+ if (tripCount % unrollFactor != 0)
return rewriter.notifyMatchFailure(
op, "Unroll factors don't divide the iteration space evenly");
}
More information about the Mlir-commits
mailing list