[Mlir-commits] [mlir] [mlir][scf] Fix trip count signedness and overflow in SCF Utils (PR #178782)

Jhalak Patel llvmlistbot at llvm.org
Mon Feb 2 09:50:50 PST 2026


https://github.com/jhalakpatel updated https://github.com/llvm/llvm-project/pull/178782

>From 1cc34aa1e36f9e09a0f3ea41281ebfdcd8006164 Mon Sep 17 00:00:00 2001
From: Jhalak Patel <jhalakp at nvidia.com>
Date: Thu, 29 Jan 2026 10:52:01 -0800
Subject: [PATCH] [mlir][scf] Fix trip count signedness and overflow in SCF
 Utils

Change `getConstLoopTripCounts` to return `SmallVector<llvm::APInt>` instead
of `SmallVector<int64_t>` to properly handle signedness and prevent potential
overflow issues. Update all call sites to use APInt methods and uint64_t for
intermediate calculations.

- Use APInt::isOne() instead of direct comparison with 1
- Store trip counts in uint64_t to avoid overflow in modulo operations
- Remove TODOs about signedness and overflow issues that are now fixed

Fixes #178506
---
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h |  3 +-
 mlir/lib/Dialect/SCF/Utils/Utils.cpp        | 31 +++++++++------------
 2 files changed, 14 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 881125bd0da2f..c85f3b02c4a44 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -251,8 +251,7 @@ FailureOr<scf::ParallelOp> parallelLoopUnrollByFactors(
 /// Get constant trip counts for each of the induction variables of the given
 /// loop operation. If any of the loop's trip counts is not constant, return an
 /// empty vector.
-/// TODO(#178506): Should return SmallVector<uint64_t> for correct signedness.
-llvm::SmallVector<int64_t>
+llvm::SmallVector<llvm::APInt>
 getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp);
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 3892904646da7..f8a4f057c9f0d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -391,18 +391,14 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
     int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
     int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
     if (unrollFactor == 1) {
-      if (*constTripCount == 1 &&
+      if (constTripCount->isOne() &&
           failed(forOp.promoteIfSingleIteration(rewriter)))
         return failure();
       return UnrolledLoopInfo{forOp, std::nullopt};
     }
 
-    // TODO(#178506): This may overflow for large trip counts. Should use
-    // uint64_t.
-    int64_t tripCountEvenMultiple =
-        constTripCount->getZExtValue() -
-        (constTripCount->getZExtValue() % unrollFactor);
-    // TODO(#178506): This may overflow when computing upperBoundUnrolledCst.
+    uint64_t tripCount = constTripCount->getZExtValue();
+    uint64_t tripCountEvenMultiple = tripCount - tripCount % unrollFactor;
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
     int64_t stepUnrolledCst = stepCst * unrollFactor;
 
@@ -504,7 +500,7 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
   const APInt &tripCount = *mayBeConstantTripCount;
   if (tripCount.isZero())
     return success();
-  if (tripCount.getZExtValue() == 1)
+  if (tripCount.isOne())
     return forOp.promoteIfSingleIteration(rewriter);
   return loopUnrollByFactor(forOp, tripCount.getZExtValue());
 }
@@ -552,12 +548,13 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
     LDBG() << "failed to unroll and jam: trip count could not be determined";
     return failure();
   }
-  if (unrollJamFactor > tripCount->getZExtValue()) {
+  uint64_t tripCountValue = tripCount->getZExtValue();
+  if (unrollJamFactor > tripCountValue) {
     LDBG() << "unroll and jam factor is greater than trip count, set factor to "
               "trip "
               "count";
-    unrollJamFactor = tripCount->getZExtValue();
-  } else if (tripCount->getZExtValue() % unrollJamFactor != 0) {
+    unrollJamFactor = tripCountValue;
+  } else if (tripCountValue % unrollJamFactor != 0) {
     LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
               "multiple of unroll jam factor";
     return failure();
@@ -1563,23 +1560,21 @@ bool mlir::isPerfectlyNestedForLoops(
   return true;
 }
 
-llvm::SmallVector<int64_t>
+llvm::SmallVector<llvm::APInt>
 mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
   std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
   std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
   std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
   if (!loBnds || !upBnds || !steps)
     return {};
-  // TODO(#178506): The result should be SmallVector<uint64_t> and use uint64_t
-  // for trip counts.
-  llvm::SmallVector<int64_t> tripCounts;
+  llvm::SmallVector<llvm::APInt> tripCounts;
   for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
     // TODO(#178506): Signedness is not handled correctly here.
     std::optional<llvm::APInt> numIter = constantTripCount(
         lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
     if (!numIter)
       return {};
-    tripCounts.push_back(numIter->getZExtValue());
+    tripCounts.push_back(*numIter);
   }
   return tripCounts;
 }
@@ -1610,7 +1605,7 @@ FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
 
   // Make sure that the unroll factors divide the iteration space evenly
   // TODO: Support unrolling loops with dynamic iteration spaces.
-  const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op);
+  const llvm::SmallVector<llvm::APInt> tripCounts = getConstLoopTripCounts(op);
   if (tripCounts.empty())
     return rewriter.notifyMatchFailure(
         op, "Failed to compute constant trip counts for the loop. Note that "
@@ -1618,7 +1613,7 @@ FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
 
   for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
     const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
-    if (tripCounts[dimIdx] % unrollFactor)
+    if (tripCounts[dimIdx].urem(unrollFactor) != 0)
       return rewriter.notifyMatchFailure(
           op, "Unroll factors don't divide the iteration space evenly");
   }



More information about the Mlir-commits mailing list