[Mlir-commits] [mlir] [MLIR] Add a getStaticTripCount method to LoopLikeOpInterface (PR #158679)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 15 09:33:44 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
This patch adds a `getStaticTripCount` to the LoopLikeOpInterface, allowing loops to optionally return a static trip count when possible. This is implemented on SCF ForOp, revamping the implementation of `constantTripCount`, removing redundant duplicate implementations from SCF.cpp.
---
Patch is 47.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158679.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+1-1)
- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+22-3)
- (modified) mlir/include/mlir/Interfaces/LoopLikeInterface.td (+11)
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+61-53)
- (modified) mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp (+9-1)
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+13-19)
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+95-16)
- (modified) mlir/test/Dialect/SCF/for-loop-peeling.mlir (+19-17)
- (added) mlir/test/Dialect/SCF/trip_count.mlir (+589)
- (modified) mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp (+14)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index d3c01c31636a7..fadd3fc10bfc4 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -152,7 +152,7 @@ def ForOp : SCF_Op<"for",
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
- "getLoopUpperBounds", "getYieldedValuesMutable",
+ "getLoopUpperBounds", "getStaticTripCount", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 77c376fb9973a..aa0cc35a1d675 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -105,6 +105,10 @@ OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val);
SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
ArrayRef<int64_t> values);
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
+/// The second return value indicates whether the value is an index type
+/// and thus the bitwidth is not defined.
+std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr);
/// If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
/// If all ofrs are constant integers or IntegerAttrs, return the integers.
@@ -201,9 +205,24 @@ foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
/// Return the number of iterations for a loop with a lower bound `lb`, upper
-/// bound `ub` and step `step`.
-std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
- OpFoldResult step);
+/// bound `ub` and step `step`. The `isSigned` flag indicates whether the loop
+/// comparison between lb and ub is signed or unsigned. A negative step or a
+/// lower bound greater than the upper bound are considered invalid and will
+/// yield a zero trip count.
+/// The `computeUbMinusLb` callback is invoked to compute the difference between
+/// the upper and lower bound when not constant. It can be used by the client
+/// to match the pattern:
+///
+/// %ub = arith.addi nsw %lb, %c16_i32 : i32
+/// %1 = scf.for %arg0 = %lb to %ub ...
+///
+/// where %ub is computed as a static offset from %lb.
+/// Note: the matched addition should be nsw/nuw (matching the loop comparison)
+/// to avoid overflow, otherwise an overflow would imply a zero trip count.
+std::optional<APInt> constantTripCount(
+ OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
+ llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
+ computeUbMinusLb);
/// Idiomatic saturated operations on values like offsets, sizes, and strides.
struct SaturatedInteger {
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 6c95b4802837b..cfd15a7746e19 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -232,6 +232,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*defaultImplementation=*/[{
return ::mlir::failure();
}]
+ >,
+ InterfaceMethod<[{
+ Compute the static trip count if possible.
+ }],
+ /*retTy=*/"::std::optional<APInt>",
+ /*methodName=*/"getStaticTripCount",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::std::nullopt;
+ }]
>
];
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c35989ecba6cd..e68dc04de231b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -19,6 +19,8 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
@@ -26,6 +28,9 @@
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/DebugLog.h"
+#include <optional>
using namespace mlir;
using namespace mlir::scf;
@@ -105,6 +110,23 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion,
return nullptr;
}
+/// Helper function to compute the difference between two values. This is used
+/// by the loop implementations to compute the trip count.
+static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
+ bool isSigned) {
+ llvm::APSInt diff;
+ auto addOp = ub.getDefiningOp<arith::AddIOp>();
+ if (!addOp)
+ return std::nullopt;
+ if ((isSigned && !addOp.hasNoSignedWrap()) ||
+ (!isSigned && !addOp.hasNoUnsignedWrap()))
+ return std::nullopt;
+
+ if (!matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
+ return std::nullopt;
+ return diff;
+}
+
//===----------------------------------------------------------------------===//
// ExecuteRegionOp
//===----------------------------------------------------------------------===//
@@ -408,11 +430,19 @@ std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
- std::optional<int64_t> tripCount =
- constantTripCount(getLowerBound(), getUpperBound(), getStep());
- if (!tripCount.has_value() || tripCount != 1)
+ std::optional<APInt> tripCount = getStaticTripCount();
+ LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
+ << " for loop "
+ << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
+ if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
return failure();
+ if (*tripCount == 0) {
+ rewriter.replaceAllUsesWith(getResults(), getInitArgs());
+ rewriter.eraseOp(*this);
+ return success();
+ }
+
// Replace all results with the yielded values.
auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
@@ -646,7 +676,8 @@ SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
for (auto [lb, ub, step] :
llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
- auto tripCount = constantTripCount(lb, ub, step);
+ auto tripCount =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (!tripCount.has_value() || *tripCount != 1)
return failure();
}
@@ -1003,27 +1034,6 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
}
};
-/// Util function that tries to compute a constant diff between u and l.
-/// Returns std::nullopt when the difference between two AffineValueMap is
-/// dynamic.
-static std::optional<APInt> computeConstDiff(Value l, Value u) {
- IntegerAttr clb, cub;
- if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
- llvm::APInt lbValue = clb.getValue();
- llvm::APInt ubValue = cub.getValue();
- return ubValue - lbValue;
- }
-
- // Else a simple pattern match for x + c or c + x
- llvm::APInt diff;
- if (matchPattern(
- u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
- matchPattern(
- u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
- return diff;
- return std::nullopt;
-}
-
/// Rewriting pattern that erases loops that are known not to iterate, replaces
/// single-iteration loops with their bodies, and removes empty loops that
/// iterate at least once and only return values defined outside of the loop.
@@ -1032,34 +1042,21 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
- // If the upper bound is the same as the lower bound, the loop does not
- // iterate, just remove it.
- if (op.getLowerBound() == op.getUpperBound()) {
+ std::optional<APInt> tripCount = op.getStaticTripCount();
+ if (!tripCount.has_value())
+ return rewriter.notifyMatchFailure(op,
+ "can't compute constant trip count");
+
+ if (tripCount->isZero()) {
+ LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
rewriter.replaceOp(op, op.getInitArgs());
return success();
}
- std::optional<APInt> diff =
- computeConstDiff(op.getLowerBound(), op.getUpperBound());
- if (!diff)
- return failure();
-
- // If the loop is known to have 0 iterations, remove it.
- bool zeroOrLessIterations =
- diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
- if (zeroOrLessIterations) {
- rewriter.replaceOp(op, op.getInitArgs());
- return success();
- }
-
- std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
- if (!maybeStepValue)
- return failure();
-
- // If the loop is known to have 1 iteration, inline its body and remove the
- // loop.
- llvm::APInt stepValue = *maybeStepValue;
- if (stepValue.sge(*diff)) {
+ if (tripCount->getSExtValue() == 1) {
+ LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
SmallVector<Value, 4> blockArgs;
blockArgs.reserve(op.getInitArgs().size() + 1);
blockArgs.push_back(op.getLowerBound());
@@ -1072,11 +1069,14 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
Block &block = op.getRegion().front();
if (!llvm::hasSingleElement(block))
return failure();
- // If the loop is empty, iterates at least once, and only returns values
+ // The loop is empty and iterates at least once, if it only returns values
// defined outside of the loop, remove it and replace it with yield values.
if (llvm::any_of(op.getYieldedValues(),
[&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
return failure();
+ LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
+ "yield operands for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
rewriter.replaceOp(op, op.getYieldedValues());
return success();
}
@@ -1172,6 +1172,11 @@ Speculation::Speculatability ForOp::getSpeculatability() {
return Speculation::NotSpeculatable;
}
+std::optional<APInt> ForOp::getStaticTripCount() {
+ return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
+ /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
+}
+
//===----------------------------------------------------------------------===//
// ForallOp
//===----------------------------------------------------------------------===//
@@ -1768,7 +1773,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
for (auto [lb, ub, step, iv] :
llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
op.getMixedStep(), op.getInductionVars())) {
- auto numIterations = constantTripCount(lb, ub, step);
+ auto numIterations =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (numIterations.has_value()) {
// Remove the loop if it performs zero iterations.
if (*numIterations == 0) {
@@ -1839,7 +1845,8 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
op.getMixedStep(), op.getInductionVars())) {
if (iv.hasNUses(0))
continue;
- auto numIterations = constantTripCount(lb, ub, step);
+ auto numIterations =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (!numIterations.has_value() || numIterations.value() != 1) {
continue;
}
@@ -3084,7 +3091,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
for (auto [lb, ub, step, iv] :
llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
op.getInductionVars())) {
- auto numIterations = constantTripCount(lb, ub, step);
+ auto numIterations =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (numIterations.has_value()) {
// Remove the loop if it performs zero iterations.
if (*numIterations == 0) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index f1203b2bdfee5..e3717aa9d940e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -94,7 +94,9 @@ static void specializeForLoopForUnrolling(ForOp op) {
OpBuilder b(op);
IRMapping map;
- Value constant = arith::ConstantIndexOp::create(b, op.getLoc(), minConstant);
+ Value constant = arith::ConstantOp::create(
+ b, op.getLoc(),
+ IntegerAttr::get(op.getUpperBound().getType(), minConstant));
Value cond = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::eq,
bound, constant);
map.map(bound, constant);
@@ -150,6 +152,9 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
ValueRange{forOp.getLowerBound(),
forOp.getUpperBound(),
forOp.getStep()});
+ if (splitBound.getType() != forOp.getLowerBound().getType())
+ splitBound = b.createOrFold<arith::IndexCastOp>(
+ loc, forOp.getLowerBound().getType(), splitBound);
// Create ForOp for partial iteration.
b.setInsertionPointAfter(forOp);
@@ -230,6 +235,9 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
auto loc = forOp.getLoc();
Value splitBound = b.createOrFold<AffineApplyOp>(
loc, ubMap, ValueRange{forOp.getLowerBound(), forOp.getStep()});
+ if (splitBound.getType() != forOp.getUpperBound().getType())
+ splitBound = b.createOrFold<arith::IndexCastOp>(
+ loc, forOp.getUpperBound().getType(), splitBound);
// Peel the first iteration.
IRMapping map;
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 684dff8121de6..57c4deab89321 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/DebugLog.h"
@@ -291,14 +292,6 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
return arith::DivUIOp::create(builder, loc, sum, divisor);
}
-/// Returns the trip count of `forOp` if its' low bound, high bound and step are
-/// constants, or optional otherwise. Trip count is computed as
-/// ceilDiv(highBound - lowBound, step).
-static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
- return constantTripCount(forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep());
-}
-
/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -377,7 +370,7 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
Value stepUnrolled;
bool generateEpilogueLoop = true;
- std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
+ std::optional<APInt> constTripCount = forOp.getStaticTripCount();
if (constTripCount) {
// Constant loop bounds computation.
int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
@@ -391,7 +384,8 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
}
int64_t tripCountEvenMultiple =
- *constTripCount - (*constTripCount % unrollFactor);
+ constTripCount->getSExtValue() -
+ (constTripCount->getSExtValue() % unrollFactor);
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
int64_t stepUnrolledCst = stepCst * unrollFactor;
@@ -487,15 +481,15 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
/// Unrolls this loop completely.
LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
IRRewriter rewriter(forOp.getContext());
- std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+ std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
if (!mayBeConstantTripCount.has_value())
return failure();
- uint64_t tripCount = *mayBeConstantTripCount;
- if (tripCount == 0)
+ APInt &tripCount = *mayBeConstantTripCount;
+ if (tripCount.isZero())
return success();
- if (tripCount == 1)
+ if (tripCount.getSExtValue() == 1)
return forOp.promoteIfSingleIteration(rewriter);
- return loopUnrollByFactor(forOp, tripCount);
+ return loopUnrollByFactor(forOp, tripCount.getSExtValue());
}
/// Check if bounds of all inner loops are defined outside of `forOp`
@@ -535,18 +529,18 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
// Currently, only constant trip count that divided by the unroll factor is
// supported.
- std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+ std::optional<APInt> tripCount = forOp.getStaticTripCount();
if (!tripCount.has_value()) {
// If the trip count is dynamic, do not unroll & jam.
LDBG() << "failed to unroll and jam: trip count could not be determined";
return failure();
}
- if (unrollJamFactor > *tripCount) {
+ if (unrollJamFactor > tripCount->getZExtValue()) {
LDBG() << "unroll and jam factor is greater than trip count, set factor to "
"trip "
"count";
- unrollJamFactor = *tripCount;
- } else if (*tripCount % unrollJamFactor != 0) {
+ unrollJamFactor = tripCount->getZExtValue();
+ } else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
"multiple of unroll jam factor";
return failure();
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 34385d76f133a..e5deea6fa21ab 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -11,6 +11,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
@@ -112,21 +113,30 @@ SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
}
/// If ofr is a constant integer or an IntegerAttr, return the integer.
-std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
+/// The boolean indicates whether the value is an index type.
+std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
- APSInt intVal;
+ APInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
- return intVal.getSExtValue();
+ return std::make_pair(intVal, val.getType().isIndex());
return std::null...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/158679
More information about the Mlir-commits
mailing list