[llvm] LAA: make stride versioning code more robust (PR #97075)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 21 11:20:59 PDT 2024


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/97075

>From 52935f610360d345415dbeab7ec240bb95c81064 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 27 Jun 2024 17:12:33 +0100
Subject: [PATCH] LAA: make stride versioning code more robust

Rewrite the stride-versioning code in LoopAccessAnalysis more robust,
and make it possible to directly return the SCEVUnknown under a cast in
getStrideFromPointer, eliminating a second cast-stripping in
collectStridedAccess.
---
 llvm/lib/Analysis/LoopAccessAnalysis.cpp | 56 +++++++++++-------------
 1 file changed, 26 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 980f142f113265..109c256121b774 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -2929,13 +2929,14 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   // computation of an interesting IV - but we chose not to as we
   // don't have a cost model here, and broadening the scope exposes
   // far too many unprofitable cases.
-  const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
-  if (!StrideExpr)
+  ScalarEvolution *SE = PSE->getSE();
+  const SCEV *Stride = getStrideFromPointer(Ptr, SE, TheLoop);
+  if (!Stride)
     return;
 
   LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for "
                        "versioning:");
-  LLVM_DEBUG(dbgs() << "  Ptr: " << *Ptr << " Stride: " << *StrideExpr << "\n");
+  LLVM_DEBUG(dbgs() << "  Ptr: " << *Ptr << " Stride: " << *Stride << "\n");
 
   if (!SpeculateUnitStride) {
     LLVM_DEBUG(dbgs() << "  Chose not to due to -laa-speculate-unit-stride\n");
@@ -2955,40 +2956,35 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   // of various possible stride specializations, considering the alternatives
   // of using gather/scatters (if available).
 
-  const SCEV *MaxBTC = PSE->getSymbolicMaxBackedgeTakenCount();
+  const SCEV *BTC = PSE->getSymbolicMaxBackedgeTakenCount();
 
-  // Match the types so we can compare the stride and the MaxBTC.
-  // The Stride can be positive/negative, so we sign extend Stride;
-  // The backedgeTakenCount is non-negative, so we zero extend MaxBTC.
+  // Sign-extend the stride or zero-extend the BTC, as appropriate, before
+  // performing subtraction. We take care to do this because an unknown stride
+  // might equal an unknown TC, and we don't want to version the loop in that
+  // case.
+  const SCEV *CastedStride = Stride;
+  const SCEV *CastedBTC = BTC;
   const DataLayout &DL = TheLoop->getHeader()->getDataLayout();
-  uint64_t StrideTypeSizeBits = DL.getTypeSizeInBits(StrideExpr->getType());
-  uint64_t BETypeSizeBits = DL.getTypeSizeInBits(MaxBTC->getType());
-  const SCEV *CastedStride = StrideExpr;
-  const SCEV *CastedBECount = MaxBTC;
-  ScalarEvolution *SE = PSE->getSE();
-  if (BETypeSizeBits >= StrideTypeSizeBits)
-    CastedStride = SE->getNoopOrSignExtend(StrideExpr, MaxBTC->getType());
+  if (DL.getTypeSizeInBits(BTC->getType()) >=
+      DL.getTypeSizeInBits(Stride->getType()))
+    CastedStride = SE->getNoopOrSignExtend(Stride, BTC->getType());
   else
-    CastedBECount = SE->getZeroExtendExpr(MaxBTC, StrideExpr->getType());
-  const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount);
-  // Since TripCount == BackEdgeTakenCount + 1, checking:
-  // "Stride >= TripCount" is equivalent to checking:
-  // Stride - MaxBTC> 0
-  if (SE->isKnownPositive(StrideMinusBETaken)) {
-    LLVM_DEBUG(
-        dbgs() << "LAA: Stride>=TripCount; No point in versioning as the "
-                  "Stride==1 predicate will imply that the loop executes "
-                  "at most once.\n");
+    CastedBTC = SE->getZeroExtendExpr(BTC, Stride->getType());
+  const SCEV *StrideMinusBTC = SE->getMinusSCEV(CastedStride, CastedBTC);
+
+  // Stride - BTC > 0 is equivalent to Stride >= TripCount, but computing
+  // TripCount from BTC would introduce more casts, and Stride - TC might fail
+  // the known-non-negative test.
+  if (SE->isKnownPositive(StrideMinusBTC)) {
+    LLVM_DEBUG(dbgs() << "LAA: Not versioning with Stride==1 predicate.\n");
     return;
   }
   LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n");
 
-  // Strip back off the integer cast, and check that our result is a
-  // SCEVUnknown as we expect.
-  const SCEV *StrideBase = StrideExpr;
-  if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
-    StrideBase = C->getOperand();
-  SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
+  // Strip back off the integer cast, to get the resulting SCEVUnknown.
+  if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(Stride))
+    Stride = C->getOperand();
+  SymbolicStrides[Ptr] = cast<SCEVUnknown>(Stride);
 }
 
 LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,



More information about the llvm-commits mailing list