[llvm] 7692d10 - [SLP][NFC]Remove dead code + use nlogn lookups instead of n^2

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 4 15:33:45 PDT 2024


Author: Alexey Bataev
Date: 2024-10-04T15:32:04-07:00
New Revision: 7692d106b480a16861f5ed63378ec1b857a20bc1

URL: https://github.com/llvm/llvm-project/commit/7692d106b480a16861f5ed63378ec1b857a20bc1
DIFF: https://github.com/llvm/llvm-project/commit/7692d106b480a16861f5ed63378ec1b857a20bc1.diff

LOG: [SLP][NFC]Remove dead code + use nlogn lookups instead of n^2

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index dc9ad5335f8a52..401597af35bdac 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6525,48 +6525,38 @@ static void gatherPossiblyVectorizableLoads(
   Type *ScalarTy = getValueType(VL.front());
   if (!isValidElementType(ScalarTy))
     return;
-  const int NumScalars = VL.size();
-  int NumParts = 1;
-  if (NumScalars > 1) {
-    auto *VecTy = getWidenedType(ScalarTy, NumScalars);
-    NumParts = TTI.getNumberOfParts(VecTy);
-    if (NumParts == 0 || NumParts >= NumScalars ||
-        VecTy->getNumElements() % NumParts != 0 ||
-        !hasFullVectorsOrPowerOf2(TTI, VecTy->getElementType(),
-                                  VecTy->getNumElements() / NumParts))
-      NumParts = 1;
-  }
-  unsigned VF = PowerOf2Ceil(NumScalars / NumParts);
   SmallVector<SmallVector<std::pair<LoadInst *, int>>> ClusteredLoads;
-  for (int I : seq<int>(NumParts)) {
-    for (Value *V :
-         VL.slice(I * VF, std::min<unsigned>(VF, VL.size() - I * VF))) {
-      auto *LI = dyn_cast<LoadInst>(V);
-      if (!LI)
+  SmallVector<DenseMap<int, LoadInst *>> ClusteredDistToLoad;
+  for (Value *V : VL) {
+    auto *LI = dyn_cast<LoadInst>(V);
+    if (!LI)
+      continue;
+    if (R.isDeleted(LI) || R.isVectorized(LI) || !LI->isSimple())
+      continue;
+    bool IsFound = false;
+    for (auto [Map, Data] : zip(ClusteredDistToLoad, ClusteredLoads)) {
+      if (LI->getParent() != Data.front().first->getParent() ||
+          LI->getType() != Data.front().first->getType())
         continue;
-      if (R.isDeleted(LI) || R.isVectorized(LI) || !LI->isSimple())
+      std::optional<int> Dist = getPointersDiff(
+          LI->getType(), LI->getPointerOperand(), Data.front().first->getType(),
+          Data.front().first->getPointerOperand(), DL, SE,
+          /*StrictCheck=*/true);
+      if (!Dist)
         continue;
-      bool IsFound = false;
-      for (auto &Data : ClusteredLoads) {
-        if (LI->getParent() != Data.front().first->getParent())
-          continue;
-        std::optional<int> Dist =
-            getPointersDiff(LI->getType(), LI->getPointerOperand(),
-                            Data.front().first->getType(),
-                            Data.front().first->getPointerOperand(), DL, SE,
-                            /*StrictCheck=*/true);
-        if (Dist && all_of(Data, [&](const std::pair<LoadInst *, int> &Pair) {
-              IsFound |= Pair.first == LI;
-              return IsFound || Pair.second != *Dist;
-            })) {
-          if (!IsFound)
-            Data.emplace_back(LI, *Dist);
-          IsFound = true;
-          break;
-        }
+      auto It = Map.find(*Dist);
+      if (It != Map.end() && It->second != LI)
+        continue;
+      if (It == Map.end()) {
+        Data.emplace_back(LI, *Dist);
+        Map.try_emplace(*Dist, LI);
       }
-      if (!IsFound)
-        ClusteredLoads.emplace_back().emplace_back(LI, 0);
+      IsFound = true;
+      break;
+    }
+    if (!IsFound) {
+      ClusteredLoads.emplace_back().emplace_back(LI, 0);
+      ClusteredDistToLoad.emplace_back().try_emplace(0, LI);
     }
   }
   auto FindMatchingLoads =
@@ -6591,38 +6581,37 @@ static void gatherPossiblyVectorizableLoads(
                               Data.front().first->getType(),
                               Data.front().first->getPointerOperand(), DL, SE,
                               /*StrictCheck=*/true);
-          if (Dist) {
-            // Found matching gathered loads - check if all loads are unique or
-            // can be effectively vectorized.
-            unsigned NumUniques = 0;
-            for (auto [Cnt, Pair] : enumerate(Loads)) {
-              bool Used = any_of(
-                  Data, [&, &P = Pair](const std::pair<LoadInst *, int> &PD) {
-                    return PD.first == P.first;
-                  });
-              if (!Used &&
-                  none_of(Data,
-                          [&, &P = Pair](const std::pair<LoadInst *, int> &PD) {
-                            return *Dist + P.second == PD.second;
-                          })) {
-                ++NumUniques;
-                ToAdd.insert(Cnt);
-              } else if (Used) {
-                Repeated.insert(Cnt);
-              }
-            }
-            if (NumUniques > 0 &&
-                (Loads.size() == NumUniques ||
-                 (Loads.size() - NumUniques >= 2 &&
-                  Loads.size() - NumUniques >= Loads.size() / 2 &&
-                  (has_single_bit(Data.size() + NumUniques) ||
-                   bit_ceil(Data.size()) <
-                       bit_ceil(Data.size() + NumUniques))))) {
-              Offset = *Dist;
-              Start = Idx + 1;
-              return std::next(GatheredLoads.begin(), Idx);
+          if (!Dist)
+            continue;
+          SmallSet<int, 4> DataDists;
+          SmallPtrSet<LoadInst *, 4> DataLoads;
+          for (std::pair<LoadInst *, int> P : Data) {
+            DataDists.insert(P.second);
+            DataLoads.insert(P.first);
+          }
+          // Found matching gathered loads - check if all loads are unique or
+          // can be effectively vectorized.
+          unsigned NumUniques = 0;
+          for (auto [Cnt, Pair] : enumerate(Loads)) {
+            bool Used = DataLoads.contains(Pair.first);
+            if (!Used && !DataDists.contains(*Dist + Pair.second)) {
+              ++NumUniques;
+              ToAdd.insert(Cnt);
+            } else if (Used) {
+              Repeated.insert(Cnt);
             }
           }
+          if (NumUniques > 0 &&
+              (Loads.size() == NumUniques ||
+               (Loads.size() - NumUniques >= 2 &&
+                Loads.size() - NumUniques >= Loads.size() / 2 &&
+                (has_single_bit(Data.size() + NumUniques) ||
+                 bit_ceil(Data.size()) <
+                     bit_ceil(Data.size() + NumUniques))))) {
+            Offset = *Dist;
+            Start = Idx + 1;
+            return std::next(GatheredLoads.begin(), Idx);
+          }
         }
         ToAdd.clear();
         return GatheredLoads.end();


        


More information about the llvm-commits mailing list