[llvm] [SCEV] Use power of two facts involving vscale when inferring wrap flags (PR #101380)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 31 11:07:37 PDT 2024
https://github.com/preames created https://github.com/llvm/llvm-project/pull/101380
SCEV has logic for inferring wrap flags on AddRecs which are known to control an exit based on whether the step is a power of two. This logic only considered constants, and thus did not trigger for steps such as (4 x vscale) which are common in scalably vectorized loops.
The net effect is that we were very sensative to the preservation of nsw/nuw flags on such IVs, and could not infer trip counts if they got lost for any reason.
>From bf7b6c6cb5c7eae5d051ca7006c0ab63a0bc289f Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Tue, 30 Jul 2024 15:06:56 -0700
Subject: [PATCH] [SCEV] Use power of two facts involving vscale when inferring
wrap flags
SCEV has logic for inferring wrap flags on AddRecs which are known to
control an exit based on whether the step is a power of two. This logic
only considered constants, and thus did not trigger for steps such
as (4 x vscale) which are common in scalably vectorized loops.
The net effect is that we were very sensative to the preservation of
nsw/nuw flags on such IVs, and could not infer trip counts if they
got lost for any reason.
---
llvm/include/llvm/Analysis/ScalarEvolution.h | 3 +
llvm/lib/Analysis/ScalarEvolution.cpp | 130 ++++++++++--------
.../trip-count-scalable-stride.ll | 26 ++--
3 files changed, 89 insertions(+), 70 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index d9bfca763819f..fbefa2bd074dd 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1028,6 +1028,9 @@ class ScalarEvolution {
/// Test if the given expression is known to be non-zero.
bool isKnownNonZero(const SCEV *S);
+ /// Test if the given expression is known to be a power of 2.
+ bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero = false);
+
/// Splits SCEV expression \p S into two SCEVs. One of them is obtained from
/// \p S by substitution of all AddRec sub-expression related to loop \p L
/// with initial value of that SCEV. The second is obtained from \p S by
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index fb56d5d436653..159aa6e93a6ad 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9156,16 +9156,14 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
auto *InnerLHS = LHS;
if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
InnerLHS = ZExt->getOperand();
- if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
- auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
- if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
- StrideC && StrideC->getAPInt().isPowerOf2()) {
- auto Flags = AR->getNoWrapFlags();
- Flags = setFlags(Flags, SCEV::FlagNW);
- SmallVector<const SCEV*> Operands{AR->operands()};
- Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
- setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
- }
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
+ AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
+ isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this))) {
+ auto Flags = AR->getNoWrapFlags();
+ Flags = setFlags(Flags, SCEV::FlagNW);
+ SmallVector<const SCEV*> Operands{AR->operands()};
+ Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
}
}
@@ -10845,6 +10843,25 @@ bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
return getUnsignedRangeMin(S) != 0;
}
+bool ScalarEvolution::isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero) {
+ auto nonRecursive = [this](const SCEV *S) {
+ if (auto *C = dyn_cast<SCEVConstant>(S))
+ return C->getAPInt().isPowerOf2();
+ // The vscale_range indicates vscale is a power-of-two.
+ return S->getSCEVType() == scVScale && F.hasFnAttribute(Attribute::VScaleRange);;
+ };
+
+ if (nonRecursive(S))
+ return true;
+
+ auto *Mul = dyn_cast<SCEVMulExpr>(S);
+ if (!Mul || Mul->getNumOperands() != 2)
+ return false;
+ return nonRecursive(Mul->getOperand(0)) && nonRecursive(Mul->getOperand(1)) &&
+ (OrZero || isKnownNonZero(S));
+}
+
+
std::pair<const SCEV *, const SCEV *>
ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
// Compute SCEV on entry of loop L.
@@ -12775,8 +12792,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (!isLoopInvariant(RHS, L))
return false;
- auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
- if (!StrideC || !StrideC->getAPInt().isPowerOf2())
+ if (!isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this)))
return false;
if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L))
@@ -13132,52 +13148,50 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
// "(Start - End) + (Stride - 1)" has unsigned overflow.
const SCEV *One = getOne(Stride->getType());
bool MayAddOverflow = [&] {
- if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
- if (StrideC->getAPInt().isPowerOf2()) {
- // Suppose Stride is a power of two, and Start/End are unsigned
- // integers. Let UMAX be the largest representable unsigned
- // integer.
- //
- // By the preconditions of this function, we know
- // "(Start + Stride * N) >= End", and this doesn't overflow.
- // As a formula:
- //
- // End <= (Start + Stride * N) <= UMAX
- //
- // Subtracting Start from all the terms:
- //
- // End - Start <= Stride * N <= UMAX - Start
- //
- // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
- //
- // End - Start <= Stride * N <= UMAX
- //
- // Stride * N is a multiple of Stride. Therefore,
- //
- // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
- //
- // Since Stride is a power of two, UMAX + 1 is divisible by
- // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
- // write:
- //
- // End - Start <= Stride * N <= UMAX - Stride - 1
- //
- // Dropping the middle term:
- //
- // End - Start <= UMAX - Stride - 1
- //
- // Adding Stride - 1 to both sides:
- //
- // (End - Start) + (Stride - 1) <= UMAX
- //
- // In other words, the addition doesn't have unsigned overflow.
- //
- // A similar proof works if we treat Start/End as signed values.
- // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
- // to use signed max instead of unsigned max. Note that we're
- // trying to prove a lack of unsigned overflow in either case.
- return false;
- }
+ if (isKnownToBeAPowerOfTwo(Stride)) {
+ // Suppose Stride is a power of two, and Start/End are unsigned
+ // integers. Let UMAX be the largest representable unsigned
+ // integer.
+ //
+ // By the preconditions of this function, we know
+ // "(Start + Stride * N) >= End", and this doesn't overflow.
+ // As a formula:
+ //
+ // End <= (Start + Stride * N) <= UMAX
+ //
+ // Subtracting Start from all the terms:
+ //
+ // End - Start <= Stride * N <= UMAX - Start
+ //
+ // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
+ //
+ // End - Start <= Stride * N <= UMAX
+ //
+ // Stride * N is a multiple of Stride. Therefore,
+ //
+ // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
+ //
+ // Since Stride is a power of two, UMAX + 1 is divisible by
+ // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
+ // write:
+ //
+ // End - Start <= Stride * N <= UMAX - Stride - 1
+ //
+ // Dropping the middle term:
+ //
+ // End - Start <= UMAX - Stride - 1
+ //
+ // Adding Stride - 1 to both sides:
+ //
+ // (End - Start) + (Stride - 1) <= UMAX
+ //
+ // In other words, the addition doesn't have unsigned overflow.
+ //
+ // A similar proof works if we treat Start/End as signed values.
+ // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
+ // to use signed max instead of unsigned max. Note that we're
+ // trying to prove a lack of unsigned overflow in either case.
+ return false;
}
if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
// If Start is equal to Stride, (End - Start) + (Stride - 1) == End
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
index 943389d07eb8b..50e6014734f31 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
@@ -374,15 +374,16 @@ define void @vscale_slt_noflags(ptr nocapture %A, i32 %n) mustprogress vscale_ra
; CHECK-NEXT: %vscale = call i32 @llvm.vscale.i32()
; CHECK-NEXT: --> vscale U: [2,1025) S: [2,1025)
; CHECK-NEXT: %i.05 = phi i32 [ %add, %for.body ], [ 0, %entry ]
-; CHECK-NEXT: --> {0,+,vscale}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {0,+,vscale}<%for.body> U: full-set S: full-set Exits: (vscale * ((-1 + %n) /u vscale))<nuw> LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %A, i32 %i.05
-; CHECK-NEXT: --> {%A,+,(4 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {%A,+,(4 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: ((4 * vscale * ((-1 + %n) /u vscale)) + %A) LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: %add = add i32 %i.05, %vscale
-; CHECK-NEXT: --> {vscale,+,vscale}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {vscale,+,vscale}<nw><%for.body> U: full-set S: full-set Exits: (vscale * (1 + ((-1 + %n) /u vscale))<nuw>) LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: Determining loop execution counts for: @vscale_slt_noflags
-; CHECK-NEXT: Loop %for.body: Unpredictable backedge-taken count.
-; CHECK-NEXT: Loop %for.body: Unpredictable constant max backedge-taken count.
-; CHECK-NEXT: Loop %for.body: Unpredictable symbolic max backedge-taken count.
+; CHECK-NEXT: Loop %for.body: backedge-taken count is ((-1 + %n) /u vscale)
+; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 1073741822
+; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u vscale)
+; CHECK-NEXT: Loop %for.body: Trip multiple is 1
;
entry:
%vscale = call i32 @llvm.vscale.i32()
@@ -411,15 +412,16 @@ define void @vscalex4_ult_noflags(ptr nocapture %A, i32 %n) mustprogress vscale_
; CHECK-NEXT: %VF = mul i32 %vscale, 4
; CHECK-NEXT: --> (4 * vscale)<nuw><nsw> U: [8,4097) S: [8,4097)
; CHECK-NEXT: %i.05 = phi i32 [ %add, %for.body ], [ 0, %entry ]
-; CHECK-NEXT: --> {0,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {0,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: (4 * vscale * ((-1 + %n) /u (4 * vscale)<nuw><nsw>)) LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %A, i32 %i.05
-; CHECK-NEXT: --> {%A,+,(16 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {%A,+,(16 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: ((16 * vscale * ((-1 + %n) /u (4 * vscale)<nuw><nsw>)) + %A) LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: %add = add i32 %i.05, %VF
-; CHECK-NEXT: --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT: --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<nw><%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: (vscale * (4 + (4 * ((-1 + %n) /u (4 * vscale)<nuw><nsw>))<nuw><nsw>)<nuw>) LoopDispositions: { %for.body: Computable }
; CHECK-NEXT: Determining loop execution counts for: @vscalex4_ult_noflags
-; CHECK-NEXT: Loop %for.body: Unpredictable backedge-taken count.
-; CHECK-NEXT: Loop %for.body: Unpredictable constant max backedge-taken count.
-; CHECK-NEXT: Loop %for.body: Unpredictable symbolic max backedge-taken count.
+; CHECK-NEXT: Loop %for.body: backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
+; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 536870910
+; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
+; CHECK-NEXT: Loop %for.body: Trip multiple is 1
;
entry:
%vscale = call i32 @llvm.vscale.i32()
More information about the llvm-commits
mailing list