[llvm] [SCEV] Use power of two facts involving vscale when inferring wrap flags (PR #101380)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 31 11:08:10 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

SCEV has logic for inferring wrap flags on AddRecs which are known to control an exit based on whether the step is a power of two.  This logic only considered constants, and thus did not trigger for steps such as (4 x vscale) which are common in scalably vectorized loops.

The net effect is that we were very sensative to the preservation of nsw/nuw flags on such IVs, and could not infer trip counts if they got lost for any reason.

---
Full diff: https://github.com/llvm/llvm-project/pull/101380.diff


3 Files Affected:

- (modified) llvm/include/llvm/Analysis/ScalarEvolution.h (+3) 
- (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+72-58) 
- (modified) llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll (+14-12) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index d9bfca763819f..fbefa2bd074dd 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1028,6 +1028,9 @@ class ScalarEvolution {
   /// Test if the given expression is known to be non-zero.
   bool isKnownNonZero(const SCEV *S);
 
+  /// Test if the given expression is known to be a power of 2.
+  bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero = false);
+
   /// Splits SCEV expression \p S into two SCEVs. One of them is obtained from
   /// \p S by substitution of all AddRec sub-expression related to loop \p L
   /// with initial value of that SCEV. The second is obtained from \p S by
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index fb56d5d436653..159aa6e93a6ad 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9156,16 +9156,14 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
     auto *InnerLHS = LHS;
     if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
       InnerLHS = ZExt->getOperand();
-    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
-      auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
-      if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
-          StrideC && StrideC->getAPInt().isPowerOf2()) {
-        auto Flags = AR->getNoWrapFlags();
-        Flags = setFlags(Flags, SCEV::FlagNW);
-        SmallVector<const SCEV*> Operands{AR->operands()};
-        Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
-        setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
-      }
+    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
+        AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
+        isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this))) {
+      auto Flags = AR->getNoWrapFlags();
+      Flags = setFlags(Flags, SCEV::FlagNW);
+      SmallVector<const SCEV*> Operands{AR->operands()};
+      Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
+      setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
     }
   }
 
@@ -10845,6 +10843,25 @@ bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
   return getUnsignedRangeMin(S) != 0;
 }
 
+bool ScalarEvolution::isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero) {
+  auto nonRecursive = [this](const SCEV *S) {
+    if (auto *C = dyn_cast<SCEVConstant>(S))
+      return C->getAPInt().isPowerOf2();
+    // The vscale_range indicates vscale is a power-of-two.
+    return S->getSCEVType() == scVScale && F.hasFnAttribute(Attribute::VScaleRange);;
+  };
+
+  if (nonRecursive(S))
+    return true;
+
+  auto *Mul = dyn_cast<SCEVMulExpr>(S);
+  if (!Mul || Mul->getNumOperands() != 2)
+    return false;
+  return nonRecursive(Mul->getOperand(0)) && nonRecursive(Mul->getOperand(1)) &&
+    (OrZero || isKnownNonZero(S));
+}
+
+
 std::pair<const SCEV *, const SCEV *>
 ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
   // Compute SCEV on entry of loop L.
@@ -12775,8 +12792,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
     if (!isLoopInvariant(RHS, L))
       return false;
 
-    auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
-    if (!StrideC || !StrideC->getAPInt().isPowerOf2())
+    if (!isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this)))
       return false;
 
     if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L))
@@ -13132,52 +13148,50 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       // "(Start - End) + (Stride - 1)" has unsigned overflow.
       const SCEV *One = getOne(Stride->getType());
       bool MayAddOverflow = [&] {
-        if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
-          if (StrideC->getAPInt().isPowerOf2()) {
-            // Suppose Stride is a power of two, and Start/End are unsigned
-            // integers.  Let UMAX be the largest representable unsigned
-            // integer.
-            //
-            // By the preconditions of this function, we know
-            // "(Start + Stride * N) >= End", and this doesn't overflow.
-            // As a formula:
-            //
-            //   End <= (Start + Stride * N) <= UMAX
-            //
-            // Subtracting Start from all the terms:
-            //
-            //   End - Start <= Stride * N <= UMAX - Start
-            //
-            // Since Start is unsigned, UMAX - Start <= UMAX.  Therefore:
-            //
-            //   End - Start <= Stride * N <= UMAX
-            //
-            // Stride * N is a multiple of Stride. Therefore,
-            //
-            //   End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
-            //
-            // Since Stride is a power of two, UMAX + 1 is divisible by
-            // Stride. Therefore, UMAX mod Stride == Stride - 1.  So we can
-            // write:
-            //
-            //   End - Start <= Stride * N <= UMAX - Stride - 1
-            //
-            // Dropping the middle term:
-            //
-            //   End - Start <= UMAX - Stride - 1
-            //
-            // Adding Stride - 1 to both sides:
-            //
-            //   (End - Start) + (Stride - 1) <= UMAX
-            //
-            // In other words, the addition doesn't have unsigned overflow.
-            //
-            // A similar proof works if we treat Start/End as signed values.
-            // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
-            // to use signed max instead of unsigned max. Note that we're
-            // trying to prove a lack of unsigned overflow in either case.
-            return false;
-          }
+        if (isKnownToBeAPowerOfTwo(Stride)) {
+          // Suppose Stride is a power of two, and Start/End are unsigned
+          // integers.  Let UMAX be the largest representable unsigned
+          // integer.
+          //
+          // By the preconditions of this function, we know
+          // "(Start + Stride * N) >= End", and this doesn't overflow.
+          // As a formula:
+          //
+          //   End <= (Start + Stride * N) <= UMAX
+          //
+          // Subtracting Start from all the terms:
+          //
+          //   End - Start <= Stride * N <= UMAX - Start
+          //
+          // Since Start is unsigned, UMAX - Start <= UMAX.  Therefore:
+          //
+          //   End - Start <= Stride * N <= UMAX
+          //
+          // Stride * N is a multiple of Stride. Therefore,
+          //
+          //   End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
+          //
+          // Since Stride is a power of two, UMAX + 1 is divisible by
+          // Stride. Therefore, UMAX mod Stride == Stride - 1.  So we can
+          // write:
+          //
+          //   End - Start <= Stride * N <= UMAX - Stride - 1
+          //
+          // Dropping the middle term:
+          //
+          //   End - Start <= UMAX - Stride - 1
+          //
+          // Adding Stride - 1 to both sides:
+          //
+          //   (End - Start) + (Stride - 1) <= UMAX
+          //
+          // In other words, the addition doesn't have unsigned overflow.
+          //
+          // A similar proof works if we treat Start/End as signed values.
+          // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
+          // to use signed max instead of unsigned max. Note that we're
+          // trying to prove a lack of unsigned overflow in either case.
+          return false;
         }
         if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
           // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
index 943389d07eb8b..50e6014734f31 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-scalable-stride.ll
@@ -374,15 +374,16 @@ define void @vscale_slt_noflags(ptr nocapture %A, i32 %n) mustprogress vscale_ra
 ; CHECK-NEXT:    %vscale = call i32 @llvm.vscale.i32()
 ; CHECK-NEXT:    --> vscale U: [2,1025) S: [2,1025)
 ; CHECK-NEXT:    %i.05 = phi i32 [ %add, %for.body ], [ 0, %entry ]
-; CHECK-NEXT:    --> {0,+,vscale}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {0,+,vscale}<%for.body> U: full-set S: full-set Exits: (vscale * ((-1 + %n) /u vscale))<nuw> LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:    %arrayidx = getelementptr inbounds i32, ptr %A, i32 %i.05
-; CHECK-NEXT:    --> {%A,+,(4 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {%A,+,(4 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: ((4 * vscale * ((-1 + %n) /u vscale)) + %A) LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:    %add = add i32 %i.05, %vscale
-; CHECK-NEXT:    --> {vscale,+,vscale}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {vscale,+,vscale}<nw><%for.body> U: full-set S: full-set Exits: (vscale * (1 + ((-1 + %n) /u vscale))<nuw>) LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:  Determining loop execution counts for: @vscale_slt_noflags
-; CHECK-NEXT:  Loop %for.body: Unpredictable backedge-taken count.
-; CHECK-NEXT:  Loop %for.body: Unpredictable constant max backedge-taken count.
-; CHECK-NEXT:  Loop %for.body: Unpredictable symbolic max backedge-taken count.
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is ((-1 + %n) /u vscale)
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is i32 1073741822
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u vscale)
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
 ;
 entry:
   %vscale = call i32 @llvm.vscale.i32()
@@ -411,15 +412,16 @@ define void @vscalex4_ult_noflags(ptr nocapture %A, i32 %n) mustprogress vscale_
 ; CHECK-NEXT:    %VF = mul i32 %vscale, 4
 ; CHECK-NEXT:    --> (4 * vscale)<nuw><nsw> U: [8,4097) S: [8,4097)
 ; CHECK-NEXT:    %i.05 = phi i32 [ %add, %for.body ], [ 0, %entry ]
-; CHECK-NEXT:    --> {0,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {0,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: (4 * vscale * ((-1 + %n) /u (4 * vscale)<nuw><nsw>)) LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:    %arrayidx = getelementptr inbounds i32, ptr %A, i32 %i.05
-; CHECK-NEXT:    --> {%A,+,(16 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {%A,+,(16 * vscale)<nuw><nsw>}<%for.body> U: full-set S: full-set Exits: ((16 * vscale * ((-1 + %n) /u (4 * vscale)<nuw><nsw>)) + %A) LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:    %add = add i32 %i.05, %VF
-; CHECK-NEXT:    --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: <<Unknown>> LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    --> {(4 * vscale)<nuw><nsw>,+,(4 * vscale)<nuw><nsw>}<nw><%for.body> U: [0,-3) S: [-2147483648,2147483645) Exits: (vscale * (4 + (4 * ((-1 + %n) /u (4 * vscale)<nuw><nsw>))<nuw><nsw>)<nuw>) LoopDispositions: { %for.body: Computable }
 ; CHECK-NEXT:  Determining loop execution counts for: @vscalex4_ult_noflags
-; CHECK-NEXT:  Loop %for.body: Unpredictable backedge-taken count.
-; CHECK-NEXT:  Loop %for.body: Unpredictable constant max backedge-taken count.
-; CHECK-NEXT:  Loop %for.body: Unpredictable symbolic max backedge-taken count.
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is i32 536870910
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is ((-1 + %n) /u (4 * vscale)<nuw><nsw>)
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
 ;
 entry:
   %vscale = call i32 @llvm.vscale.i32()

``````````

</details>


https://github.com/llvm/llvm-project/pull/101380


More information about the llvm-commits mailing list