[llvm] LAA: make stride versioning code more robust (PR #97075)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 28 08:57:22 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Ramkumar Ramachandra (artagnon)

<details>
<summary>Changes</summary>

The current stride versioning code in collectStridedAccess is quite fragile, and has implicit effects. Make it more robust by making it clear that Stride - 1 == BTC is a special case, and operate on ConstantRanges directly. Query the exact backedge-taken count instead of the symbolic maximum of it. This patch has the side effect of making it possible to directly return the SCEVUnknown under a cast in getStrideFromPointer, eliminating a second cast-stripping in collectStridedAccess. It also has the side-effect of a positive test update in symbolic-stride.

---
Full diff: https://github.com/llvm/llvm-project/pull/97075.diff


4 Files Affected:

- (modified) llvm/include/llvm/IR/ConstantRange.h (+3) 
- (modified) llvm/lib/Analysis/LoopAccessAnalysis.cpp (+26-36) 
- (modified) llvm/lib/IR/ConstantRange.cpp (+5) 
- (modified) llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll (+4-11) 


``````````diff
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index 7b94b9c6c6d11..86d0a6b35d748 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -277,6 +277,9 @@ class [[nodiscard]] ConstantRange {
   /// Return true if all values in this range are non-negative.
   bool isAllNonNegative() const;
 
+  /// Return true if all values in this range are positive.
+  bool isAllPositive() const;
+
   /// Return the largest unsigned value contained in the ConstantRange.
   APInt getUnsignedMax() const;
 
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 38bf6d8160aa9..4932ed61ec1ba 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -35,6 +35,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugLoc.h"
@@ -61,6 +62,8 @@
 #include <cassert>
 #include <cstdint>
 #include <iterator>
+#include <optional>
+#include <sys/types.h>
 #include <utility>
 #include <variant>
 #include <vector>
@@ -2914,7 +2917,7 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
 
   if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(V))
     if (isa<SCEVUnknown>(C->getOperand()))
-      return V;
+      return C->getOperand();
 
   return nullptr;
 }
@@ -2930,7 +2933,8 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   // computation of an interesting IV - but we chose not to as we
   // don't have a cost model here, and broadening the scope exposes
   // far too many unprofitable cases.
-  const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
+  ScalarEvolution *SE = PSE->getSE();
+  const SCEV *StrideExpr = getStrideFromPointer(Ptr, SE, TheLoop);
   if (!StrideExpr)
     return;
 
@@ -2943,10 +2947,6 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
     return;
   }
 
-  // Avoid adding the "Stride == 1" predicate when we know that
-  // Stride >= Trip-Count. Such a predicate will effectively optimize a single
-  // or zero iteration loop, as Trip-Count <= Stride == 1.
-  //
   // TODO: We are currently not making a very informed decision on when it is
   // beneficial to apply stride versioning. It might make more sense that the
   // users of this analysis (such as the vectorizer) will trigger it, based on
@@ -2956,40 +2956,30 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   // of various possible stride specializations, considering the alternatives
   // of using gather/scatters (if available).
 
-  const SCEV *MaxBTC = PSE->getSymbolicMaxBackedgeTakenCount();
-
-  // Match the types so we can compare the stride and the MaxBTC.
-  // The Stride can be positive/negative, so we sign extend Stride;
-  // The backedgeTakenCount is non-negative, so we zero extend MaxBTC.
-  const DataLayout &DL = TheLoop->getHeader()->getDataLayout();
-  uint64_t StrideTypeSizeBits = DL.getTypeSizeInBits(StrideExpr->getType());
-  uint64_t BETypeSizeBits = DL.getTypeSizeInBits(MaxBTC->getType());
-  const SCEV *CastedStride = StrideExpr;
-  const SCEV *CastedBECount = MaxBTC;
-  ScalarEvolution *SE = PSE->getSE();
-  if (BETypeSizeBits >= StrideTypeSizeBits)
-    CastedStride = SE->getNoopOrSignExtend(StrideExpr, MaxBTC->getType());
-  else
-    CastedBECount = SE->getZeroExtendExpr(MaxBTC, StrideExpr->getType());
-  const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount);
-  // Since TripCount == BackEdgeTakenCount + 1, checking:
-  // "Stride >= TripCount" is equivalent to checking:
-  // Stride - MaxBTC> 0
-  if (SE->isKnownPositive(StrideMinusBETaken)) {
-    LLVM_DEBUG(
-        dbgs() << "LAA: Stride>=TripCount; No point in versioning as the "
-                  "Stride==1 predicate will imply that the loop executes "
-                  "at most once.\n");
+  // Get two signed ranges and compare them, after adjusting for bitwidth. BTC
+  // range could extend into -1.
+  const SCEV *BTC = PSE->getBackedgeTakenCount();
+  ConstantRange BTCRange = SE->getSignedRange(BTC);
+  ConstantRange StrideRange =
+      SE->getSignedRange(StrideExpr).sextOrTrunc(BTCRange.getBitWidth());
+
+  // Stride is zero-extended to compare with BTC.
+  const SCEV *CastedStride =
+      SE->getTruncateOrZeroExtend(StrideExpr, BTC->getType());
+  const SCEV *StrideMinusOne =
+      SE->getMinusSCEV(CastedStride, SE->getOne(CastedStride->getType()));
+
+  // Stride - 1 exactly equal to BTC is a special case for which the loop should
+  // not be versioned. Otherwise, the loop should not be versioned if the range
+  // difference is all positive.
+  if (StrideMinusOne == BTC ||
+      StrideRange.difference(BTCRange).isAllPositive()) {
+    LLVM_DEBUG(dbgs() << "LAA: Not versioning with Stride==1 predicate.\n");
     return;
   }
   LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n");
 
-  // Strip back off the integer cast, and check that our result is a
-  // SCEVUnknown as we expect.
-  const SCEV *StrideBase = StrideExpr;
-  if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
-    StrideBase = C->getOperand();
-  SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
+  SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideExpr);
 }
 
 LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 19041704a40be..b942894d34467 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -440,6 +440,11 @@ bool ConstantRange::isAllNonNegative() const {
   return !isSignWrappedSet() && Lower.isNonNegative();
 }
 
+bool ConstantRange::isAllPositive() const {
+  // Empty and full set are automatically treated correctly.
+  return !isSignWrappedSet() && Lower.isStrictlyPositive();
+}
+
 APInt ConstantRange::getUnsignedMax() const {
   if (isFullSet() || isUpperWrapped())
     return APInt::getMaxValue(getBitWidth());
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll b/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
index 7c1b11e22aef2..e9aeac7ac2bc5 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
@@ -170,23 +170,16 @@ define void @single_stride_castexpr_multiuse(i32 %offset, ptr %src, ptr %dst, i1
 ; CHECK-NEXT:          %gep.src = getelementptr inbounds i32, ptr %src, i64 %iv.3
 ; CHECK-NEXT:      Grouped accesses:
 ; CHECK-NEXT:        Group [[GRP3]]:
-; CHECK-NEXT:          (Low: ((4 * %iv.1) + %dst) High: (804 + (4 * %iv.1) + (-4 * (zext i32 %offset to i64))<nsw> + %dst))
-; CHECK-NEXT:            Member: {((4 * %iv.1) + %dst),+,4}<%inner.loop>
+; CHECK-NEXT:          (Low: (((4 * %iv.1) + %dst) umin ((4 * %iv.1) + (4 * (sext i32 %offset to i64) * (200 + (-1 * (zext i32 %offset to i64))<nsw>)<nsw>) + %dst)) High: (4 + (((4 * %iv.1) + %dst) umax ((4 * %iv.1) + (4 * (sext i32 %offset to i64) * (200 + (-1 * (zext i32 %offset to i64))<nsw>)<nsw>) + %dst))))
+; CHECK-NEXT:            Member: {((4 * %iv.1) + %dst),+,(4 * (sext i32 %offset to i64))<nsw>}<%inner.loop>
 ; CHECK-NEXT:        Group [[GRP4]]:
-; CHECK-NEXT:          (Low: (4 + %src) High: (808 + (-4 * (zext i32 %offset to i64))<nsw> + %src))
-; CHECK-NEXT:            Member: {(4 + %src),+,4}<%inner.loop>
+; CHECK-NEXT:          (Low: ((4 * (zext i32 %offset to i64))<nuw><nsw> + %src) High: (804 + %src))
+; CHECK-NEXT:            Member: {((4 * (zext i32 %offset to i64))<nuw><nsw> + %src),+,4}<%inner.loop>
 ; CHECK-EMPTY:
 ; CHECK-NEXT:      Non vectorizable stores to invariant address were not found in loop.
 ; CHECK-NEXT:      SCEV assumptions:
-; CHECK-NEXT:      Equal predicate: %offset == 1
 ; CHECK-EMPTY:
 ; CHECK-NEXT:      Expressions re-written:
-; CHECK-NEXT:      [PSE] %gep.src = getelementptr inbounds i32, ptr %src, i64 %iv.3:
-; CHECK-NEXT:        {((4 * (zext i32 %offset to i64))<nuw><nsw> + %src),+,4}<%inner.loop>
-; CHECK-NEXT:        --> {(4 + %src),+,4}<%inner.loop>
-; CHECK-NEXT:      [PSE] %gep.dst = getelementptr i32, ptr %dst, i64 %iv.2:
-; CHECK-NEXT:        {((4 * %iv.1) + %dst),+,(4 * (sext i32 %offset to i64))<nsw>}<%inner.loop>
-; CHECK-NEXT:        --> {((4 * %iv.1) + %dst),+,4}<%inner.loop>
 ; CHECK-NEXT:    outer.header:
 ; CHECK-NEXT:      Report: loop is not the innermost loop
 ; CHECK-NEXT:      Dependences:

``````````

</details>


https://github.com/llvm/llvm-project/pull/97075


More information about the llvm-commits mailing list