[Mlir-commits] [mlir] [mlir][DialectUtils] Fix 0 step handling in `constantTripCount` (PR #177329)

Matthias Springer llvmlistbot at llvm.org
Sun Jan 25 11:38:04 PST 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/177329

>From 3342e6d2c19ca41f03c059d212c56c398e38b36e Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 22 Jan 2026 09:39:50 +0000
Subject: [PATCH] [mlir][DialectUtils] Fix 0 step handling in
 `constantTripCount`

---
 .../Dialect/SCF/Transforms/LoopSpecialization.cpp  |  3 +++
 mlir/lib/Dialect/Utils/StaticValueUtils.cpp        | 14 +++++++++++---
 mlir/test/Dialect/SCF/canonicalize.mlir            |  7 ++++---
 mlir/test/Dialect/SCF/for-loop-peeling.mlir        | 10 ++++++++--
 4 files changed, 26 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index cff01d1de17e8..a39e5520a144b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -117,6 +117,9 @@ static void specializeForLoopForUnrolling(ForOp op) {
 /// The newly generated scf.if operation is returned via `ifOp`. The boundary
 /// at which the loop is split (new upper bound) is returned via `splitBound`.
 /// The return value indicates whether the loop was rewritten or not.
+///
+/// Note: Loops with a step size of 0 cannot be peeled. Applying this function
+/// to such a loop may result in IR with undefined behavior.
 static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
                                  ForOp &partialIteration, Value &splitBound) {
   RewriterBase::InsertionGuard guard(b);
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index bc9d8a2496b4b..7fb0d4e9710f8 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -316,8 +316,12 @@ std::optional<APInt> constantTripCount(
            << lb;
     return std::nullopt;
   }
-  if (lb == ub)
+  if (lb == ub) {
+    // Fast path: LB == UB. The loop has zero iterations.
+    // Note: LB and UB could match at runtime, even though they are different
+    // SSA values. That case cannot be detected here.
     return APInt(bitwidth, 0);
+  }
 
   std::optional<std::pair<APInt, bool>> maybeStepCst =
       getConstantAPIntValue(step);
@@ -326,8 +330,12 @@ std::optional<APInt> constantTripCount(
     auto &stepCst = maybeStepCst->first;
     assert(static_cast<int>(stepCst.getBitWidth()) == bitwidth &&
            "step must have the same bitwidth as lb and ub");
-    if (stepCst.isZero())
-      return stepCst;
+    if (stepCst.isZero()) {
+      // Step is zero. If LB and UB match, we have zero iterations. Otherwise,
+      // we have an infinite number of iterations. We cannot tell for sure which
+      // case applies, so the static trip count is unknown.
+      return std::nullopt;
+    }
     if (stepCst.isNegative())
       return APInt(bitwidth, 0);
   }
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 56ef00dcd6f8b..f65046ecee6da 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2207,13 +2207,14 @@ func.func @index_switch_fold_no_res() {
 
 // -----
 
-// Step 0 is invalid, the loop is eliminated.
+// Step size 0: The loop has an infinite number of iterations.
 // CHECK-LABEL: func @scf_for_all_step_size_0()
-//       CHECK-NOT:   scf.forall
+//       CHECK:   scf.forall (%[[arg0:.*]]) = (0) to (1) step (0)
+//       CHECK:     vector.print %[[arg0]]
 func.func @scf_for_all_step_size_0()  {
   %x = arith.constant 0 : index
   scf.forall (%i, %j) = (0, 4) to (1, 5) step (%x, 8) {
-    vector.print %x : index
+    vector.print %i : index
     scf.forall.in_parallel {}
   }
   return
diff --git a/mlir/test/Dialect/SCF/for-loop-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-peeling.mlir
index 084576625f32c..6caa4bd50bad8 100644
--- a/mlir/test/Dialect/SCF/for-loop-peeling.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-peeling.mlir
@@ -357,9 +357,15 @@ func.func @regression(%arg0: memref<i64>, %arg1: index) {
 // -----
 
 // Regression test: Make sure that we do not crash.
-// The step is 0, the loop will be eliminated.
+
 // CHECK-LABEL: func @zero_step(
-//       CHECK-NOT:   scf.for
+//       CHECK:   %[[c0:.*]] = arith.constant 0
+//       CHECK:   %[[c1:.*]] = arith.constant 1
+//       CHECK:   %[[poison:.*]] = ub.poison
+//       CHECK:   scf.for %{{.*}} = %[[c0]] to %[[poison]] step %[[c0]]
+//       CHECK:     arith.index_cast
+//       CHECK:   scf.for %{{.*}} = %[[poison]] to %[[c1]] step %[[c0]]
+//       CHECK:     arith.index_cast
 func.func @zero_step(%arg0: memref<i64>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index



More information about the Mlir-commits mailing list