[llvm] [SLP]Add cost estimation for gather node reshuffling (PR #115201)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 15 03:03:35 PST 2024
================
@@ -13077,26 +13093,149 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
// Pair.first is the offset to the vector, while Pair.second is the index of
// scalar in the list.
for (const std::pair<unsigned, int> &Pair : EntryLanes) {
- unsigned Idx = Part * VL.size() + Pair.second;
+ int Idx = Part * VL.size() + Pair.second;
Mask[Idx] =
Pair.first * VF +
(ForOrder ? std::distance(
Entries[Pair.first]->Scalars.begin(),
find(Entries[Pair.first]->Scalars, VL[Pair.second]))
: Entries[Pair.first]->findLaneForValue(VL[Pair.second]));
- IsIdentity &= Mask[Idx] == Pair.second;
+ IsIdentity &= Mask[Idx] % VL.size() == Idx % VL.size();
}
- switch (Entries.size()) {
- case 1:
- if (IsIdentity || EntryLanes.size() > 1 || VL.size() <= 2)
- return TargetTransformInfo::SK_PermuteSingleSrc;
- break;
- case 2:
- if (EntryLanes.size() > 2 || VL.size() <= 2)
- return TargetTransformInfo::SK_PermuteTwoSrc;
- break;
- default:
- break;
+ if (ForOrder || IsIdentity || Entries.empty()) {
+ switch (Entries.size()) {
+ case 1:
+ if (IsIdentity || EntryLanes.size() > 1 || VL.size() <= 2)
+ return TargetTransformInfo::SK_PermuteSingleSrc;
+ break;
+ case 2:
+ if (EntryLanes.size() > 2 || VL.size() <= 2)
+ return TargetTransformInfo::SK_PermuteTwoSrc;
+ break;
+ default:
+ break;
+ }
+ } else if (!isa<VectorType>(VL.front()->getType()) &&
+ (EntryLanes.size() > Entries.size() || VL.size() <= 2)) {
+ // Do the cost estimation if shuffle beneficial than buildvector.
+ SmallVector<int> SubMask(std::next(Mask.begin(), Part * VL.size()),
+ std::next(Mask.begin(), (Part + 1) * VL.size()));
+ int MinElement = SubMask.front(), MaxElement = SubMask.front();
+ for (int Idx : SubMask) {
+ if (Idx == PoisonMaskElem)
+ continue;
+ if (MinElement == PoisonMaskElem || MinElement % VF > Idx % VF)
+ MinElement = Idx;
+ if (MaxElement == PoisonMaskElem || MaxElement % VF < Idx % VF)
+ MaxElement = Idx;
+ }
+ assert(MaxElement >= 0 && MinElement >= 0 &&
+ "Expected at least single element.");
----------------
RKSimon wrote:
assert MaxElement >= MinElement?
https://github.com/llvm/llvm-project/pull/115201
More information about the llvm-commits
mailing list