[llvm] [LAA] refactor program logic (NFC) (PR #92101)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 14 04:38:07 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Ramkumar Ramachandra (artagnon)

<details>
<summary>Changes</summary>

Implement NFC improvements spotted during a cursory reading of LoopAccessAnalysis.

---
Full diff: https://github.com/llvm/llvm-project/pull/92101.diff


2 Files Affected:

- (modified) llvm/include/llvm/Analysis/LoopAccessAnalysis.h (+1-1) 
- (modified) llvm/lib/Analysis/LoopAccessAnalysis.cpp (+41-68) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index 6ebd0fb8477a0..c22e1d470f380 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -540,7 +540,7 @@ class RuntimePointerChecking {
   /// Try to create add a new (pointer-difference, access size) pair to
   /// DiffCheck for checking groups \p CGI and \p CGJ. If pointer-difference
   /// checks cannot be used for the groups, set CanUseDiffCheck to false.
-  void tryToCreateDiffCheck(const RuntimeCheckingPtrGroup &CGI,
+  bool tryToCreateDiffCheck(const RuntimeCheckingPtrGroup &CGI,
                             const RuntimeCheckingPtrGroup &CGJ);
 
   MemoryDepChecker &DC;
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index d071e53324408..92e3d118fd6ae 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -250,18 +250,13 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
                         NeedsFreeze);
 }
 
-void RuntimePointerChecking::tryToCreateDiffCheck(
+bool RuntimePointerChecking::tryToCreateDiffCheck(
     const RuntimeCheckingPtrGroup &CGI, const RuntimeCheckingPtrGroup &CGJ) {
-  if (!CanUseDiffCheck)
-    return;
-
   // If either group contains multiple different pointers, bail out.
   // TODO: Support multiple pointers by using the minimum or maximum pointer,
   // depending on src & sink.
-  if (CGI.Members.size() != 1 || CGJ.Members.size() != 1) {
-    CanUseDiffCheck = false;
-    return;
-  }
+  if (CGI.Members.size() != 1 || CGJ.Members.size() != 1)
+    return false;
 
   PointerInfo *Src = &Pointers[CGI.Members[0]];
   PointerInfo *Sink = &Pointers[CGJ.Members[0]];
@@ -269,10 +264,8 @@ void RuntimePointerChecking::tryToCreateDiffCheck(
   // If either pointer is read and written, multiple checks may be needed. Bail
   // out.
   if (!DC.getOrderForAccess(Src->PointerValue, !Src->IsWritePtr).empty() ||
-      !DC.getOrderForAccess(Sink->PointerValue, !Sink->IsWritePtr).empty()) {
-    CanUseDiffCheck = false;
-    return;
-  }
+      !DC.getOrderForAccess(Sink->PointerValue, !Sink->IsWritePtr).empty())
+    return false;
 
   ArrayRef<unsigned> AccSrc =
       DC.getOrderForAccess(Src->PointerValue, Src->IsWritePtr);
@@ -280,10 +273,9 @@ void RuntimePointerChecking::tryToCreateDiffCheck(
       DC.getOrderForAccess(Sink->PointerValue, Sink->IsWritePtr);
   // If either pointer is accessed multiple times, there may not be a clear
   // src/sink relation. Bail out for now.
-  if (AccSrc.size() != 1 || AccSink.size() != 1) {
-    CanUseDiffCheck = false;
-    return;
-  }
+  if (AccSrc.size() != 1 || AccSink.size() != 1)
+    return false;
+
   // If the sink is accessed before src, swap src/sink.
   if (AccSink[0] < AccSrc[0])
     std::swap(Src, Sink);
@@ -291,10 +283,8 @@ void RuntimePointerChecking::tryToCreateDiffCheck(
   auto *SrcAR = dyn_cast<SCEVAddRecExpr>(Src->Expr);
   auto *SinkAR = dyn_cast<SCEVAddRecExpr>(Sink->Expr);
   if (!SrcAR || !SinkAR || SrcAR->getLoop() != DC.getInnermostLoop() ||
-      SinkAR->getLoop() != DC.getInnermostLoop()) {
-    CanUseDiffCheck = false;
-    return;
-  }
+      SinkAR->getLoop() != DC.getInnermostLoop())
+    return false;
 
   SmallVector<Instruction *, 4> SrcInsts =
       DC.getInstructionsForAccess(Src->PointerValue, Src->IsWritePtr);
@@ -302,10 +292,9 @@ void RuntimePointerChecking::tryToCreateDiffCheck(
       DC.getInstructionsForAccess(Sink->PointerValue, Sink->IsWritePtr);
   Type *SrcTy = getLoadStoreType(SrcInsts[0]);
   Type *DstTy = getLoadStoreType(SinkInsts[0]);
-  if (isa<ScalableVectorType>(SrcTy) || isa<ScalableVectorType>(DstTy)) {
-    CanUseDiffCheck = false;
-    return;
-  }
+  if (isa<ScalableVectorType>(SrcTy) || isa<ScalableVectorType>(DstTy))
+    return false;
+
   const DataLayout &DL =
       SinkAR->getLoop()->getHeader()->getModule()->getDataLayout();
   unsigned AllocSize =
@@ -316,10 +305,8 @@ void RuntimePointerChecking::tryToCreateDiffCheck(
   // future.
   auto *Step = dyn_cast<SCEVConstant>(SinkAR->getStepRecurrence(*SE));
   if (!Step || Step != SrcAR->getStepRecurrence(*SE) ||
-      Step->getAPInt().abs() != AllocSize) {
-    CanUseDiffCheck = false;
-    return;
-  }
+      Step->getAPInt().abs() != AllocSize)
+    return false;
 
   IntegerType *IntTy =
       IntegerType::get(Src->PointerValue->getContext(),
@@ -332,10 +319,8 @@ void RuntimePointerChecking::tryToCreateDiffCheck(
   const SCEV *SinkStartInt = SE->getPtrToIntExpr(SinkAR->getStart(), IntTy);
   const SCEV *SrcStartInt = SE->getPtrToIntExpr(SrcAR->getStart(), IntTy);
   if (isa<SCEVCouldNotCompute>(SinkStartInt) ||
-      isa<SCEVCouldNotCompute>(SrcStartInt)) {
-    CanUseDiffCheck = false;
-    return;
-  }
+      isa<SCEVCouldNotCompute>(SrcStartInt))
+    return false;
 
   const Loop *InnerLoop = SrcAR->getLoop();
   // If the start values for both Src and Sink also vary according to an outer
@@ -356,8 +341,7 @@ void RuntimePointerChecking::tryToCreateDiffCheck(
             SinkStartAR->getStepRecurrence(*SE)) {
       LLVM_DEBUG(dbgs() << "LAA: Not creating diff runtime check, since these "
                            "cannot be hoisted out of the outer loop\n");
-      CanUseDiffCheck = false;
-      return;
+      return false;
     }
   }
 
@@ -366,6 +350,7 @@ void RuntimePointerChecking::tryToCreateDiffCheck(
                     << "SinkStartInt: " << *SinkStartInt << '\n');
   DiffChecks.emplace_back(SrcStartInt, SinkStartInt, AllocSize,
                           Src->NeedsFreeze || Sink->NeedsFreeze);
+  return true;
 }
 
 SmallVector<RuntimePointerCheck, 4> RuntimePointerChecking::generateChecks() {
@@ -377,7 +362,7 @@ SmallVector<RuntimePointerCheck, 4> RuntimePointerChecking::generateChecks() {
       const RuntimeCheckingPtrGroup &CGJ = CheckingGroups[J];
 
       if (needsChecking(CGI, CGJ)) {
-        tryToCreateDiffCheck(CGI, CGJ);
+        CanUseDiffCheck = CanUseDiffCheck && tryToCreateDiffCheck(CGI, CGJ);
         Checks.push_back(std::make_pair(&CGI, &CGJ));
       }
     }
@@ -394,9 +379,9 @@ void RuntimePointerChecking::generateChecks(
 
 bool RuntimePointerChecking::needsChecking(
     const RuntimeCheckingPtrGroup &M, const RuntimeCheckingPtrGroup &N) const {
-  for (unsigned I = 0, EI = M.Members.size(); EI != I; ++I)
-    for (unsigned J = 0, EJ = N.Members.size(); EJ != J; ++J)
-      if (needsChecking(M.Members[I], N.Members[J]))
+  for (auto &I : M.Members)
+    for (auto &J : N.Members)
+      if (needsChecking(I, J))
         return true;
   return false;
 }
@@ -410,9 +395,7 @@ static const SCEV *getMinFromExprs(const SCEV *I, const SCEV *J,
 
   if (!C)
     return nullptr;
-  if (C->getValue()->isNegative())
-    return J;
-  return I;
+  return C->getValue()->isNegative() ? J : I;
 }
 
 bool RuntimeCheckingPtrGroup::addPointer(unsigned Index,
@@ -1646,22 +1629,19 @@ bool llvm::sortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy,
 
     // Check if the pointer with the same offset is found.
     int64_t Offset = *Diff;
-    auto Res = Offsets.emplace(Offset, Cnt);
-    if (!Res.second)
+    auto [It, IsInserted] = Offsets.emplace(Offset, Cnt);
+    if (!IsInserted)
       return false;
     // Consecutive order if the inserted element is the last one.
-    IsConsecutive = IsConsecutive && std::next(Res.first) == Offsets.end();
+    IsConsecutive = IsConsecutive && std::next(It) == Offsets.end();
     ++Cnt;
   }
   SortedIndices.clear();
   if (!IsConsecutive) {
     // Fill SortedIndices array only if it is non-consecutive.
     SortedIndices.resize(VL.size());
-    Cnt = 0;
-    for (const std::pair<int64_t, int> &Pair : Offsets) {
-      SortedIndices[Cnt] = Pair.second;
-      ++Cnt;
-    }
+    for (const std::pair<int64_t, int> &Pair : Offsets)
+      SortedIndices.push_back(Pair.second);
   }
   return true;
 }
@@ -1866,10 +1846,7 @@ static bool isSafeDependenceDistance(const DataLayout &DL, ScalarEvolution &SE,
   // (If so, then we have proven (**) because |Dist| >= -1*Dist)
   const SCEV *NegDist = SE.getNegativeSCEV(CastedDist);
   Minus = SE.getMinusSCEV(NegDist, CastedProduct);
-  if (SE.isKnownPositive(Minus))
-    return true;
-
-  return false;
+  return SE.isKnownPositive(Minus);
 }
 
 /// Check the dependence for two accesses with the same stride \p Stride.
@@ -2043,7 +2020,7 @@ MemoryDepChecker::Dependence::DepType MemoryDepChecker::isDependent(
   if (isa<SCEVCouldNotCompute>(Dist)) {
     // TODO: Relax requirement that there is a common stride to retry with
     // non-constant distance dependencies.
-    FoundNonConstantDistanceDependence |= !!CommonStride;
+    FoundNonConstantDistanceDependence |= CommonStride.has_value();
     LLVM_DEBUG(dbgs() << "LAA: Dependence because of uncomputable distance.\n");
     return Dependence::Unknown;
   }
@@ -2082,14 +2059,12 @@ MemoryDepChecker::Dependence::DepType MemoryDepChecker::isDependent(
   // Negative distances are not plausible dependencies.
   if (SE.isKnownNonPositive(Dist)) {
     if (SE.isKnownNonNegative(Dist)) {
-      if (HasSameSize) {
+      if (HasSameSize)
         // Write to the same location with the same size.
         return Dependence::Forward;
-      } else {
-        LLVM_DEBUG(dbgs() << "LAA: possibly zero dependence difference but "
-                             "different type sizes\n");
-        return Dependence::Unknown;
-      }
+      LLVM_DEBUG(dbgs() << "LAA: possibly zero dependence difference but "
+                           "different type sizes\n");
+      return Dependence::Unknown;
     }
 
     bool IsTrueDataDependence = (AIsWrite && !BIsWrite);
@@ -2335,7 +2310,7 @@ bool MemoryDepChecker::areDepsSafe(
           }
         ++OI;
       }
-      AI++;
+      ++AI;
     }
   }
 
@@ -2344,8 +2319,8 @@ bool MemoryDepChecker::areDepsSafe(
 }
 
 SmallVector<Instruction *, 4>
-MemoryDepChecker::getInstructionsForAccess(Value *Ptr, bool isWrite) const {
-  MemAccessInfo Access(Ptr, isWrite);
+MemoryDepChecker::getInstructionsForAccess(Value *Ptr, bool IsWrite) const {
+  MemAccessInfo Access(Ptr, IsWrite);
   auto &IndexVector = Accesses.find(Access)->second;
 
   SmallVector<Instruction *, 4> Insts;
@@ -2656,7 +2631,7 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI,
                                SymbolicStrides, UncomputablePtr, false);
   if (!CanDoRTIfNeeded) {
     auto *I = dyn_cast_or_null<Instruction>(UncomputablePtr);
-    recordAnalysis("CantIdentifyArrayBounds", I) 
+    recordAnalysis("CantIdentifyArrayBounds", I)
         << "cannot identify array bounds";
     LLVM_DEBUG(dbgs() << "LAA: We can't vectorize because we can't find "
                       << "the array bounds.\n");
@@ -3050,11 +3025,10 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
   if (TTI) {
     TypeSize FixedWidth =
         TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector);
-    if (FixedWidth.isNonZero()) {
+    if (FixedWidth.isNonZero())
       // Scale the vector width by 2 as rough estimate to also consider
       // interleaving.
       MaxTargetVectorWidthInBits = FixedWidth.getFixedValue() * 2;
-    }
 
     TypeSize ScalableWidth =
         TTI->getRegisterBitWidth(TargetTransformInfo::RGK_ScalableVector);
@@ -3064,9 +3038,8 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
   DepChecker =
       std::make_unique<MemoryDepChecker>(*PSE, L, MaxTargetVectorWidthInBits);
   PtrRtChecking = std::make_unique<RuntimePointerChecking>(*DepChecker, SE);
-  if (canAnalyzeLoop()) {
+  if (canAnalyzeLoop())
     analyzeLoop(AA, LI, TLI, DT);
-  }
 }
 
 void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const {

``````````

</details>


https://github.com/llvm/llvm-project/pull/92101


More information about the llvm-commits mailing list