[Mlir-commits] [mlir] [MLIR] Add a getStaticTripCount method to LoopLikeOpInterface (PR #158679)

Mehdi Amini llvmlistbot at llvm.org
Wed Sep 17 02:50:47 PDT 2025


================
@@ -264,22 +274,108 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
 
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
 /// bound `ub` and step `step`.
-std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
-                                         OpFoldResult step) {
+std::optional<APInt> constantTripCount(
+    OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
+    llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
+        computeUbMinusLb) {
+  // This is the bitwidth used to return 0 when loop does not execute.
+  // We infer it from the type of the bound if it isn't an index type.
+  bool isIndex = true;
+  auto getBitwidth = [&](OpFoldResult ofr) -> int {
+    if (auto attr = dyn_cast<Attribute>(ofr)) {
+      if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+        if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) {
+          isIndex = intType.isIndex();
+          return intType.getWidth();
+        }
+      }
+    } else {
+      auto val = cast<Value>(ofr);
+      if (auto intType = dyn_cast<IntegerType>(val.getType())) {
+        isIndex = intType.isIndex();
+        return intType.getWidth();
+      }
+    }
+    return IndexType::kInternalStorageBitWidth;
+  };
+  int bitwidth = getBitwidth(lb);
+  assert(bitwidth == getBitwidth(ub) &&
+         "lb and ub must have the same bitwidth");
   if (lb == ub)
-    return 0;
+    return APInt(bitwidth, 0);
+
+  std::optional<std::pair<APInt, bool>> maybeStepCst =
+      getConstantAPIntValue(step);
+
+  if (maybeStepCst) {
+    auto &stepCst = maybeStepCst->first;
+    assert(static_cast<int>(stepCst.getBitWidth()) == bitwidth &&
+           "step must have the same bitwidth as lb and ub");
+    if (stepCst.isZero())
+      return stepCst;
+    if (stepCst.isNegative())
+      return APInt(bitwidth, 0);
+  }
 
-  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
-  if (!lbConstant)
-    return std::nullopt;
-  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
-  if (!ubConstant)
-    return std::nullopt;
-  std::optional<int64_t> stepConstant = getConstantIntValue(step);
-  if (!stepConstant || *stepConstant == 0)
-    return std::nullopt;
+  if (isIndex) {
+    LDBG()
+        << "Computing loop trip count for index type may break with overflow";
+    // TODO: we can't compute the trip count for index type. We should fix this
+    // but too many tests are failing right now.
+    //   return {};
+  }
 
-  return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
+  /// Compute the difference between the upper and lower bound: either from the
+  /// constant value or using the computeUbMinusLb callback.
----------------
joker-eph wrote:

Because we don't want everyone reimplementing the callback: we provide the "normal" constant diff logic here.
The user can inject their own "offset" op there. We need the callback because the utility can't depend on arith.add and the callback makes it possible to work with other adds.

https://github.com/llvm/llvm-project/pull/158679


More information about the Mlir-commits mailing list