[Mlir-commits] [mlir] [MLIR] Add a getStaticTripCount method to LoopLikeOpInterface (PR #158679)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 15 09:33:44 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

This patch adds a `getStaticTripCount` to the LoopLikeOpInterface, allowing loops to optionally return a static trip count when possible. This is implemented on SCF ForOp, revamping the implementation of `constantTripCount`, removing redundant duplicate implementations from SCF.cpp.

---

Patch is 47.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158679.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+1-1) 
- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+22-3) 
- (modified) mlir/include/mlir/Interfaces/LoopLikeInterface.td (+11) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+61-53) 
- (modified) mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp (+9-1) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+13-19) 
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+95-16) 
- (modified) mlir/test/Dialect/SCF/for-loop-peeling.mlir (+19-17) 
- (added) mlir/test/Dialect/SCF/trip_count.mlir (+589) 
- (modified) mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp (+14) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index d3c01c31636a7..fadd3fc10bfc4 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -152,7 +152,7 @@ def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
-        "getLoopUpperBounds", "getYieldedValuesMutable",
+        "getLoopUpperBounds", "getStaticTripCount", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 77c376fb9973a..aa0cc35a1d675 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -105,6 +105,10 @@ OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val);
 SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
                                                  ArrayRef<int64_t> values);
 
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
+/// The second return value indicates whether the value is an index type
+/// and thus the bitwidth is not defined.
+std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr);
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
 /// If all ofrs are constant integers or IntegerAttrs, return the integers.
@@ -201,9 +205,24 @@ foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
 LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
 
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
-/// bound `ub` and step `step`.
-std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
-                                         OpFoldResult step);
+/// bound `ub` and step `step`. The `isSigned` flag indicates whether the loop
+/// comparison between lb and ub is signed or unsigned. A negative step or a
+/// lower bound greater than the upper bound are considered invalid and will
+/// yield a zero trip count.
+/// The `computeUbMinusLb` callback is invoked to compute the difference between
+/// the upper and lower bound when not constant. It can be used by the client
+/// to match the pattern:
+///
+///   %ub = arith.addi nsw %lb, %c16_i32 : i32
+///   %1 = scf.for %arg0 = %lb to %ub ...
+///
+/// where %ub is computed as a static offset from %lb.
+/// Note: the matched addition should be nsw/nuw (matching the loop comparison)
+/// to avoid overflow, otherwise an overflow would imply a zero trip count.
+std::optional<APInt> constantTripCount(
+    OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
+    llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
+        computeUbMinusLb);
 
 /// Idiomatic saturated operations on values like offsets, sizes, and strides.
 struct SaturatedInteger {
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 6c95b4802837b..cfd15a7746e19 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -232,6 +232,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*defaultImplementation=*/[{
         return ::mlir::failure();
       }]
+    >,
+    InterfaceMethod<[{
+        Compute the static trip count if possible.
+      }],
+      /*retTy=*/"::std::optional<APInt>",
+      /*methodName=*/"getStaticTripCount",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::std::nullopt;
+      }]
     >
   ];
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c35989ecba6cd..e68dc04de231b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -19,6 +19,8 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
@@ -26,6 +28,9 @@
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/DebugLog.h"
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::scf;
@@ -105,6 +110,23 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
   return nullptr;
 }
 
+/// Helper function to compute the difference between two values. This is used
+/// by the loop implementations to compute the trip count.
+static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
+                                                    bool isSigned) {
+  llvm::APSInt diff;
+  auto addOp = ub.getDefiningOp<arith::AddIOp>();
+  if (!addOp)
+    return std::nullopt;
+  if ((isSigned && !addOp.hasNoSignedWrap()) ||
+      (!isSigned && !addOp.hasNoUnsignedWrap()))
+    return std::nullopt;
+
+  if (!matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
+    return std::nullopt;
+  return diff;
+}
+
 //===----------------------------------------------------------------------===//
 // ExecuteRegionOp
 //===----------------------------------------------------------------------===//
@@ -408,11 +430,19 @@ std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
-  std::optional<int64_t> tripCount =
-      constantTripCount(getLowerBound(), getUpperBound(), getStep());
-  if (!tripCount.has_value() || tripCount != 1)
+  std::optional<APInt> tripCount = getStaticTripCount();
+  LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
+         << " for loop "
+         << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
+  if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
     return failure();
 
+  if (*tripCount == 0) {
+    rewriter.replaceAllUsesWith(getResults(), getInitArgs());
+    rewriter.eraseOp(*this);
+    return success();
+  }
+
   // Replace all results with the yielded values.
   auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
   rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
@@ -646,7 +676,8 @@ SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
   for (auto [lb, ub, step] :
        llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
-    auto tripCount = constantTripCount(lb, ub, step);
+    auto tripCount =
+        constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
     if (!tripCount.has_value() || *tripCount != 1)
       return failure();
   }
@@ -1003,27 +1034,6 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
   }
 };
 
-/// Util function that tries to compute a constant diff between u and l.
-/// Returns std::nullopt when the difference between two AffineValueMap is
-/// dynamic.
-static std::optional<APInt> computeConstDiff(Value l, Value u) {
-  IntegerAttr clb, cub;
-  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
-    llvm::APInt lbValue = clb.getValue();
-    llvm::APInt ubValue = cub.getValue();
-    return ubValue - lbValue;
-  }
-
-  // Else a simple pattern match for x + c or c + x
-  llvm::APInt diff;
-  if (matchPattern(
-          u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
-      matchPattern(
-          u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
-    return diff;
-  return std::nullopt;
-}
-
 /// Rewriting pattern that erases loops that are known not to iterate, replaces
 /// single-iteration loops with their bodies, and removes empty loops that
 /// iterate at least once and only return values defined outside of the loop.
@@ -1032,34 +1042,21 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
-    // If the upper bound is the same as the lower bound, the loop does not
-    // iterate, just remove it.
-    if (op.getLowerBound() == op.getUpperBound()) {
+    std::optional<APInt> tripCount = op.getStaticTripCount();
+    if (!tripCount.has_value())
+      return rewriter.notifyMatchFailure(op,
+                                         "can't compute constant trip count");
+
+    if (tripCount->isZero()) {
+      LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions());
       rewriter.replaceOp(op, op.getInitArgs());
       return success();
     }
 
-    std::optional<APInt> diff =
-        computeConstDiff(op.getLowerBound(), op.getUpperBound());
-    if (!diff)
-      return failure();
-
-    // If the loop is known to have 0 iterations, remove it.
-    bool zeroOrLessIterations =
-        diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
-    if (zeroOrLessIterations) {
-      rewriter.replaceOp(op, op.getInitArgs());
-      return success();
-    }
-
-    std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
-    if (!maybeStepValue)
-      return failure();
-
-    // If the loop is known to have 1 iteration, inline its body and remove the
-    // loop.
-    llvm::APInt stepValue = *maybeStepValue;
-    if (stepValue.sge(*diff)) {
+    if (tripCount->getSExtValue() == 1) {
+      LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions());
       SmallVector<Value, 4> blockArgs;
       blockArgs.reserve(op.getInitArgs().size() + 1);
       blockArgs.push_back(op.getLowerBound());
@@ -1072,11 +1069,14 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
     Block &block = op.getRegion().front();
     if (!llvm::hasSingleElement(block))
       return failure();
-    // If the loop is empty, iterates at least once, and only returns values
+    // The loop is empty and iterates at least once, if it only returns values
     // defined outside of the loop, remove it and replace it with yield values.
     if (llvm::any_of(op.getYieldedValues(),
                      [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
       return failure();
+    LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
+              "yield operands for loop "
+           << OpWithFlags(op, OpPrintingFlags().skipRegions());
     rewriter.replaceOp(op, op.getYieldedValues());
     return success();
   }
@@ -1172,6 +1172,11 @@ Speculation::Speculatability ForOp::getSpeculatability() {
   return Speculation::NotSpeculatable;
 }
 
+std::optional<APInt> ForOp::getStaticTripCount() {
+  return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
+                           /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
+}
+
 //===----------------------------------------------------------------------===//
 // ForallOp
 //===----------------------------------------------------------------------===//
@@ -1768,7 +1773,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
     for (auto [lb, ub, step, iv] :
          llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
                    op.getMixedStep(), op.getInductionVars())) {
-      auto numIterations = constantTripCount(lb, ub, step);
+      auto numIterations =
+          constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
       if (numIterations.has_value()) {
         // Remove the loop if it performs zero iterations.
         if (*numIterations == 0) {
@@ -1839,7 +1845,8 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
                    op.getMixedStep(), op.getInductionVars())) {
       if (iv.hasNUses(0))
         continue;
-      auto numIterations = constantTripCount(lb, ub, step);
+      auto numIterations =
+          constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
       if (!numIterations.has_value() || numIterations.value() != 1) {
         continue;
       }
@@ -3084,7 +3091,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
     for (auto [lb, ub, step, iv] :
          llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
                    op.getInductionVars())) {
-      auto numIterations = constantTripCount(lb, ub, step);
+      auto numIterations =
+          constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
       if (numIterations.has_value()) {
         // Remove the loop if it performs zero iterations.
         if (*numIterations == 0) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index f1203b2bdfee5..e3717aa9d940e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -94,7 +94,9 @@ static void specializeForLoopForUnrolling(ForOp op) {
 
   OpBuilder b(op);
   IRMapping map;
-  Value constant = arith::ConstantIndexOp::create(b, op.getLoc(), minConstant);
+  Value constant = arith::ConstantOp::create(
+      b, op.getLoc(),
+      IntegerAttr::get(op.getUpperBound().getType(), minConstant));
   Value cond = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::eq,
                                      bound, constant);
   map.map(bound, constant);
@@ -150,6 +152,9 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
                                              ValueRange{forOp.getLowerBound(),
                                                         forOp.getUpperBound(),
                                                         forOp.getStep()});
+  if (splitBound.getType() != forOp.getLowerBound().getType())
+    splitBound = b.createOrFold<arith::IndexCastOp>(
+        loc, forOp.getLowerBound().getType(), splitBound);
 
   // Create ForOp for partial iteration.
   b.setInsertionPointAfter(forOp);
@@ -230,6 +235,9 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
   auto loc = forOp.getLoc();
   Value splitBound = b.createOrFold<AffineApplyOp>(
       loc, ubMap, ValueRange{forOp.getLowerBound(), forOp.getStep()});
+  if (splitBound.getType() != forOp.getUpperBound().getType())
+    splitBound = b.createOrFold<arith::IndexCastOp>(
+        loc, forOp.getUpperBound().getType(), splitBound);
 
   // Peel the first iteration.
   IRMapping map;
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 684dff8121de6..57c4deab89321 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/DebugLog.h"
@@ -291,14 +292,6 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
   return arith::DivUIOp::create(builder, loc, sum, divisor);
 }
 
-/// Returns the trip count of `forOp` if its' low bound, high bound and step are
-/// constants, or optional otherwise. Trip count is computed as
-/// ceilDiv(highBound - lowBound, step).
-static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
-  return constantTripCount(forOp.getLowerBound(), forOp.getUpperBound(),
-                           forOp.getStep());
-}
-
 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -377,7 +370,7 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
   Value stepUnrolled;
   bool generateEpilogueLoop = true;
 
-  std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
+  std::optional<APInt> constTripCount = forOp.getStaticTripCount();
   if (constTripCount) {
     // Constant loop bounds computation.
     int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
@@ -391,7 +384,8 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
     }
 
     int64_t tripCountEvenMultiple =
-        *constTripCount - (*constTripCount % unrollFactor);
+        constTripCount->getSExtValue() -
+        (constTripCount->getSExtValue() % unrollFactor);
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
     int64_t stepUnrolledCst = stepCst * unrollFactor;
 
@@ -487,15 +481,15 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
 /// Unrolls this loop completely.
 LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
   IRRewriter rewriter(forOp.getContext());
-  std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+  std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
   if (!mayBeConstantTripCount.has_value())
     return failure();
-  uint64_t tripCount = *mayBeConstantTripCount;
-  if (tripCount == 0)
+  APInt &tripCount = *mayBeConstantTripCount;
+  if (tripCount.isZero())
     return success();
-  if (tripCount == 1)
+  if (tripCount.getSExtValue() == 1)
     return forOp.promoteIfSingleIteration(rewriter);
-  return loopUnrollByFactor(forOp, tripCount);
+  return loopUnrollByFactor(forOp, tripCount.getSExtValue());
 }
 
 /// Check if bounds of all inner loops are defined outside of `forOp`
@@ -535,18 +529,18 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
 
   // Currently, only constant trip count that divided by the unroll factor is
   // supported.
-  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  std::optional<APInt> tripCount = forOp.getStaticTripCount();
   if (!tripCount.has_value()) {
     // If the trip count is dynamic, do not unroll & jam.
     LDBG() << "failed to unroll and jam: trip count could not be determined";
     return failure();
   }
-  if (unrollJamFactor > *tripCount) {
+  if (unrollJamFactor > tripCount->getZExtValue()) {
     LDBG() << "unroll and jam factor is greater than trip count, set factor to "
               "trip "
               "count";
-    unrollJamFactor = *tripCount;
-  } else if (*tripCount % unrollJamFactor != 0) {
+    unrollJamFactor = tripCount->getZExtValue();
+  } else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
     LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
               "multiple of unroll jam factor";
     return failure();
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 34385d76f133a..e5deea6fa21ab 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/APSInt.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
 #include "llvm/Support/MathExtras.h"
 
 namespace mlir {
@@ -112,21 +113,30 @@ SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
 }
 
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
-std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
+/// The boolean indicates whether the value is an index type.
+std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr) {
   // Case 1: Check for Constant integer.
   if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
-    APSInt intVal;
+    APInt intVal;
     if (matchPattern(val, m_ConstantInt(&intVal)))
-      return intVal.getSExtValue();
+      return std::make_pair(intVal, val.getType().isIndex());
     return std::null...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/158679


More information about the Mlir-commits mailing list