[llvm] 733b8b2 - [LAA] Simplify identification of speculatable strides [nfc]

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu May 11 11:48:50 PDT 2023


Author: Philip Reames
Date: 2023-05-11T11:48:21-07:00
New Revision: 733b8b2b4919e76e1768919461fd5eb211e41d79

URL: https://github.com/llvm/llvm-project/commit/733b8b2b4919e76e1768919461fd5eb211e41d79
DIFF: https://github.com/llvm/llvm-project/commit/733b8b2b4919e76e1768919461fd5eb211e41d79.diff

LOG: [LAA] Simplify identification of speculatable strides [nfc]

Mostly just avoiding the need to keep both Value and SCEVs flowing through with consistent handling.  We can do everything in terms of SCEV - aside from the profitability heuristics which are now isolated in one spot.

Added: 
    

Modified: 
    llvm/lib/Analysis/LoopAccessAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 358f97f83d40..5a0b1abe96d9 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -2610,7 +2610,7 @@ static Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) {
 
 /// Get the stride of a pointer access in a loop. Looks for symbolic
 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
-static Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
+static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
   auto *PtrTy = dyn_cast<PointerType>(Ptr->getType());
   if (!PtrTy || PtrTy->isAggregateType())
     return nullptr;
@@ -2664,28 +2664,27 @@ static Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
     }
   }
 
-  // Strip off casts.
-  Type *StripedOffRecurrenceCast = nullptr;
-  if (const SCEVIntegralCastExpr *C = dyn_cast<SCEVIntegralCastExpr>(V)) {
-    StripedOffRecurrenceCast = C->getType();
-    V = C->getOperand();
-  }
+  // Note that the restriction after this loop invariant check are only
+  // profitability restrictions.
+  if (!SE->isLoopInvariant(V, Lp))
+    return nullptr;
 
   // Look for the loop invariant symbolic value.
   const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V);
-  if (!U)
-    return nullptr;
+  if (!U) {
+    const auto *C = dyn_cast<SCEVIntegralCastExpr>(V);
+    if (!C)
+      return nullptr;
+    U = dyn_cast<SCEVUnknown>(C->getOperand());
+    if (!U)
+      return nullptr;
 
-  Value *Stride = U->getValue();
-  if (!Lp->isLoopInvariant(Stride))
-    return nullptr;
-
-  // If we have stripped off the recurrence cast we have to make sure that we
-  // return the value that is used in this loop so that we can replace it later.
-  if (StripedOffRecurrenceCast)
-    Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast);
+    // Match legacy behavior - this is not needed for correctness
+    if (!getUniqueCastUse(U->getValue(), Lp, V->getType()))
+      return nullptr;
+  }
 
-  return Stride;
+  return V;
 }
 
 void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
@@ -2699,13 +2698,13 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   // computation of an interesting IV - but we chose not to as we
   // don't have a cost model here, and broadening the scope exposes
   // far too many unprofitable cases.
-  Value *Stride = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
-  if (!Stride)
+  const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
+  if (!StrideExpr)
     return;
 
   LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for "
                        "versioning:");
-  LLVM_DEBUG(dbgs() << "  Ptr: " << *Ptr << " Stride: " << *Stride << "\n");
+  LLVM_DEBUG(dbgs() << "  Ptr: " << *Ptr << " Stride: " << *StrideExpr << "\n");
 
   if (!SpeculateUnitStride) {
     LLVM_DEBUG(dbgs() << "  Chose not to due to -laa-speculate-unit-stride\n");
@@ -2725,7 +2724,6 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   // of various possible stride specializations, considering the alternatives
   // of using gather/scatters (if available).
 
-  const SCEV *StrideExpr = PSE->getSCEV(Stride);
   const SCEV *BETakenCount = PSE->getBackedgeTakenCount();
 
   // Match the types so we can compare the stride and the BETakenCount.
@@ -2756,8 +2754,10 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
 
   // Strip back off the integer cast, and check that our result is a
   // SCEVUnknown as we expect.
-  Value *StrideVal = stripIntegerCast(Stride);
-  SymbolicStrides[Ptr] = cast<SCEVUnknown>(PSE->getSCEV(StrideVal));
+  const SCEV *StrideBase = StrideExpr;
+  if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
+    StrideBase = C->getOperand();
+  SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
 }
 
 LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,


        


More information about the llvm-commits mailing list