[llvm] e41dce4 - [LAA/LV] Simplify stride speculation logic [NFC] (try 2)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu May 11 10:19:37 PDT 2023


Author: Philip Reames
Date: 2023-05-11T10:19:23-07:00
New Revision: e41dce4d4974f41d8e7572dfc698e5ddd55a3d4b

URL: https://github.com/llvm/llvm-project/commit/e41dce4d4974f41d8e7572dfc698e5ddd55a3d4b
DIFF: https://github.com/llvm/llvm-project/commit/e41dce4d4974f41d8e7572dfc698e5ddd55a3d4b.diff

LOG: [LAA/LV] Simplify stride speculation logic [NFC] (try 2)

The original commit wasn't quite NFC, and this was caught by an arguably overly strong assert.  Specifically, I'd failed to strip off the integer cast off the SCEV before saving it in the map.  The result - other than a failed assert - is that we'd speculate on the casted unknown, not the unknown.  The only case I can think of where that might change behavior would be a sext(i1 load).  I doubt that case is interesting in practice, but it's good to be strictly NFC on this change regardless.

Original commit message follows..

The existing code makes it hard to tell that collectStridedAccess is really about identifying some loop invariant SCEV which is *profitable* to speculate is equal to one. The odd dual usage structure of Value and SCEV confuses this point.

We could choose to loosen the profitability analysis if desired. I'm not proposing doing so at this time as it exposes too many cases where the speculation is unprofitable.

Differential Revision: https://reviews.llvm.org/D147750

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/LoopAccessAnalysis.h
    llvm/include/llvm/Analysis/VectorUtils.h
    llvm/lib/Analysis/LoopAccessAnalysis.cpp
    llvm/lib/Analysis/VectorUtils.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index fae22d484a09b..11b4d621d7640 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -184,7 +184,7 @@ class MemoryDepChecker {
   ///
   /// Only checks sets with elements in \p CheckDeps.
   bool areDepsSafe(DepCandidates &AccessSets, MemAccessInfoList &CheckDeps,
-                   const ValueToValueMap &Strides);
+                   const DenseMap<Value *, const SCEV *> &Strides);
 
   /// No memory dependence was encountered that would inhibit
   /// vectorization.
@@ -316,7 +316,7 @@ class MemoryDepChecker {
   /// Otherwise, this function returns true signaling a possible dependence.
   Dependence::DepType isDependent(const MemAccessInfo &A, unsigned AIdx,
                                   const MemAccessInfo &B, unsigned BIdx,
-                                  const ValueToValueMap &Strides);
+                                  const DenseMap<Value *, const SCEV *> &Strides);
 
   /// Check whether the data dependence could prevent store-load
   /// forwarding.
@@ -612,7 +612,9 @@ class LoopAccessInfo {
 
   /// If an access has a symbolic strides, this maps the pointer value to
   /// the stride symbol.
-  const ValueToValueMap &getSymbolicStrides() const { return SymbolicStrides; }
+  const DenseMap<Value *, const SCEV *> &getSymbolicStrides() const {
+    return SymbolicStrides;
+  }
 
   /// Pointer has a symbolic stride.
   bool hasStride(Value *V) const { return StrideSet.count(V); }
@@ -699,7 +701,7 @@ class LoopAccessInfo {
 
   /// If an access has a symbolic strides, this maps the pointer value to
   /// the stride symbol.
-  ValueToValueMap SymbolicStrides;
+  DenseMap<Value *, const SCEV *> SymbolicStrides;
 
   /// Set of symbolic strides values.
   SmallPtrSet<Value *, 8> StrideSet;
@@ -716,9 +718,10 @@ Value *stripIntegerCast(Value *V);
 ///
 /// \p PtrToStride provides the mapping between the pointer value and its
 /// stride as collected by LoopVectorizationLegality::collectStridedAccess.
-const SCEV *replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
-                                      const ValueToValueMap &PtrToStride,
-                                      Value *Ptr);
+const SCEV *
+replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
+                          const DenseMap<Value *, const SCEV *> &PtrToStride,
+                          Value *Ptr);
 
 /// If the pointer has a constant stride return it in units of the access type
 /// size.  Otherwise return std::nullopt.
@@ -737,7 +740,7 @@ const SCEV *replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
 std::optional<int64_t>
 getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr,
              const Loop *Lp,
-             const ValueToValueMap &StridesMap = ValueToValueMap(),
+             const DenseMap<Value *, const SCEV *> &StridesMap = DenseMap<Value *, const SCEV *>(),
              bool Assume = false, bool ShouldCheckWrap = true);
 
 /// Returns the distance between the pointers \p PtrA and \p PtrB iff they are

diff  --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 5d824540c3896..18214fe19700e 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -917,7 +917,7 @@ class InterleavedAccessInfo {
   /// Collect all the accesses with a constant stride in program order.
   void collectConstStrideAccesses(
       MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
-      const ValueToValueMap &Strides);
+      const DenseMap<Value *, const SCEV *> &Strides);
 
   /// Returns true if \p Stride is allowed in an interleaved group.
   static bool isStrided(int Stride);

diff  --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 351e09094d487..df21679e14448 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -154,23 +154,25 @@ Value *llvm::stripIntegerCast(Value *V) {
 }
 
 const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
-                                            const ValueToValueMap &PtrToStride,
+                                            const DenseMap<Value *, const SCEV *> &PtrToStride,
                                             Value *Ptr) {
   const SCEV *OrigSCEV = PSE.getSCEV(Ptr);
 
   // If there is an entry in the map return the SCEV of the pointer with the
   // symbolic stride replaced by one.
-  ValueToValueMap::const_iterator SI = PtrToStride.find(Ptr);
+  DenseMap<Value *, const SCEV *>::const_iterator SI = PtrToStride.find(Ptr);
   if (SI == PtrToStride.end())
     // For a non-symbolic stride, just return the original expression.
     return OrigSCEV;
 
-  Value *StrideVal = stripIntegerCast(SI->second);
-
-  ScalarEvolution *SE = PSE.getSE();
-  const SCEV *StrideSCEV = SE->getSCEV(StrideVal);
+  const SCEV *StrideSCEV = SI->second;
+  // Note: This assert is both overly strong and overly weak.  The actual
+  // invariant here is that StrideSCEV should be loop invariant.  The only
+  // such invariant strides we happen to speculate right now are unknowns
+  // and thus this is a reasonable proxy of the actual invariant.
   assert(isa<SCEVUnknown>(StrideSCEV) && "shouldn't be in map");
 
+  ScalarEvolution *SE = PSE.getSE();
   const auto *CT = SE->getOne(StrideSCEV->getType());
   PSE.addPredicate(*SE->getEqualPredicate(StrideSCEV, CT));
   auto *Expr = PSE.getSCEV(Ptr);
@@ -658,7 +660,7 @@ class AccessAnalysis {
   /// the bounds of the pointer.
   bool createCheckForAccess(RuntimePointerChecking &RtCheck,
                             MemAccessInfo Access, Type *AccessTy,
-                            const ValueToValueMap &Strides,
+                            const DenseMap<Value *, const SCEV *> &Strides,
                             DenseMap<Value *, unsigned> &DepSetId,
                             Loop *TheLoop, unsigned &RunningDepId,
                             unsigned ASId, bool ShouldCheckStride, bool Assume);
@@ -669,7 +671,7 @@ class AccessAnalysis {
   /// Returns true if we need no check or if we do and we can generate them
   /// (i.e. the pointers have computable bounds).
   bool canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE,
-                       Loop *TheLoop, const ValueToValueMap &Strides,
+                       Loop *TheLoop, const DenseMap<Value *, const SCEV *> &Strides,
                        Value *&UncomputablePtr, bool ShouldCheckWrap = false);
 
   /// Goes over all memory accesses, checks whether a RT check is needed
@@ -764,7 +766,7 @@ static bool hasComputableBounds(PredicatedScalarEvolution &PSE, Value *Ptr,
 
 /// Check whether a pointer address cannot wrap.
 static bool isNoWrap(PredicatedScalarEvolution &PSE,
-                     const ValueToValueMap &Strides, Value *Ptr, Type *AccessTy,
+                     const DenseMap<Value *, const SCEV *> &Strides, Value *Ptr, Type *AccessTy,
                      Loop *L) {
   const SCEV *PtrScev = PSE.getSCEV(Ptr);
   if (PSE.getSE()->isLoopInvariant(PtrScev, L))
@@ -957,7 +959,7 @@ static void findForkedSCEVs(
 
 static SmallVector<PointerIntPair<const SCEV *, 1, bool>>
 findForkedPointer(PredicatedScalarEvolution &PSE,
-                  const ValueToValueMap &StridesMap, Value *Ptr,
+                  const DenseMap<Value *, const SCEV *> &StridesMap, Value *Ptr,
                   const Loop *L) {
   ScalarEvolution *SE = PSE.getSE();
   assert(SE->isSCEVable(Ptr->getType()) && "Value is not SCEVable!");
@@ -982,7 +984,7 @@ findForkedPointer(PredicatedScalarEvolution &PSE,
 
 bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck,
                                           MemAccessInfo Access, Type *AccessTy,
-                                          const ValueToValueMap &StridesMap,
+                                          const DenseMap<Value *, const SCEV *> &StridesMap,
                                           DenseMap<Value *, unsigned> &DepSetId,
                                           Loop *TheLoop, unsigned &RunningDepId,
                                           unsigned ASId, bool ShouldCheckWrap,
@@ -1043,7 +1045,7 @@ bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck,
 
 bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,
                                      ScalarEvolution *SE, Loop *TheLoop,
-                                     const ValueToValueMap &StridesMap,
+                                     const DenseMap<Value *, const SCEV *> &StridesMap,
                                      Value *&UncomputablePtr, bool ShouldCheckWrap) {
   // Find pointers with computable bounds. We are going to use this information
   // to place a runtime bound check.
@@ -1373,7 +1375,7 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR,
 std::optional<int64_t> llvm::getPtrStride(PredicatedScalarEvolution &PSE,
                                           Type *AccessTy, Value *Ptr,
                                           const Loop *Lp,
-                                          const ValueToValueMap &StridesMap,
+                                          const DenseMap<Value *, const SCEV *> &StridesMap,
                                           bool Assume, bool ShouldCheckWrap) {
   Type *Ty = Ptr->getType();
   assert(Ty->isPointerTy() && "Unexpected non-ptr");
@@ -1822,7 +1824,7 @@ static bool areStridedAccessesIndependent(uint64_t Distance, uint64_t Stride,
 MemoryDepChecker::Dependence::DepType
 MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
                               const MemAccessInfo &B, unsigned BIdx,
-                              const ValueToValueMap &Strides) {
+                              const DenseMap<Value *, const SCEV *> &Strides) {
   assert (AIdx < BIdx && "Must pass arguments in program order");
 
   auto [APtr, AIsWrite] = A;
@@ -2016,7 +2018,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
 
 bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets,
                                    MemAccessInfoList &CheckDeps,
-                                   const ValueToValueMap &Strides) {
+                                   const DenseMap<Value *, const SCEV *> &Strides) {
 
   MaxSafeDepDistBytes = -1;
   SmallPtrSet<MemAccessInfo, 8> Visited;
@@ -2691,6 +2693,12 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   if (!Ptr)
     return;
 
+  // Note: getStrideFromPointer is a *profitability* heuristic.  We
+  // could broaden the scope of values returned here - to anything
+  // which happens to be loop invariant and contributes to the
+  // computation of an interesting IV - but we chose not to as we
+  // don't have a cost model here, and broadening the scope exposes
+  // far too many unprofitable cases.
   Value *Stride = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
   if (!Stride)
     return;
@@ -2746,7 +2754,10 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   }
   LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n");
 
-  SymbolicStrides[Ptr] = Stride;
+  // Strip back off the integer cast, and check that our result is a
+  // SCEVUnknown as we expect.
+  Value *StrideVal = stripIntegerCast(Stride);
+  SymbolicStrides[Ptr] = cast<SCEVUnknown>(PSE->getSCEV(StrideVal));
   StrideSet.insert(Stride);
 }
 

diff  --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 423ef670bb3c2..5da21e75a20cb 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1011,7 +1011,7 @@ bool InterleavedAccessInfo::isStrided(int Stride) {
 
 void InterleavedAccessInfo::collectConstStrideAccesses(
     MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
-    const ValueToValueMap &Strides) {
+    const DenseMap<Value*, const SCEV*> &Strides) {
   auto &DL = TheLoop->getHeader()->getModule()->getDataLayout();
 
   // Since it's desired that the load/store instructions be maintained in
@@ -1091,7 +1091,7 @@ void InterleavedAccessInfo::collectConstStrideAccesses(
 void InterleavedAccessInfo::analyzeInterleaving(
                                  bool EnablePredicatedInterleavedMemAccesses) {
   LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n");
-  const ValueToValueMap &Strides = LAI->getSymbolicStrides();
+  const auto &Strides = LAI->getSymbolicStrides();
 
   // Holds all accesses with a constant stride.
   MapVector<Instruction *, StrideDescriptor> AccessStrideInfo;

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 80a3f13f304b0..ffefb94d5dbd2 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3477,7 +3477,7 @@ InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
 
 static bool containsDecreasingPointers(Loop *TheLoop,
                                        PredicatedScalarEvolution *PSE) {
-  const ValueToValueMap &Strides = ValueToValueMap();
+  const auto &Strides = DenseMap<Value *, const SCEV *>();
   for (BasicBlock *BB : TheLoop->blocks()) {
     // Scan the instructions in the block and look for addresses that are
     // consecutive and decreasing.

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
index 3a868e8625a85..a2b5c04dfd149 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
@@ -456,8 +456,8 @@ int LoopVectorizationLegality::isConsecutivePtr(Type *AccessTy,
   // it's collected.  This happens from canVectorizeWithIfConvert, when the
   // pointer is checked to reference consecutive elements suitable for a
   // masked access.
-  const ValueToValueMap &Strides =
-    LAI ? LAI->getSymbolicStrides() : ValueToValueMap();
+  const auto &Strides =
+    LAI ? LAI->getSymbolicStrides() : DenseMap<Value *, const SCEV *>();
 
   Function *F = TheLoop->getHeader()->getParent();
   bool OptForSize = F->hasOptSize() ||


        


More information about the llvm-commits mailing list