[llvm] [SCEV] Apply loop guards to End computeMaxBECountForLT (PR #116187)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 14 01:22:18 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Luke Lau (lukel97)

<details>
<summary>Changes</summary>

This is a follow on from #<!-- -->115705. Applying the loop guard allows us to calculate the maximum trip count in more places, which in turn allows isIndvarOverflowCheckKnownFalse to skip the overflow check.


---
Full diff: https://github.com/llvm/llvm-project/pull/116187.diff


3 Files Affected:

- (modified) llvm/include/llvm/Analysis/ScalarEvolution.h (+2-2) 
- (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+10-9) 
- (modified) llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-known-no-overflow.ll (+2-7) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 4b8cb3a39a86db..8b7745e6ab1034 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -2218,8 +2218,8 @@ class ScalarEvolution {
   ///   actually doesn't, or we'd have to immediately execute UB)
   /// We *don't* assert these preconditions so please be careful.
   const SCEV *computeMaxBECountForLT(const SCEV *Start, const SCEV *Stride,
-                                     const SCEV *End, unsigned BitWidth,
-                                     bool IsSigned);
+                                     const SCEV *End, const Loop *L,
+                                     unsigned BitWidth, bool IsSigned);
 
   /// Verify if an linear IV with positive stride can overflow when in a
   /// less-than comparison, knowing the invariant term of the comparison,
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index b10811133770e1..bb7306f5ba3778 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -12857,11 +12857,10 @@ const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) {
   return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
 }
 
-const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
-                                                    const SCEV *Stride,
-                                                    const SCEV *End,
-                                                    unsigned BitWidth,
-                                                    bool IsSigned) {
+const SCEV *
+ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, const SCEV *Stride,
+                                        const SCEV *End, const Loop *L,
+                                        unsigned BitWidth, bool IsSigned) {
   // The logic in this function assumes we can represent a positive stride.
   // If we can't, the backedge-taken count must be zero.
   if (IsSigned && BitWidth == 1)
@@ -12895,8 +12894,10 @@ const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
   // the case End = RHS of the loop termination condition. This is safe because
   // in the other case (End - Start) is zero, leading to a zero maximum backedge
   // taken count.
-  APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
-                          : APIntOps::umin(getUnsignedRangeMax(End), Limit);
+  const SCEV *GuardedEnd = applyLoopGuards(End, L);
+  APInt MaxEnd = IsSigned
+                     ? APIntOps::smin(getSignedRangeMax(GuardedEnd), Limit)
+                     : APIntOps::umin(getUnsignedRangeMax(GuardedEnd), Limit);
 
   // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
   MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
@@ -13150,7 +13151,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       // loop (RHS), and the fact that IV does not overflow (which is
       // checked above).
       const SCEV *MaxBECount = computeMaxBECountForLT(
-          Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
+          Start, Stride, RHS, L, getTypeSizeInBits(LHS->getType()), IsSigned);
       return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
                        MaxBECount, false /*MaxOrZero*/, Predicates);
     }
@@ -13334,7 +13335,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
     MaxOrZero = true;
   } else {
     ConstantMaxBECount = computeMaxBECountForLT(
-        Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
+        Start, Stride, RHS, L, getTypeSizeInBits(LHS->getType()), IsSigned);
   }
 
   if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-known-no-overflow.ll b/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-known-no-overflow.ll
index a8cf002182e240..735421a4f65ee8 100644
--- a/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-known-no-overflow.ll
+++ b/llvm/test/Transforms/LoopVectorize/RISCV/vectorize-force-tail-with-evl-known-no-overflow.ll
@@ -4,8 +4,7 @@
 ; RUN: -prefer-predicate-over-epilogue=predicate-dont-vectorize \
 ; RUN: -mtriple=riscv64 -mattr=+v -S < %s | FileCheck %s
 
-; TODO: We know the IV will never overflow here so we can skip the overflow
-; check
+; We know the IV will never overflow here so we can skip the overflow check
 
 define void @trip_count_max_1024(ptr %p, i64 %tc) vscale_range(2, 1024) {
 ; CHECK-LABEL: define void @trip_count_max_1024(
@@ -15,11 +14,7 @@ define void @trip_count_max_1024(ptr %p, i64 %tc) vscale_range(2, 1024) {
 ; CHECK-NEXT:    br i1 [[GUARD]], label %[[EXIT:.*]], label %[[LOOP_PREHEADER:.*]]
 ; CHECK:       [[LOOP_PREHEADER]]:
 ; CHECK-NEXT:    [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TC]], i64 1)
-; CHECK-NEXT:    [[TMP0:%.*]] = sub i64 -1, [[UMAX]]
-; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 [[TMP1]], 2
-; CHECK-NEXT:    [[TMP3:%.*]] = icmp ult i64 [[TMP0]], [[TMP2]]
-; CHECK-NEXT:    br i1 [[TMP3]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
+; CHECK-NEXT:    br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
 ; CHECK:       [[VECTOR_PH]]:
 ; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 2

``````````

</details>


https://github.com/llvm/llvm-project/pull/116187


More information about the llvm-commits mailing list