[llvm] 8982786 - [SLP][NFC]Make canVectorizeLoads member of BoUpSLP class, NFC.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 07:10:40 PST 2024


Author: Alexey Bataev
Date: 2024-03-04T07:10:27-08:00
New Revision: 89827863a3a65947e1665b436ca8dd1fbdb042fb

URL: https://github.com/llvm/llvm-project/commit/89827863a3a65947e1665b436ca8dd1fbdb042fb
DIFF: https://github.com/llvm/llvm-project/commit/89827863a3a65947e1665b436ca8dd1fbdb042fb.diff

LOG: [SLP][NFC]Make canVectorizeLoads member of BoUpSLP class, NFC.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index daea3bdce68893..e0e3648f718eb8 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -980,6 +980,14 @@ class BoUpSLP {
   class ShuffleInstructionBuilder;
 
 public:
+  /// Tracks the state we can represent the loads in the given sequence.
+  enum class LoadsState {
+    Gather,
+    Vectorize,
+    ScatterVectorize,
+    StridedVectorize
+  };
+
   using ValueList = SmallVector<Value *, 8>;
   using InstrList = SmallVector<Instruction *, 16>;
   using ValueSet = SmallPtrSet<Value *, 16>;
@@ -1184,6 +1192,19 @@ class BoUpSLP {
   ///       may not be necessary.
   bool isLoadCombineCandidate() const;
 
+  /// Checks if the given array of loads can be represented as a vectorized,
+  /// scatter or just simple gather.
+  /// \param VL list of loads.
+  /// \param VL0 main load value.
+  /// \param Order returned order of load instructions.
+  /// \param PointerOps returned list of pointer operands.
+  /// \param TryRecursiveCheck used to check if long masked gather can be
+  /// represented as a serie of loads/insert subvector, if profitable.
+  LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
+                               SmallVectorImpl<unsigned> &Order,
+                               SmallVectorImpl<Value *> &PointerOps,
+                               bool TryRecursiveCheck = true) const;
+
   OptimizationRemarkEmitter *getORE() { return ORE; }
 
   /// This structure holds any data we need about the edges being traversed
@@ -3957,11 +3978,6 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) {
   return std::move(CurrentOrder);
 }
 
-namespace {
-/// Tracks the state we can represent the loads in the given sequence.
-enum class LoadsState { Gather, Vectorize, ScatterVectorize, StridedVectorize };
-} // anonymous namespace
-
 static bool arePointersCompatible(Value *Ptr1, Value *Ptr2,
                                   const TargetLibraryInfo &TLI,
                                   bool CompareOpcodes = true) {
@@ -3998,16 +4014,9 @@ static bool isReverseOrder(ArrayRef<unsigned> Order) {
   });
 }
 
-/// Checks if the given array of loads can be represented as a vectorized,
-/// scatter or just simple gather.
-static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
-                                    const Value *VL0,
-                                    const TargetTransformInfo &TTI,
-                                    const DataLayout &DL, ScalarEvolution &SE,
-                                    LoopInfo &LI, const TargetLibraryInfo &TLI,
-                                    SmallVectorImpl<unsigned> &Order,
-                                    SmallVectorImpl<Value *> &PointerOps,
-                                    bool TryRecursiveCheck = true) {
+BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
+    ArrayRef<Value *> VL, const Value *VL0, SmallVectorImpl<unsigned> &Order,
+    SmallVectorImpl<Value *> &PointerOps, bool TryRecursiveCheck) const {
   // Check that a vectorized load would load the same memory as a scalar
   // load. For example, we don't want to vectorize loads that are smaller
   // than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM
@@ -4016,7 +4025,7 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
   // unvectorized version.
   Type *ScalarTy = VL0->getType();
 
-  if (DL.getTypeSizeInBits(ScalarTy) != DL.getTypeAllocSizeInBits(ScalarTy))
+  if (DL->getTypeSizeInBits(ScalarTy) != DL->getTypeAllocSizeInBits(ScalarTy))
     return LoadsState::Gather;
 
   // Make sure all loads in the bundle are simple - we can't vectorize
@@ -4036,9 +4045,9 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
   Order.clear();
   auto *VecTy = FixedVectorType::get(ScalarTy, Sz);
   // Check the order of pointer operands or that all pointers are the same.
-  bool IsSorted = sortPtrAccesses(PointerOps, ScalarTy, DL, SE, Order);
+  bool IsSorted = sortPtrAccesses(PointerOps, ScalarTy, *DL, *SE, Order);
   if (IsSorted || all_of(PointerOps, [&](Value *P) {
-        return arePointersCompatible(P, PointerOps.front(), TLI);
+        return arePointersCompatible(P, PointerOps.front(), *TLI);
       })) {
     if (IsSorted) {
       Value *Ptr0;
@@ -4051,7 +4060,7 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
         PtrN = PointerOps[Order.back()];
       }
       std::optional<int> Diff =
-          getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, DL, SE);
+          getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE);
       // Check that the sorted loads are consecutive.
       if (static_cast<unsigned>(*Diff) == Sz - 1)
         return LoadsState::Vectorize;
@@ -4078,7 +4087,7 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
           Align Alignment =
               cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
                   ->getAlign();
-          if (TTI.isLegalStridedLoadStore(VecTy, Alignment)) {
+          if (TTI->isLegalStridedLoadStore(VecTy, Alignment)) {
             // Iterate through all pointers and check if all distances are
             // unique multiple of Dist.
             SmallSet<int, 4> Dists;
@@ -4087,7 +4096,8 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
               if (Ptr == PtrN)
                 Dist = *Diff;
               else if (Ptr != Ptr0)
-                Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, DL, SE);
+                Dist =
+                    *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
               // If the strides are not the same or repeated, we can't
               // vectorize.
               if (((Dist / Stride) * Stride) != Dist ||
@@ -4100,11 +4110,11 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
         }
       }
     }
-    auto CheckForShuffledLoads = [&](Align CommonAlignment) {
-      unsigned Sz = DL.getTypeSizeInBits(ScalarTy);
-      unsigned MinVF = R.getMinVF(Sz);
+    auto CheckForShuffledLoads = [&, &TTI = *TTI](Align CommonAlignment) {
+      unsigned Sz = DL->getTypeSizeInBits(ScalarTy);
+      unsigned MinVF = getMinVF(Sz);
       unsigned MaxVF = std::max<unsigned>(bit_floor(VL.size() / 2), MinVF);
-      MaxVF = std::min(R.getMaximumVF(Sz, Instruction::Load), MaxVF);
+      MaxVF = std::min(getMaximumVF(Sz, Instruction::Load), MaxVF);
       for (unsigned VF = MaxVF; VF >= MinVF; VF /= 2) {
         unsigned VectorizedCnt = 0;
         SmallVector<LoadsState> States;
@@ -4114,8 +4124,8 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
           SmallVector<unsigned> Order;
           SmallVector<Value *> PointerOps;
           LoadsState LS =
-              canVectorizeLoads(R, Slice, Slice.front(), TTI, DL, SE, LI, TLI,
-                                Order, PointerOps, /*TryRecursiveCheck=*/false);
+              canVectorizeLoads(Slice, Slice.front(), Order, PointerOps,
+                                /*TryRecursiveCheck=*/false);
           // Check that the sorted loads are consecutive.
           if (LS == LoadsState::Gather)
             break;
@@ -4175,7 +4185,7 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
     // TODO: need to improve analysis of the pointers, if not all of them are
     // GEPs or have > 2 operands, we end up with a gather node, which just
     // increases the cost.
-    Loop *L = LI.getLoopFor(cast<LoadInst>(VL0)->getParent());
+    Loop *L = LI->getLoopFor(cast<LoadInst>(VL0)->getParent());
     bool ProfitableGatherPointers =
         L && Sz > 2 && count_if(PointerOps, [L](Value *V) {
                          return L->isLoopInvariant(V);
@@ -4187,8 +4197,8 @@ static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
                   isa<Constant, Instruction>(GEP->getOperand(1)));
         })) {
       Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
-      if (TTI.isLegalMaskedGather(VecTy, CommonAlignment) &&
-          !TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment)) {
+      if (TTI->isLegalMaskedGather(VecTy, CommonAlignment) &&
+          !TTI->forceScalarizeMaskedGather(VecTy, CommonAlignment)) {
         // Check if potential masked gather can be represented as series
         // of loads + insertsubvectors.
         if (TryRecursiveCheck && CheckForShuffledLoads(CommonAlignment)) {
@@ -5635,8 +5645,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
     // treats loading/storing it as an i8 struct. If we vectorize loads/stores
     // from such a struct, we read/write packed bits disagreeing with the
     // unvectorized version.
-    switch (canVectorizeLoads(*this, VL, VL0, *TTI, *DL, *SE, *LI, *TLI,
-                              CurrentOrder, PointerOps)) {
+    switch (canVectorizeLoads(VL, VL0, CurrentOrder, PointerOps)) {
     case LoadsState::Vectorize:
       return TreeEntry::Vectorize;
     case LoadsState::ScatterVectorize:
@@ -7416,9 +7425,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
               !VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) {
             SmallVector<Value *> PointerOps;
             OrdersType CurrentOrder;
-            LoadsState LS =
-                canVectorizeLoads(R, Slice, Slice.front(), TTI, *R.DL, *R.SE,
-                                  *R.LI, *R.TLI, CurrentOrder, PointerOps);
+            LoadsState LS = R.canVectorizeLoads(Slice, Slice.front(),
+                                                CurrentOrder, PointerOps);
             switch (LS) {
             case LoadsState::Vectorize:
             case LoadsState::ScatterVectorize:


        


More information about the llvm-commits mailing list