[llvm] f020bf1 - [SLP]Initial support for non-power-of-2 (but whole reg) vectorization for stores

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 9 12:22:50 PDT 2024


Author: Alexey Bataev
Date: 2024-10-09T15:22:44-04:00
New Revision: f020bf15263f71e76e8b64fd0c333fff9744beae

URL: https://github.com/llvm/llvm-project/commit/f020bf15263f71e76e8b64fd0c333fff9744beae
DIFF: https://github.com/llvm/llvm-project/commit/f020bf15263f71e76e8b64fd0c333fff9744beae.diff

LOG: [SLP]Initial support for non-power-of-2 (but whole reg) vectorization for stores

Allows non-power-of-2 vectorization for stores, but still requires, that
vectorized number of elements forms full vector registers.

Reviewers: RKSimon

Reviewed By: RKSimon

Pull Request: https://github.com/llvm/llvm-project/pull/111194

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 553df1c08f3aeb..94de520a2715ff 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -278,6 +278,22 @@ static unsigned getFullVectorNumberOfElements(const TargetTransformInfo &TTI,
   return bit_ceil(divideCeil(Sz, NumParts)) * NumParts;
 }
 
+/// Returns the number of elements of the given type \p Ty, not greater than \p
+/// Sz, which forms type, which splits by \p TTI into whole vector types during
+/// legalization.
+static unsigned
+getFloorFullVectorNumberOfElements(const TargetTransformInfo &TTI, Type *Ty,
+                                   unsigned Sz) {
+  if (!isValidElementType(Ty))
+    return bit_floor(Sz);
+  // Find the number of elements, which forms full vectors.
+  unsigned NumParts = TTI.getNumberOfParts(getWidenedType(Ty, Sz));
+  if (NumParts == 0 || NumParts >= Sz)
+    return bit_floor(Sz);
+  unsigned RegVF = bit_ceil(divideCeil(Sz, NumParts));
+  return (Sz / RegVF) * RegVF;
+}
+
 static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
                                                    SmallVectorImpl<int> &Mask) {
   // The ShuffleBuilder implementation use shufflevector to splat an "element".
@@ -7716,7 +7732,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
     }
     size_t NumUniqueScalarValues = UniqueValues.size();
     bool IsFullVectors = hasFullVectorsOrPowerOf2(
-        *TTI, UniqueValues.front()->getType(), NumUniqueScalarValues);
+        *TTI, getValueType(UniqueValues.front()), NumUniqueScalarValues);
     if (NumUniqueScalarValues == VL.size() &&
         (VectorizeNonPowerOf2 || IsFullVectors)) {
       ReuseShuffleIndices.clear();
@@ -17466,7 +17482,11 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R,
   const unsigned Sz = R.getVectorElementSize(Chain[0]);
   unsigned VF = Chain.size();
 
-  if (!has_single_bit(Sz) || !has_single_bit(VF) || VF < 2 || VF < MinVF) {
+  if (!has_single_bit(Sz) ||
+      !hasFullVectorsOrPowerOf2(
+          *TTI, cast<StoreInst>(Chain.front())->getValueOperand()->getType(),
+          VF) ||
+      VF < 2 || VF < MinVF) {
     // Check if vectorizing with a non-power-of-2 VF should be considered. At
     // the moment, only consider cases where VF + 1 is a power-of-2, i.e. almost
     // all vector lanes are used.
@@ -17484,10 +17504,12 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R,
   InstructionsState S = getSameOpcode(ValOps.getArrayRef(), *TLI);
   if (all_of(ValOps, IsaPred<Instruction>) && ValOps.size() > 1) {
     DenseSet<Value *> Stores(Chain.begin(), Chain.end());
-    bool IsPowerOf2 =
-        has_single_bit(ValOps.size()) ||
+    bool IsAllowedSize =
+        hasFullVectorsOrPowerOf2(*TTI, ValOps.front()->getType(),
+                                 ValOps.size()) ||
         (VectorizeNonPowerOf2 && has_single_bit(ValOps.size() + 1));
-    if ((!IsPowerOf2 && S.getOpcode() && S.getOpcode() != Instruction::Load &&
+    if ((!IsAllowedSize && S.getOpcode() &&
+         S.getOpcode() != Instruction::Load &&
          (!S.MainOp->isSafeToRemove() ||
           any_of(ValOps.getArrayRef(),
                  [&](Value *V) {
@@ -17498,7 +17520,7 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R,
                            }));
                  }))) ||
         (ValOps.size() > Chain.size() / 2 && !S.getOpcode())) {
-      Size = (!IsPowerOf2 && S.getOpcode()) ? 1 : 2;
+      Size = (!IsAllowedSize && S.getOpcode()) ? 1 : 2;
       return false;
     }
   }
@@ -17626,15 +17648,11 @@ bool SLPVectorizerPass::vectorizeStores(
 
       unsigned MaxVF =
           std::min(R.getMaximumVF(EltSize, Instruction::Store), MaxElts);
-      unsigned MaxRegVF = MaxVF;
       auto *Store = cast<StoreInst>(Operands[0]);
       Type *StoreTy = Store->getValueOperand()->getType();
       Type *ValueTy = StoreTy;
       if (auto *Trunc = dyn_cast<TruncInst>(Store->getValueOperand()))
         ValueTy = Trunc->getSrcTy();
-      if (ValueTy == StoreTy &&
-          R.getVectorElementSize(Store->getValueOperand()) <= EltSize)
-        MaxVF = std::min<unsigned>(MaxVF, bit_floor(Operands.size()));
       unsigned MinVF = std::max<unsigned>(
           2, PowerOf2Ceil(TTI->getStoreMinimumVF(
                  R.getMinVF(DL->getTypeStoreSizeInBits(StoreTy)), StoreTy,
@@ -17652,10 +17670,21 @@ bool SLPVectorizerPass::vectorizeStores(
         // First try vectorizing with a non-power-of-2 VF. At the moment, only
         // consider cases where VF + 1 is a power-of-2, i.e. almost all vector
         // lanes are used.
-        unsigned CandVF =
-            std::clamp<unsigned>(Operands.size(), MaxVF, MaxRegVF);
-        if (has_single_bit(CandVF + 1))
+        unsigned CandVF = std::clamp<unsigned>(Operands.size(), MinVF, MaxVF);
+        if (has_single_bit(CandVF + 1)) {
           NonPowerOf2VF = CandVF;
+          assert(NonPowerOf2VF != MaxVF &&
+                 "Non-power-of-2 VF should not be equal to MaxVF");
+        }
+      }
+
+      unsigned MaxRegVF = MaxVF;
+      MaxVF = std::min<unsigned>(MaxVF, bit_floor(Operands.size()));
+      if (MaxVF < MinVF) {
+        LLVM_DEBUG(dbgs() << "SLP: Vectorization infeasible as MaxVF (" << MaxVF
+                          << ") < "
+                          << "MinVF (" << MinVF << ")\n");
+        continue;
       }
 
       unsigned Sz = 1 + Log2_32(MaxVF) - Log2_32(MinVF);
@@ -17810,7 +17839,7 @@ bool SLPVectorizerPass::vectorizeStores(
                         std::bind(IsNotVectorized, Size >= MaxRegVF,
                                   std::placeholders::_1)));
           }
-          if (!AnyProfitableGraph && Size >= MaxRegVF)
+          if (!AnyProfitableGraph && Size >= MaxRegVF && has_single_bit(Size))
             break;
         }
         // All values vectorized - exit.
@@ -17823,7 +17852,7 @@ bool SLPVectorizerPass::vectorizeStores(
             (Repeat > 1 && (RepeatChanged || !AnyProfitableGraph)))
           break;
         constexpr unsigned StoresLimit = 64;
-        const unsigned MaxTotalNum = bit_floor(std::min<unsigned>(
+        const unsigned MaxTotalNum = std::min<unsigned>(
             Operands.size(),
             static_cast<unsigned>(
                 End -
@@ -17831,8 +17860,13 @@ bool SLPVectorizerPass::vectorizeStores(
                     RangeSizes.begin(),
                     find_if(RangeSizes, std::bind(IsNotVectorized, true,
                                                   std::placeholders::_1))) +
-                1)));
-        unsigned VF = PowerOf2Ceil(CandidateVFs.front()) * 2;
+                1));
+        unsigned VF = bit_ceil(CandidateVFs.front()) * 2;
+        unsigned Limit =
+            getFloorFullVectorNumberOfElements(*TTI, StoreTy, MaxTotalNum);
+        CandidateVFs.clear();
+        if (bit_floor(Limit) == VF)
+          CandidateVFs.push_back(Limit);
         if (VF > MaxTotalNum || VF >= StoresLimit)
           break;
         for_each(RangeSizes, [&](std::pair<unsigned, unsigned> &P) {
@@ -17841,7 +17875,6 @@ bool SLPVectorizerPass::vectorizeStores(
         });
         // Last attempt to vectorize max number of elements, if all previous
         // attempts were unsuccessful because of the cost issues.
-        CandidateVFs.clear();
         CandidateVFs.push_back(VF);
       }
     }

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll b/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll
index b5f993f986c7cc..aff66dd7c10ea7 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll
@@ -4,30 +4,16 @@
 define void @test(ptr noalias %0, ptr noalias %1) {
 ; CHECK-LABEL: define void @test(
 ; CHECK-SAME: ptr noalias [[TMP0:%.*]], ptr noalias [[TMP1:%.*]]) {
-; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i64 24
-; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr i8, ptr [[TMP1]], i64 48
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr i8, ptr [[TMP1]], i64 8
-; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[TMP1]], i64 16
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP0]], i64 24
-; CHECK-NEXT:    [[TMP8:%.*]] = load double, ptr [[TMP7]], align 8
-; CHECK-NEXT:    store double [[TMP8]], ptr [[TMP5]], align 8
 ; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr i8, ptr [[TMP0]], i64 48
-; CHECK-NEXT:    [[TMP10:%.*]] = load double, ptr [[TMP9]], align 16
-; CHECK-NEXT:    store double [[TMP10]], ptr [[TMP6]], align 16
 ; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[TMP0]], i64 8
-; CHECK-NEXT:    [[TMP12:%.*]] = load double, ptr [[TMP11]], align 8
-; CHECK-NEXT:    store double [[TMP12]], ptr [[TMP3]], align 8
-; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[TMP0]], i64 32
-; CHECK-NEXT:    [[TMP14:%.*]] = load double, ptr [[TMP13]], align 16
-; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr i8, ptr [[TMP1]], i64 32
-; CHECK-NEXT:    store double [[TMP14]], ptr [[TMP15]], align 16
-; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr i8, ptr [[TMP0]], i64 56
-; CHECK-NEXT:    [[TMP17:%.*]] = load double, ptr [[TMP16]], align 8
-; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr i8, ptr [[TMP1]], i64 40
-; CHECK-NEXT:    store double [[TMP17]], ptr [[TMP18]], align 8
-; CHECK-NEXT:    [[TMP19:%.*]] = getelementptr i8, ptr [[TMP0]], i64 16
-; CHECK-NEXT:    [[TMP20:%.*]] = load double, ptr [[TMP19]], align 16
-; CHECK-NEXT:    store double [[TMP20]], ptr [[TMP4]], align 16
+; CHECK-NEXT:    [[TMP6:%.*]] = load <2 x double>, ptr [[TMP9]], align 16
+; CHECK-NEXT:    [[TMP7:%.*]] = load <4 x double>, ptr [[TMP11]], align 8
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> poison, <6 x i32> <i32 0, i32 1, i32 poison, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> [[TMP10]], <6 x i32> <i32 2, i32 4, i32 0, i32 3, i32 5, i32 1>
+; CHECK-NEXT:    store <6 x double> [[TMP13]], ptr [[TMP5]], align 8
 ; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i8, ptr [[TMP0]], i64 40
 ; CHECK-NEXT:    [[TMP22:%.*]] = load double, ptr [[TMP21]], align 8
 ; CHECK-NEXT:    [[TMP23:%.*]] = getelementptr i8, ptr [[TMP1]], i64 56


        


More information about the llvm-commits mailing list