[llvm] [SLP]Allow matching and shuffling of extractelement vector operands with different VF. (PR #97414)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 3 12:54:11 PDT 2024


https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/97414

>From 4ad68bf194757814dbac65dcc2935a63e6acc355 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Tue, 2 Jul 2024 13:10:09 +0000
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
 =?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.5
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 99 +++++++++----------
 .../extractelement-single-use-many-nodes.ll   | 10 +-
 .../X86/extractlements-gathered-first-node.ll |  1 -
 3 files changed, 50 insertions(+), 60 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e233de89a33f1..3ad40c329211e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -510,11 +510,17 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) {
   const auto *It = find_if(VL, IsaPred<ExtractElementInst>);
   if (It == VL.end())
     return std::nullopt;
-  auto *EI0 = cast<ExtractElementInst>(*It);
-  if (isa<ScalableVectorType>(EI0->getVectorOperandType()))
-    return std::nullopt;
   unsigned Size =
-      cast<FixedVectorType>(EI0->getVectorOperandType())->getNumElements();
+      std::accumulate(VL.begin(), VL.end(), 0u, [](unsigned S, Value *V) {
+        auto *EI = dyn_cast<ExtractElementInst>(V);
+        if (!EI)
+          return S;
+        auto *VTy = dyn_cast<FixedVectorType>(EI->getVectorOperandType());
+        if (!VTy)
+          return S;
+        return std::max(S, VTy->getNumElements());
+      });
+
   Value *Vec1 = nullptr;
   Value *Vec2 = nullptr;
   bool HasNonUndefVec = any_of(VL, [](Value *V) {
@@ -544,8 +550,6 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) {
     if (isa<UndefValue>(Vec)) {
       Mask[I] = I;
     } else {
-      if (cast<FixedVectorType>(Vec->getType())->getNumElements() != Size)
-        return std::nullopt;
       if (isa<UndefValue>(EI->getIndexOperand()))
         continue;
       auto *Idx = dyn_cast<ConstantInt>(EI->getIndexOperand());
@@ -10639,36 +10643,20 @@ BoUpSLP::tryToGatherSingleRegisterExtractElements(
     VectorOpToIdx[EI->getVectorOperand()].push_back(I);
   }
   // Sort the vector operands by the maximum number of uses in extractelements.
-  MapVector<unsigned, SmallVector<Value *>> VFToVector;
-  for (const auto &Data : VectorOpToIdx)
-    VFToVector[cast<FixedVectorType>(Data.first->getType())->getNumElements()]
-        .push_back(Data.first);
-  for (auto &Data : VFToVector) {
-    stable_sort(Data.second, [&VectorOpToIdx](Value *V1, Value *V2) {
-      return VectorOpToIdx.find(V1)->second.size() >
-             VectorOpToIdx.find(V2)->second.size();
-    });
-  }
-  // Find the best pair of the vectors with the same number of elements or a
-  // single vector.
+  SmallVector<std::pair<Value *, SmallVector<int>>> Vectors =
+      VectorOpToIdx.takeVector();
+  stable_sort(Vectors, [](const auto &P1, const auto &P2) {
+    return P1.second.size() > P2.second.size();
+  });
+  // Find the best pair of the vectors or a single vector.
   const int UndefSz = UndefVectorExtracts.size();
   unsigned SingleMax = 0;
-  Value *SingleVec = nullptr;
   unsigned PairMax = 0;
-  std::pair<Value *, Value *> PairVec(nullptr, nullptr);
-  for (auto &Data : VFToVector) {
-    Value *V1 = Data.second.front();
-    if (SingleMax < VectorOpToIdx[V1].size() + UndefSz) {
-      SingleMax = VectorOpToIdx[V1].size() + UndefSz;
-      SingleVec = V1;
-    }
-    Value *V2 = nullptr;
-    if (Data.second.size() > 1)
-      V2 = *std::next(Data.second.begin());
-    if (V2 && PairMax < VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() +
-                            UndefSz) {
-      PairMax = VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + UndefSz;
-      PairVec = std::make_pair(V1, V2);
+  if (!Vectors.empty()) {
+    SingleMax = Vectors.front().second.size() + UndefSz;
+    if (Vectors.size() > 1) {
+      auto *ItNext = std::next(Vectors.begin());
+      PairMax = SingleMax + ItNext->second.size();
     }
   }
   if (SingleMax == 0 && PairMax == 0 && UndefSz == 0)
@@ -10679,11 +10667,11 @@ BoUpSLP::tryToGatherSingleRegisterExtractElements(
   SmallVector<Value *> GatheredExtracts(
       VL.size(), PoisonValue::get(VL.front()->getType()));
   if (SingleMax >= PairMax && SingleMax) {
-    for (int Idx : VectorOpToIdx[SingleVec])
+    for (int Idx : Vectors.front().second)
       std::swap(GatheredExtracts[Idx], VL[Idx]);
-  } else {
-    for (Value *V : {PairVec.first, PairVec.second})
-      for (int Idx : VectorOpToIdx[V])
+  } else if (!Vectors.empty()) {
+    for (unsigned Idx : {0, 1})
+      for (int Idx : Vectors[Idx].second)
         std::swap(GatheredExtracts[Idx], VL[Idx]);
   }
   // Add extracts from undefs too.
@@ -11752,25 +11740,29 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
       MutableArrayRef<int> SubMask = Mask.slice(Part * SliceSize, Limit);
       constexpr int MaxBases = 2;
       SmallVector<Value *, MaxBases> Bases(MaxBases);
-#ifndef NDEBUG
-      int PrevSize = 0;
-#endif // NDEBUG
-      for (const auto [I, V]: enumerate(VL)) {
-        if (SubMask[I] == PoisonMaskElem)
+      auto VLMask = zip(VL, SubMask);
+      const unsigned VF = std::accumulate(
+          VLMask.begin(), VLMask.end(), 0U, [&](unsigned S, const auto &D) {
+            if (std::get<1>(D) == PoisonMaskElem)
+              return S;
+            Value *VecOp =
+                cast<ExtractElementInst>(std::get<0>(D))->getVectorOperand();
+            if (const TreeEntry *TE = R.getTreeEntry(VecOp))
+              VecOp = TE->VectorizedValue;
+            assert(VecOp && "Expected vectorized value.");
+            const unsigned Size =
+                cast<FixedVectorType>(VecOp->getType())->getNumElements();
+            return std::max(S, Size);
+          });
+      for (const auto [V, I] : VLMask) {
+        if (I == PoisonMaskElem)
           continue;
         Value *VecOp = cast<ExtractElementInst>(V)->getVectorOperand();
         if (const TreeEntry *TE = R.getTreeEntry(VecOp))
           VecOp = TE->VectorizedValue;
         assert(VecOp && "Expected vectorized value.");
-        const int Size =
-            cast<FixedVectorType>(VecOp->getType())->getNumElements();
-#ifndef NDEBUG
-        assert((PrevSize == Size || PrevSize == 0) &&
-               "Expected vectors of the same size.");
-        PrevSize = Size;
-#endif // NDEBUG
         VecOp = castToScalarTyElem(VecOp);
-        Bases[SubMask[I] < Size ? 0 : 1] = VecOp;
+        Bases[I / VF] = VecOp;
       }
       if (!Bases.front())
         continue;
@@ -11796,16 +11788,17 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
                "Expected first part or all previous parts masked.");
         copy(SubMask, std::next(VecMask.begin(), Part * SliceSize));
       } else {
-        unsigned VF = cast<FixedVectorType>(Vec->getType())->getNumElements();
+        unsigned NewVF =
+            cast<FixedVectorType>(Vec->getType())->getNumElements();
         if (Vec->getType() != SubVec->getType()) {
           unsigned SubVecVF =
               cast<FixedVectorType>(SubVec->getType())->getNumElements();
-          VF = std::max(VF, SubVecVF);
+          NewVF = std::max(NewVF, SubVecVF);
         }
         // Adjust SubMask.
         for (int &Idx : SubMask)
           if (Idx != PoisonMaskElem)
-            Idx += VF;
+            Idx += NewVF;
         copy(SubMask, std::next(VecMask.begin(), Part * SliceSize));
         Vec = createShuffle(Vec, SubVec, VecMask);
         TransformToIdentity(VecMask);
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/extractelement-single-use-many-nodes.ll b/llvm/test/Transforms/SLPVectorizer/X86/extractelement-single-use-many-nodes.ll
index 24b95c4e6ff2f..4e6ed4bce6588 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/extractelement-single-use-many-nodes.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/extractelement-single-use-many-nodes.ll
@@ -9,11 +9,10 @@ define void @foo(double %i) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = fsub <4 x double> zeroinitializer, [[TMP0]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <2 x double> poison, double [[I]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = fsub <2 x double> zeroinitializer, [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x double> [[TMP4]], i32 1
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> poison, <8 x i32> <i32 0, i32 poison, i32 poison, i32 1, i32 poison, i32 0, i32 poison, i32 1>
-; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <8 x double> [[TMP9]], <8 x double> <double poison, double 0.000000e+00, double poison, double poison, double 0.000000e+00, double poison, double poison, double poison>, <8 x i32> <i32 0, i32 9, i32 poison, i32 3, i32 12, i32 5, i32 poison, i32 7>
-; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <8 x double> [[TMP6]], double [[TMP5]], i32 2
-; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <8 x double> [[TMP7]], <8 x double> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 2, i32 7>
+; CHECK-NEXT:    [[TMP22:%.*]] = shufflevector <2 x double> [[TMP4]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> [[TMP22]], <4 x i32> <i32 0, i32 0, i32 5, i32 1>
+; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <4 x double> [[TMP5]], <4 x double> [[TMP5]], <8 x i32> <i32 0, i32 poison, i32 2, i32 3, i32 poison, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <8 x double> [[TMP6]], <8 x double> <double poison, double 0.000000e+00, double poison, double poison, double 0.000000e+00, double poison, double poison, double poison>, <8 x i32> <i32 0, i32 9, i32 2, i32 3, i32 12, i32 5, i32 6, i32 7>
 ; CHECK-NEXT:    [[TMP12:%.*]] = fmul <8 x double> <double 0.000000e+00, double poison, double 0.000000e+00, double 0.000000e+00, double poison, double 0.000000e+00, double 0.000000e+00, double 0.000000e+00>, [[TMP8]]
 ; CHECK-NEXT:    [[TMP13:%.*]] = fadd <8 x double> zeroinitializer, [[TMP12]]
 ; CHECK-NEXT:    [[TMP14:%.*]] = fadd <8 x double> [[TMP13]], zeroinitializer
@@ -27,7 +26,6 @@ define void @foo(double %i) {
 ; CHECK-NEXT:    [[TMP20:%.*]] = extractelement <2 x double> [[TMP18]], i32 1
 ; CHECK-NEXT:    [[I118:%.*]] = fadd double [[TMP19]], [[TMP20]]
 ; CHECK-NEXT:    [[TMP21:%.*]] = fmul <4 x double> zeroinitializer, [[TMP1]]
-; CHECK-NEXT:    [[TMP22:%.*]] = shufflevector <2 x double> [[TMP4]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP23:%.*]] = shufflevector <4 x double> <double 0.000000e+00, double 0.000000e+00, double 0.000000e+00, double poison>, <4 x double> [[TMP22]], <4 x i32> <i32 0, i32 1, i32 2, i32 5>
 ; CHECK-NEXT:    [[TMP24:%.*]] = fadd <4 x double> [[TMP21]], [[TMP23]]
 ; CHECK-NEXT:    [[TMP25:%.*]] = fadd <4 x double> [[TMP24]], zeroinitializer
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/extractlements-gathered-first-node.ll b/llvm/test/Transforms/SLPVectorizer/X86/extractlements-gathered-first-node.ll
index 57fa83b1ccdd6..d5f2cf7fc28c4 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/extractlements-gathered-first-node.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/extractlements-gathered-first-node.ll
@@ -6,7 +6,6 @@ define void @test() {
 ; CHECK-NEXT:  bb:
 ; CHECK-NEXT:    [[TMP0:%.*]] = extractelement <4 x i32> zeroinitializer, i32 0
 ; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x i32> zeroinitializer, i32 0
-; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <2 x i32> <i32 0, i32 undef>, i32 [[TMP1]], i32 1
 ; CHECK-NEXT:    [[ICMP:%.*]] = icmp ult i32 [[TMP0]], [[TMP1]]
 ; CHECK-NEXT:    ret void
 ;



More information about the llvm-commits mailing list