[Mlir-commits] [mlir] [mlir][scf] Fix trip count signedness and overflow in SCF Utils (PR #178782)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 29 15:57:57 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: Jhalak Patel (jhalakpatel)

<details>
<summary>Changes</summary>

Fixes #<!-- -->178506

This commit addresses multiple issues with trip count handling in mlir/lib/Dialect/SCF/Utils/Utils.cpp:

1. Changed getConstLoopTripCounts return type from SmallVector<int64_t> to SmallVector<uint64_t> since trip counts are inherently unsigned.

2. Added overflow checks before calling APInt::getZExtValue() to ensure values fit in uint64_t:
   - loopUnrollByFactor: Check trip count fits in 64 bits and INT64_MAX
   - loopUnrollFull: Check trip count before extraction
   - loopUnrollJamByFactor: Check trip count and step values

3. Fixed getConstLoopTripCounts to derive signedness from the loop operation itself (using ForOp::getUnsignedCmp() when available) rather than hardcoding isSigned=false. This ensures correct handling of both signed and unsigned loop comparisons.

4. Added checks to ensure trip counts fit in int64_t before using them in signed arithmetic operations (when multiplying with potentially negative step values).

5. Updated parallelLoopUnrollByFactors to use uint64_t for trip counts.

All overflow-prone operations now return failure() instead of silently overflowing, making the code more robust.

---
Full diff: https://github.com/llvm/llvm-project/pull/178782.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+1-2) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+55-21) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 881125bd0da2f..59e66567c5ae8 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -251,8 +251,7 @@ FailureOr<scf::ParallelOp> parallelLoopUnrollByFactors(
 /// Get constant trip counts for each of the induction variables of the given
 /// loop operation. If any of the loop's trip counts is not constant, return an
 /// empty vector.
-/// TODO(#178506): Should return SmallVector<uint64_t> for correct signedness.
-llvm::SmallVector<int64_t>
+llvm::SmallVector<uint64_t>
 getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp);
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 3892904646da7..4c172bc798e97 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -397,14 +397,26 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
       return UnrolledLoopInfo{forOp, std::nullopt};
     }
 
-    // TODO(#178506): This may overflow for large trip counts. Should use
-    // uint64_t.
-    int64_t tripCountEvenMultiple =
-        constTripCount->getZExtValue() -
-        (constTripCount->getZExtValue() % unrollFactor);
-    // TODO(#178506): This may overflow when computing upperBoundUnrolledCst.
-    int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
-    int64_t stepUnrolledCst = stepCst * unrollFactor;
+    // Check for overflow before extracting trip count.
+    if (constTripCount->getActiveBits() > 64)
+      return failure();
+    uint64_t tripCountValue = constTripCount->getZExtValue();
+    uint64_t tripCountEvenMultiple =
+        tripCountValue - (tripCountValue % unrollFactor);
+
+    // Since we need to compute: upperBoundUnrolled = lbCst + tripCountEvenMultiple * stepCst
+    // and the result must fit in int64_t, check that tripCountEvenMultiple fits in int64_t
+    // when we need to multiply it.
+    if (tripCountEvenMultiple > static_cast<uint64_t>(INT64_MAX))
+      return failure();
+
+    int64_t upperBoundUnrolledCst = lbCst +
+        static_cast<int64_t>(tripCountEvenMultiple) * stepCst;
+
+    // Check step * unrollFactor fits in int64_t.
+    if (unrollFactor > static_cast<uint64_t>(INT64_MAX))
+      return failure();
+    int64_t stepUnrolledCst = stepCst * static_cast<int64_t>(unrollFactor);
 
     // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
     generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
@@ -504,9 +516,13 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
   const APInt &tripCount = *mayBeConstantTripCount;
   if (tripCount.isZero())
     return success();
-  if (tripCount.getZExtValue() == 1)
+  // Check for overflow before extracting trip count.
+  if (tripCount.getActiveBits() > 64)
+    return failure();
+  uint64_t tripCountValue = tripCount.getZExtValue();
+  if (tripCountValue == 1)
     return forOp.promoteIfSingleIteration(rewriter);
-  return loopUnrollByFactor(forOp, tripCount.getZExtValue());
+  return loopUnrollByFactor(forOp, tripCountValue);
 }
 
 /// Check if bounds of all inner loops are defined outside of `forOp`
@@ -552,12 +568,18 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
     LDBG() << "failed to unroll and jam: trip count could not be determined";
     return failure();
   }
-  if (unrollJamFactor > tripCount->getZExtValue()) {
+  // Check for overflow before extracting trip count.
+  if (tripCount->getActiveBits() > 64) {
+    LDBG() << "failed to unroll and jam: trip count too large";
+    return failure();
+  }
+  uint64_t tripCountValue = tripCount->getZExtValue();
+  if (unrollJamFactor > tripCountValue) {
     LDBG() << "unroll and jam factor is greater than trip count, set factor to "
               "trip "
               "count";
-    unrollJamFactor = tripCount->getZExtValue();
-  } else if (tripCount->getZExtValue() % unrollJamFactor != 0) {
+    unrollJamFactor = tripCountValue;
+  } else if (tripCountValue % unrollJamFactor != 0) {
     LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
               "multiple of unroll jam factor";
     return failure();
@@ -632,7 +654,11 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
 
   // Scale the step of loop being unroll-jammed by the unroll-jam factor.
   rewriter.setInsertionPoint(forOp);
-  int64_t step = forOp.getConstantStep()->getSExtValue();
+  std::optional<APInt> stepAPInt = forOp.getConstantStep();
+  // Check for overflow before extracting step value.
+  if (!stepAPInt || stepAPInt->getSignificantBits() > 64)
+    return failure();
+  int64_t step = stepAPInt->getSExtValue();
   auto newStep = rewriter.createOrFold<arith::MulIOp>(
       forOp.getLoc(), forOp.getStep(),
       rewriter.createOrFold<arith::ConstantOp>(
@@ -1563,22 +1589,30 @@ bool mlir::isPerfectlyNestedForLoops(
   return true;
 }
 
-llvm::SmallVector<int64_t>
+llvm::SmallVector<uint64_t>
 mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
   std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
   std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
   std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
   if (!loBnds || !upBnds || !steps)
     return {};
-  // TODO(#178506): The result should be SmallVector<uint64_t> and use uint64_t
-  // for trip counts.
-  llvm::SmallVector<int64_t> tripCounts;
+
+  // Determine signedness from the loop operation if possible.
+  // For scf::ForOp, use the loop's comparison signedness.
+  // Otherwise, default to unsigned (trip counts are conceptually unsigned).
+  bool isSigned = false;
+  if (auto forOp = dyn_cast<scf::ForOp>(loopOp.getOperation()))
+    isSigned = !forOp.getUnsignedCmp();
+
+  llvm::SmallVector<uint64_t> tripCounts;
   for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
-    // TODO(#178506): Signedness is not handled correctly here.
     std::optional<llvm::APInt> numIter = constantTripCount(
-        lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
+        lb, ub, step, isSigned, scf::computeUbMinusLb);
     if (!numIter)
       return {};
+    // Check for overflow before extracting the value.
+    if (numIter->getActiveBits() > 64)
+      return {};
     tripCounts.push_back(numIter->getZExtValue());
   }
   return tripCounts;
@@ -1610,7 +1644,7 @@ FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
 
   // Make sure that the unroll factors divide the iteration space evenly
   // TODO: Support unrolling loops with dynamic iteration spaces.
-  const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op);
+  const llvm::SmallVector<uint64_t> tripCounts = getConstLoopTripCounts(op);
   if (tripCounts.empty())
     return rewriter.notifyMatchFailure(
         op, "Failed to compute constant trip counts for the loop. Note that "

``````````

</details>


https://github.com/llvm/llvm-project/pull/178782


More information about the Mlir-commits mailing list