[llvm] 7692d10 - [SLP][NFC]Remove dead code + use nlogn lookups instead of n^2
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 4 15:33:45 PDT 2024
Author: Alexey Bataev
Date: 2024-10-04T15:32:04-07:00
New Revision: 7692d106b480a16861f5ed63378ec1b857a20bc1
URL: https://github.com/llvm/llvm-project/commit/7692d106b480a16861f5ed63378ec1b857a20bc1
DIFF: https://github.com/llvm/llvm-project/commit/7692d106b480a16861f5ed63378ec1b857a20bc1.diff
LOG: [SLP][NFC]Remove dead code + use nlogn lookups instead of n^2
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index dc9ad5335f8a52..401597af35bdac 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6525,48 +6525,38 @@ static void gatherPossiblyVectorizableLoads(
Type *ScalarTy = getValueType(VL.front());
if (!isValidElementType(ScalarTy))
return;
- const int NumScalars = VL.size();
- int NumParts = 1;
- if (NumScalars > 1) {
- auto *VecTy = getWidenedType(ScalarTy, NumScalars);
- NumParts = TTI.getNumberOfParts(VecTy);
- if (NumParts == 0 || NumParts >= NumScalars ||
- VecTy->getNumElements() % NumParts != 0 ||
- !hasFullVectorsOrPowerOf2(TTI, VecTy->getElementType(),
- VecTy->getNumElements() / NumParts))
- NumParts = 1;
- }
- unsigned VF = PowerOf2Ceil(NumScalars / NumParts);
SmallVector<SmallVector<std::pair<LoadInst *, int>>> ClusteredLoads;
- for (int I : seq<int>(NumParts)) {
- for (Value *V :
- VL.slice(I * VF, std::min<unsigned>(VF, VL.size() - I * VF))) {
- auto *LI = dyn_cast<LoadInst>(V);
- if (!LI)
+ SmallVector<DenseMap<int, LoadInst *>> ClusteredDistToLoad;
+ for (Value *V : VL) {
+ auto *LI = dyn_cast<LoadInst>(V);
+ if (!LI)
+ continue;
+ if (R.isDeleted(LI) || R.isVectorized(LI) || !LI->isSimple())
+ continue;
+ bool IsFound = false;
+ for (auto [Map, Data] : zip(ClusteredDistToLoad, ClusteredLoads)) {
+ if (LI->getParent() != Data.front().first->getParent() ||
+ LI->getType() != Data.front().first->getType())
continue;
- if (R.isDeleted(LI) || R.isVectorized(LI) || !LI->isSimple())
+ std::optional<int> Dist = getPointersDiff(
+ LI->getType(), LI->getPointerOperand(), Data.front().first->getType(),
+ Data.front().first->getPointerOperand(), DL, SE,
+ /*StrictCheck=*/true);
+ if (!Dist)
continue;
- bool IsFound = false;
- for (auto &Data : ClusteredLoads) {
- if (LI->getParent() != Data.front().first->getParent())
- continue;
- std::optional<int> Dist =
- getPointersDiff(LI->getType(), LI->getPointerOperand(),
- Data.front().first->getType(),
- Data.front().first->getPointerOperand(), DL, SE,
- /*StrictCheck=*/true);
- if (Dist && all_of(Data, [&](const std::pair<LoadInst *, int> &Pair) {
- IsFound |= Pair.first == LI;
- return IsFound || Pair.second != *Dist;
- })) {
- if (!IsFound)
- Data.emplace_back(LI, *Dist);
- IsFound = true;
- break;
- }
+ auto It = Map.find(*Dist);
+ if (It != Map.end() && It->second != LI)
+ continue;
+ if (It == Map.end()) {
+ Data.emplace_back(LI, *Dist);
+ Map.try_emplace(*Dist, LI);
}
- if (!IsFound)
- ClusteredLoads.emplace_back().emplace_back(LI, 0);
+ IsFound = true;
+ break;
+ }
+ if (!IsFound) {
+ ClusteredLoads.emplace_back().emplace_back(LI, 0);
+ ClusteredDistToLoad.emplace_back().try_emplace(0, LI);
}
}
auto FindMatchingLoads =
@@ -6591,38 +6581,37 @@ static void gatherPossiblyVectorizableLoads(
Data.front().first->getType(),
Data.front().first->getPointerOperand(), DL, SE,
/*StrictCheck=*/true);
- if (Dist) {
- // Found matching gathered loads - check if all loads are unique or
- // can be effectively vectorized.
- unsigned NumUniques = 0;
- for (auto [Cnt, Pair] : enumerate(Loads)) {
- bool Used = any_of(
- Data, [&, &P = Pair](const std::pair<LoadInst *, int> &PD) {
- return PD.first == P.first;
- });
- if (!Used &&
- none_of(Data,
- [&, &P = Pair](const std::pair<LoadInst *, int> &PD) {
- return *Dist + P.second == PD.second;
- })) {
- ++NumUniques;
- ToAdd.insert(Cnt);
- } else if (Used) {
- Repeated.insert(Cnt);
- }
- }
- if (NumUniques > 0 &&
- (Loads.size() == NumUniques ||
- (Loads.size() - NumUniques >= 2 &&
- Loads.size() - NumUniques >= Loads.size() / 2 &&
- (has_single_bit(Data.size() + NumUniques) ||
- bit_ceil(Data.size()) <
- bit_ceil(Data.size() + NumUniques))))) {
- Offset = *Dist;
- Start = Idx + 1;
- return std::next(GatheredLoads.begin(), Idx);
+ if (!Dist)
+ continue;
+ SmallSet<int, 4> DataDists;
+ SmallPtrSet<LoadInst *, 4> DataLoads;
+ for (std::pair<LoadInst *, int> P : Data) {
+ DataDists.insert(P.second);
+ DataLoads.insert(P.first);
+ }
+ // Found matching gathered loads - check if all loads are unique or
+ // can be effectively vectorized.
+ unsigned NumUniques = 0;
+ for (auto [Cnt, Pair] : enumerate(Loads)) {
+ bool Used = DataLoads.contains(Pair.first);
+ if (!Used && !DataDists.contains(*Dist + Pair.second)) {
+ ++NumUniques;
+ ToAdd.insert(Cnt);
+ } else if (Used) {
+ Repeated.insert(Cnt);
}
}
+ if (NumUniques > 0 &&
+ (Loads.size() == NumUniques ||
+ (Loads.size() - NumUniques >= 2 &&
+ Loads.size() - NumUniques >= Loads.size() / 2 &&
+ (has_single_bit(Data.size() + NumUniques) ||
+ bit_ceil(Data.size()) <
+ bit_ceil(Data.size() + NumUniques))))) {
+ Offset = *Dist;
+ Start = Idx + 1;
+ return std::next(GatheredLoads.begin(), Idx);
+ }
}
ToAdd.clear();
return GatheredLoads.end();
More information about the llvm-commits
mailing list