[llvm] [LoopVectorize] Add support for reverse loops in isDereferenceableAndAlignedInLoop (PR #96752)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 6 02:14:59 PST 2024


================
@@ -276,84 +277,85 @@ static bool AreEquivalentAddressValues(const Value *A, const Value *B) {
 bool llvm::isDereferenceableAndAlignedInLoop(
     LoadInst *LI, Loop *L, ScalarEvolution &SE, DominatorTree &DT,
     AssumptionCache *AC, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
-  auto &DL = LI->getDataLayout();
-  Value *Ptr = LI->getPointerOperand();
-
-  APInt EltSize(DL.getIndexTypeSizeInBits(Ptr->getType()),
-                DL.getTypeStoreSize(LI->getType()).getFixedValue());
-  const Align Alignment = LI->getAlign();
+  const SCEV *Ptr = SE.getSCEV(LI->getPointerOperand());
+  auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ptr);
 
-  Instruction *HeaderFirstNonPHI = L->getHeader()->getFirstNonPHI();
-
-  // If given a uniform (i.e. non-varying) address, see if we can prove the
-  // access is safe within the loop w/o needing predication.
-  if (L->isLoopInvariant(Ptr))
-    return isDereferenceableAndAlignedPointer(Ptr, Alignment, EltSize, DL,
-                                              HeaderFirstNonPHI, AC, &DT);
-
-  // Otherwise, check to see if we have a repeating access pattern where we can
-  // prove that all accesses are well aligned and dereferenceable.
-  auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Ptr));
+  // Check to see if we have a repeating access pattern and it's possible
+  // to prove all accesses are well aligned.
   if (!AddRec || AddRec->getLoop() != L || !AddRec->isAffine())
     return false;
+
   auto* Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(SE));
   if (!Step)
     return false;
 
-  auto TC = SE.getSmallConstantMaxTripCount(L, Predicates);
-  if (!TC)
+  // For the moment, restrict ourselves to the case where the access size is a
+  // multiple of the requested alignment and the base is aligned.
+  // TODO: generalize if a case found which warrants
+  const Align Alignment = LI->getAlign();
+  auto &DL = LI->getDataLayout();
+  APInt EltSize(DL.getIndexTypeSizeInBits(Ptr->getType()),
+                DL.getTypeStoreSize(LI->getType()).getFixedValue());
+  if (EltSize.urem(Alignment.value()) != 0)
     return false;
 
   // TODO: Handle overlapping accesses.
-  // We should be computing AccessSize as (TC - 1) * Step + EltSize.
-  if (EltSize.sgt(Step->getAPInt()))
+  if (EltSize.ugt(Step->getAPInt().abs()))
     return false;
 
-  // Compute the total access size for access patterns with unit stride and
-  // patterns with gaps. For patterns with unit stride, Step and EltSize are the
-  // same.
-  // For patterns with gaps (i.e. non unit stride), we are
-  // accessing EltSize bytes at every Step.
-  APInt AccessSize = TC * Step->getAPInt();
+  const SCEV *MaxBECount =
+      SE.getPredicatedSymbolicMaxBackedgeTakenCount(L, *Predicates);
+  if (isa<SCEVCouldNotCompute>(MaxBECount))
+    return false;
 
-  assert(SE.isLoopInvariant(AddRec->getStart(), L) &&
-         "implied by addrec definition");
-  Value *Base = nullptr;
-  if (auto *StartS = dyn_cast<SCEVUnknown>(AddRec->getStart())) {
-    Base = StartS->getValue();
-  } else if (auto *StartS = dyn_cast<SCEVAddExpr>(AddRec->getStart())) {
-    // Handle (NewBase + offset) as start value.
-    const auto *Offset = dyn_cast<SCEVConstant>(StartS->getOperand(0));
-    const auto *NewBase = dyn_cast<SCEVUnknown>(StartS->getOperand(1));
-    if (StartS->getNumOperands() == 2 && Offset && NewBase) {
-      // The following code below assumes the offset is unsigned, but GEP
-      // offsets are treated as signed so we can end up with a signed value
-      // here too. For example, suppose the initial PHI value is (i8 255),
-      // the offset will be treated as (i8 -1) and sign-extended to (i64 -1).
-      if (Offset->getAPInt().isNegative())
-        return false;
+  const auto &[AccessStart, AccessEnd] =
+      getStartAndEndForAccess(L, Ptr, LI->getType(), MaxBECount, &SE, nullptr);
+  if (isa<SCEVCouldNotCompute>(AccessStart) ||
+      isa<SCEVCouldNotCompute>(AccessEnd))
+    return false;
 
-      // For the moment, restrict ourselves to the case where the offset is a
-      // multiple of the requested alignment and the base is aligned.
-      // TODO: generalize if a case found which warrants
-      if (Offset->getAPInt().urem(Alignment.value()) != 0)
-        return false;
-      Base = NewBase->getValue();
-      bool Overflow = false;
-      AccessSize = AccessSize.uadd_ov(Offset->getAPInt(), Overflow);
-      if (Overflow)
-        return false;
-    }
-  }
+  // Try to get the access size.
+  const SCEV *PtrDiff = SE.getMinusSCEV(AccessEnd, AccessStart);
+  APInt MaxPtrDiff = SE.getUnsignedRangeMax(PtrDiff);
 
-  if (!Base)
+  // If the (max) pointer difference is > 32 bits then it's unlikely to be
+  // dereferenceable.
+  if (MaxPtrDiff.getActiveBits() > 32)
----------------
david-arm wrote:

Again, I put this code because in order to reuse `getStartAndEndForAccess` I have to pass in a SCEV value for the backedge taken count, since the original version uses it. That also means I can no longer call `getSmallConstantMaxTripCount` since that returns an unsigned value, however `getSmallConstantMaxTripCount` returns 0 if the max count > 32 bits. In order to maintain the same behaviour I added the explicit check here. I'm happy to remove it though if you feel comfortable with removing this restriction?

https://github.com/llvm/llvm-project/pull/96752


More information about the llvm-commits mailing list