[llvm] da4e0f7 - [SLP][NFC]Fix PR58476: Fix compile time for reductions, NFC.
Alexey Bataev via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 24 10:14:02 PDT 2022
Author: Alexey Bataev
Date: 2022-10-24T10:13:24-07:00
New Revision: da4e0f7ac58711ddeaf86943ae82c1928e26801f
URL: https://github.com/llvm/llvm-project/commit/da4e0f7ac58711ddeaf86943ae82c1928e26801f
DIFF: https://github.com/llvm/llvm-project/commit/da4e0f7ac58711ddeaf86943ae82c1928e26801f.diff
LOG: [SLP][NFC]Fix PR58476: Fix compile time for reductions, NFC.
Improve O(N^2) to O(N) in some cases, reduce number of allocations by
reserving memory.
Also, improve analysis of loads reduction values to avoid analysis
of not vectorizable cases.
Added:
Modified:
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index d60c9f65e3acd..583ba283b07ed 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -2108,7 +2108,7 @@ class BoUpSLP {
}
/// Checks if the provided list of reduced values was checked already for
/// vectorization.
- bool areAnalyzedReductionVals(ArrayRef<Value *> VL) {
+ bool areAnalyzedReductionVals(ArrayRef<Value *> VL) const {
return AnalyzedReductionVals.contains(hash_value(VL));
}
/// Adds the list of reduced values to list of already checked values for the
@@ -3539,6 +3539,24 @@ namespace {
enum class LoadsState { Gather, Vectorize, ScatterVectorize };
} // anonymous namespace
+static bool arePointersCompatible(Value *Ptr1, Value *Ptr2,
+ bool CompareOpcodes = true) {
+ if (getUnderlyingObject(Ptr1) != getUnderlyingObject(Ptr2))
+ return false;
+ auto *GEP1 = dyn_cast<GetElementPtrInst>(Ptr1);
+ if (!GEP1)
+ return false;
+ auto *GEP2 = dyn_cast<GetElementPtrInst>(Ptr2);
+ if (!GEP2)
+ return false;
+ return GEP1->getNumOperands() == 2 && GEP2->getNumOperands() == 2 &&
+ ((isConstant(GEP1->getOperand(1)) &&
+ isConstant(GEP2->getOperand(1))) ||
+ !CompareOpcodes ||
+ getSameOpcode({GEP1->getOperand(1), GEP2->getOperand(1)})
+ .getOpcode());
+}
+
/// Checks if the given array of loads can be represented as a vectorized,
/// scatter or just simple gather.
static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
@@ -3575,17 +3593,7 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
// Check the order of pointer operands or that all pointers are the same.
bool IsSorted = sortPtrAccesses(PointerOps, ScalarTy, DL, SE, Order);
if (IsSorted || all_of(PointerOps, [&PointerOps](Value *P) {
- if (getUnderlyingObject(P) != getUnderlyingObject(PointerOps.front()))
- return false;
- auto *GEP = dyn_cast<GetElementPtrInst>(P);
- if (!GEP)
- return false;
- auto *GEP0 = cast<GetElementPtrInst>(PointerOps.front());
- return GEP->getNumOperands() == 2 &&
- ((isConstant(GEP->getOperand(1)) &&
- isConstant(GEP0->getOperand(1))) ||
- getSameOpcode({GEP->getOperand(1), GEP0->getOperand(1)})
- .getOpcode());
+ return arePointersCompatible(P, PointerOps.front());
})) {
if (IsSorted) {
Value *Ptr0;
@@ -4628,11 +4636,11 @@ static std::pair<size_t, size_t> generateKeySubkey(
hash_code SubKey = hash_value(0);
// Sort the loads by the distance between the pointers.
if (auto *LI = dyn_cast<LoadInst>(V)) {
- Key = hash_combine(hash_value(Instruction::Load), Key);
+ Key = hash_combine(LI->getType(), hash_value(Instruction::Load), Key);
if (LI->isSimple())
SubKey = hash_value(LoadsSubkeyGenerator(Key, LI));
else
- SubKey = hash_value(LI);
+ Key = SubKey = hash_value(LI);
} else if (isVectorLikeInstWithConstOps(V)) {
// Sort extracts by the vector operands.
if (isa<ExtractElementInst, UndefValue>(V))
@@ -4660,7 +4668,7 @@ static std::pair<size_t, size_t> generateKeySubkey(
if (isa<CastInst>(I)) {
std::pair<size_t, size_t> OpVals =
generateKeySubkey(I->getOperand(0), TLI, LoadsSubkeyGenerator,
- /*=AllowAlternate*/ true);
+ /*AllowAlternate=*/true);
Key = hash_combine(OpVals.first, Key);
SubKey = hash_combine(OpVals.first, SubKey);
}
@@ -4719,7 +4727,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
&UserTreeIdx,
this](const InstructionsState &S) {
// Check that every instruction appears once in this bundle.
- DenseMap<Value *, unsigned> UniquePositions;
+ DenseMap<Value *, unsigned> UniquePositions(VL.size());
for (Value *V : VL) {
if (isConstant(V)) {
ReuseShuffleIndicies.emplace_back(
@@ -4877,7 +4885,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
BB &&
sortPtrAccesses(VL, UserTreeIdx.UserTE->getMainOp()->getType(), *DL, *SE,
SortedIndices));
- if (allConstant(VL) || isSplat(VL) || !AreAllSameInsts ||
+ if (!AreAllSameInsts || allConstant(VL) || isSplat(VL) ||
(isa<InsertElementInst, ExtractValueInst, ExtractElementInst>(
S.OpValue) &&
!all_of(VL, isVectorLikeInstWithConstOps)) ||
@@ -4951,9 +4959,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
// Special processing for sorted pointers for ScatterVectorize node with
// constant indeces only.
- if (AreAllSameInsts && !(S.getOpcode() && allSameBlock(VL)) &&
- UserTreeIdx.UserTE &&
- UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize) {
+ if (AreAllSameInsts && UserTreeIdx.UserTE &&
+ UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize &&
+ !(S.getOpcode() && allSameBlock(VL))) {
assert(S.OpValue->getType()->isPointerTy() &&
count_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }) >=
2 &&
@@ -11104,6 +11112,13 @@ class HorizontalReduction {
return I->getOperand(getFirstOperandIndex(I) + 1);
}
+ static bool isGoodForReduction(ArrayRef<Value *> Data) {
+ int Sz = Data.size();
+ auto *I = dyn_cast<Instruction>(Data.front());
+ return Sz > 1 || isConstant(Data.front()) ||
+ (I && !isa<LoadInst>(I) && isValidForAlternation(I->getOpcode()));
+ }
+
public:
HorizontalReduction() = default;
@@ -11199,6 +11214,9 @@ class HorizontalReduction {
MapVector<size_t, MapVector<size_t, MapVector<Value *, unsigned>>>
PossibleReducedVals;
initReductionOps(Inst);
+ DenseMap<Value *, SmallVector<LoadInst *>> LoadsMap;
+ SmallSet<size_t, 2> LoadKeyUsed;
+ SmallPtrSet<Value *, 4> DoNotReverseVals;
while (!Worklist.empty()) {
Instruction *TreeN = Worklist.pop_back_val();
SmallVector<Value *> Args;
@@ -11220,18 +11238,36 @@ class HorizontalReduction {
size_t Key, Idx;
std::tie(Key, Idx) = generateKeySubkey(
V, &TLI,
- [&PossibleReducedVals, &DL, &SE](size_t Key, LoadInst *LI) {
- auto It = PossibleReducedVals.find(Key);
- if (It != PossibleReducedVals.end()) {
- for (const auto &LoadData : It->second) {
- auto *RLI = cast<LoadInst>(LoadData.second.front().first);
- if (getPointersDiff(RLI->getType(),
- RLI->getPointerOperand(), LI->getType(),
- LI->getPointerOperand(), DL, SE,
- /*StrictCheck=*/true))
- return hash_value(RLI->getPointerOperand());
+ [&](size_t Key, LoadInst *LI) {
+ Value *Ptr = getUnderlyingObject(LI->getPointerOperand());
+ if (LoadKeyUsed.contains(Key)) {
+ auto LIt = LoadsMap.find(Ptr);
+ if (LIt != LoadsMap.end()) {
+ for (LoadInst *RLI: LIt->second) {
+ if (getPointersDiff(
+ RLI->getType(), RLI->getPointerOperand(),
+ LI->getType(), LI->getPointerOperand(), DL, SE,
+ /*StrictCheck=*/true))
+ return hash_value(RLI->getPointerOperand());
+ }
+ for (LoadInst *RLI : LIt->second) {
+ if (arePointersCompatible(RLI->getPointerOperand(),
+ LI->getPointerOperand())) {
+ hash_code SubKey = hash_value(RLI->getPointerOperand());
+ DoNotReverseVals.insert(RLI);
+ return SubKey;
+ }
+ }
+ if (LIt->second.size() > 2) {
+ hash_code SubKey =
+ hash_value(LIt->second.back()->getPointerOperand());
+ DoNotReverseVals.insert(LIt->second.back());
+ return SubKey;
+ }
}
}
+ LoadKeyUsed.insert(Key);
+ LoadsMap.try_emplace(Ptr).first->second.push_back(LI);
return hash_value(LI->getPointerOperand());
},
/*AllowAlternate=*/false);
@@ -11245,17 +11281,35 @@ class HorizontalReduction {
size_t Key, Idx;
std::tie(Key, Idx) = generateKeySubkey(
TreeN, &TLI,
- [&PossibleReducedVals, &DL, &SE](size_t Key, LoadInst *LI) {
- auto It = PossibleReducedVals.find(Key);
- if (It != PossibleReducedVals.end()) {
- for (const auto &LoadData : It->second) {
- auto *RLI = cast<LoadInst>(LoadData.second.front().first);
- if (getPointersDiff(RLI->getType(), RLI->getPointerOperand(),
- LI->getType(), LI->getPointerOperand(),
- DL, SE, /*StrictCheck=*/true))
- return hash_value(RLI->getPointerOperand());
+ [&](size_t Key, LoadInst *LI) {
+ Value *Ptr = getUnderlyingObject(LI->getPointerOperand());
+ if (LoadKeyUsed.contains(Key)) {
+ auto LIt = LoadsMap.find(Ptr);
+ if (LIt != LoadsMap.end()) {
+ for (LoadInst *RLI: LIt->second) {
+ if (getPointersDiff(RLI->getType(),
+ RLI->getPointerOperand(), LI->getType(),
+ LI->getPointerOperand(), DL, SE,
+ /*StrictCheck=*/true))
+ return hash_value(RLI->getPointerOperand());
+ }
+ for (LoadInst *RLI : LIt->second) {
+ if (arePointersCompatible(RLI->getPointerOperand(),
+ LI->getPointerOperand())) {
+ hash_code SubKey = hash_value(RLI->getPointerOperand());
+ DoNotReverseVals.insert(RLI);
+ return SubKey;
+ }
+ }
+ if (LIt->second.size() > 2) {
+ hash_code SubKey = hash_value(LIt->second.back()->getPointerOperand());
+ DoNotReverseVals.insert(LIt->second.back());
+ return SubKey;
+ }
}
}
+ LoadKeyUsed.insert(Key);
+ LoadsMap.try_emplace(Ptr).first->second.push_back(LI);
return hash_value(LI->getPointerOperand());
},
/*AllowAlternate=*/false);
@@ -11281,9 +11335,27 @@ class HorizontalReduction {
stable_sort(PossibleRedValsVect, [](const auto &P1, const auto &P2) {
return P1.size() > P2.size();
});
- ReducedVals.emplace_back();
- for (ArrayRef<Value *> Data : PossibleRedValsVect)
- ReducedVals.back().append(Data.rbegin(), Data.rend());
+ int NewIdx = -1;
+ for (ArrayRef<Value *> Data : PossibleRedValsVect) {
+ if (isGoodForReduction(Data) ||
+ (isa<LoadInst>(Data.front()) && NewIdx >= 0 &&
+ isa<LoadInst>(ReducedVals[NewIdx].front()) &&
+ getUnderlyingObject(
+ cast<LoadInst>(Data.front())->getPointerOperand()) ==
+ getUnderlyingObject(cast<LoadInst>(ReducedVals[NewIdx].front())
+ ->getPointerOperand()))) {
+ if (NewIdx < 0) {
+ NewIdx = ReducedVals.size();
+ ReducedVals.emplace_back();
+ }
+ if (DoNotReverseVals.contains(Data.front()))
+ ReducedVals[NewIdx].append(Data.begin(), Data.end());
+ else
+ ReducedVals[NewIdx].append(Data.rbegin(), Data.rend());
+ } else {
+ ReducedVals.emplace_back().append(Data.rbegin(), Data.rend());
+ }
+ }
}
// Sort the reduced values by number of same/alternate opcode and/or pointer
// operand.
@@ -11301,18 +11373,28 @@ class HorizontalReduction {
// If there are a sufficient number of reduction values, reduce
// to a nearby power-of-2. We can safely generate oversized
// vectors and rely on the backend to split them to legal sizes.
- unsigned NumReducedVals = std::accumulate(
- ReducedVals.begin(), ReducedVals.end(), 0,
- [](int Num, ArrayRef<Value *> Vals) { return Num + Vals.size(); });
- if (NumReducedVals < ReductionLimit)
+ size_t NumReducedVals =
+ std::accumulate(ReducedVals.begin(), ReducedVals.end(), 0,
+ [](size_t Num, ArrayRef<Value *> Vals) {
+ if (!isGoodForReduction(Vals))
+ return Num;
+ return Num + Vals.size();
+ });
+ if (NumReducedVals < ReductionLimit) {
+ for (ReductionOpsType &RdxOps : ReductionOps)
+ for (Value *RdxOp : RdxOps)
+ V.analyzedReductionRoot(cast<Instruction>(RdxOp));
return nullptr;
+ }
IRBuilder<> Builder(cast<Instruction>(ReductionRoot));
// Track the reduced values in case if they are replaced by extractelement
// because of the vectorization.
- DenseMap<Value *, WeakTrackingVH> TrackedVals;
+ DenseMap<Value *, WeakTrackingVH> TrackedVals(
+ ReducedVals.size() * ReducedVals.front().size() + ExtraArgs.size());
BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues;
+ ExternallyUsedValues.reserve(ExtraArgs.size() + 1);
// The same extra argument may be used several times, so log each attempt
// to use it.
for (const std::pair<Instruction *, Value *> &Pair : ExtraArgs) {
@@ -11335,7 +11417,8 @@ class HorizontalReduction {
// The reduction root is used as the insertion point for new instructions,
// so set it as externally used to prevent it from being deleted.
ExternallyUsedValues[ReductionRoot];
- SmallDenseSet<Value *> IgnoreList;
+ SmallDenseSet<Value *> IgnoreList(ReductionOps.size() *
+ ReductionOps.front().size());
for (ReductionOpsType &RdxOps : ReductionOps)
for (Value *RdxOp : RdxOps) {
if (!RdxOp)
@@ -11350,7 +11433,7 @@ class HorizontalReduction {
for (Value *V : Candidates)
TrackedVals.try_emplace(V, V);
- DenseMap<Value *, unsigned> VectorizedVals;
+ DenseMap<Value *, unsigned> VectorizedVals(ReducedVals.size());
Value *VectorizedTree = nullptr;
bool CheckForReusedReductionOps = false;
// Try to vectorize elements based on their type.
@@ -11358,7 +11441,8 @@ class HorizontalReduction {
ArrayRef<Value *> OrigReducedVals = ReducedVals[I];
InstructionsState S = getSameOpcode(OrigReducedVals);
SmallVector<Value *> Candidates;
- DenseMap<Value *, Value *> TrackedToOrig;
+ Candidates.reserve(2 * OrigReducedVals.size());
+ DenseMap<Value *, Value *> TrackedToOrig(2 * OrigReducedVals.size());
for (unsigned Cnt = 0, Sz = OrigReducedVals.size(); Cnt < Sz; ++Cnt) {
Value *RdxVal = TrackedVals.find(OrigReducedVals[Cnt])->second;
// Check if the reduction value was not overriden by the extractelement
@@ -11483,18 +11567,14 @@ class HorizontalReduction {
});
}
// Number of uses of the candidates in the vector of values.
- SmallDenseMap<Value *, unsigned> NumUses;
+ SmallDenseMap<Value *, unsigned> NumUses(Candidates.size());
for (unsigned Cnt = 0; Cnt < Pos; ++Cnt) {
Value *V = Candidates[Cnt];
- if (NumUses.count(V) > 0)
- continue;
- NumUses[V] = std::count(VL.begin(), VL.end(), V);
+ ++NumUses.try_emplace(V, 0).first->getSecond();
}
for (unsigned Cnt = Pos + ReduxWidth; Cnt < NumReducedVals; ++Cnt) {
Value *V = Candidates[Cnt];
- if (NumUses.count(V) > 0)
- continue;
- NumUses[V] = std::count(VL.begin(), VL.end(), V);
+ ++NumUses.try_emplace(V, 0).first->getSecond();
}
// Gather externally used values.
SmallPtrSet<Value *, 4> Visited;
@@ -11545,9 +11625,8 @@ class HorizontalReduction {
}
InstructionCost Cost = TreeCost + ReductionCost;
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n");
- if (!Cost.isValid()) {
+ if (!Cost.isValid())
return nullptr;
- }
if (Cost >= -SLPCostThreshold) {
V.getORE()->emit([&]() {
return OptimizationRemarkMissed(
More information about the llvm-commits
mailing list