[llvm] [SLP]Fix the cost of the adjusted extracts in per-register analysis. (PR #96808)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 28 02:45:59 PDT 2024
================
@@ -8304,35 +8304,57 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
});
// FIXME: this must be moved to TTI for better estimation.
unsigned EltsPerVector = getPartNumElems(VL.size(), NumParts);
- auto CheckPerRegistersShuffle =
- [&](MutableArrayRef<int> Mask,
- SmallVector<int> Indices) -> std::optional<TTI::ShuffleKind> {
+ auto CheckPerRegistersShuffle = [&](MutableArrayRef<int> Mask,
+ SmallVectorImpl<unsigned> &Indices)
+ -> std::optional<TTI::ShuffleKind> {
if (NumElts <= EltsPerVector)
return std::nullopt;
+ int OffsetReg0 =
+ alignDown(std::accumulate(Mask.begin(), Mask.end(), INT_MAX,
+ [](int S, int I) {
+ if (I == PoisonMaskElem)
+ return S;
+ return std::min(S, I);
+ }),
+ EltsPerVector);
+ int OffsetReg1 = OffsetReg0;
DenseSet<int> RegIndices;
// Check that if trying to permute same single/2 input vectors.
TTI::ShuffleKind ShuffleKind = TTI::SK_PermuteSingleSrc;
int FirstRegId = -1;
- Indices.assign(1, -1);
- for (int &I : Mask) {
+ Indices.assign(1, OffsetReg0);
+ for (auto [Pos, I] : enumerate(Mask)) {
if (I == PoisonMaskElem)
continue;
- int RegId = (I / NumElts) * NumParts + (I % NumElts) / EltsPerVector;
+ int Idx = I - OffsetReg0;
+ int RegId =
+ (Idx / NumElts) * NumParts + (Idx % NumElts) / EltsPerVector;
if (FirstRegId < 0)
FirstRegId = RegId;
RegIndices.insert(RegId);
if (RegIndices.size() > 2)
return std::nullopt;
if (RegIndices.size() == 2) {
ShuffleKind = TTI::SK_PermuteTwoSrc;
- if (Indices.size() == 1)
- Indices.push_back(-1);
+ if (Indices.size() == 1) {
+ OffsetReg1 = alignDown(
+ std::accumulate(
+ std::next(Mask.begin(), Pos), Mask.end(), INT_MAX,
+ [&](int S, int I) {
+ if (I == PoisonMaskElem)
+ return S;
+ int RegId = ((I - OffsetReg0) / NumElts) * NumParts +
+ ((I - OffsetReg0) % NumElts) / EltsPerVector;
----------------
RKSimon wrote:
What is this doing? Don't we just need the PartIdx?
https://github.com/llvm/llvm-project/pull/96808
More information about the llvm-commits
mailing list