[llvm] [SLP]Unify getNumberOfParts use (PR #124774)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 28 07:55:56 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-vectorizers

Author: Alexey Bataev (alexey-bataev)

<details>
<summary>Changes</summary>

Adds getNumberOfParts and uses it instead of similar code across code
base, fixes analysis of non-vectorizable types in
computeMinimumValueSizes.


---
Full diff: https://github.com/llvm/llvm-project/pull/124774.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+52-52) 
- (modified) llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll (+10-2) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f73ad1b15891a3..7f86651256e577 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1314,6 +1314,22 @@ static bool hasFullVectorsOrPowerOf2(const TargetTransformInfo &TTI, Type *Ty,
          Sz % NumParts == 0;
 }
 
+/// Returns number of parts, the type \p VecTy will be split at the codegen
+/// phase. If the type is going to be scalarized or does not uses whole
+/// registers, returns 1.
+static unsigned
+getNumberOfParts(const TargetTransformInfo &TTI, VectorType *VecTy,
+                 const unsigned Limit = std::numeric_limits<unsigned>::max()) {
+  unsigned NumParts = TTI.getNumberOfParts(VecTy);
+  if (NumParts == 0 || NumParts >= Limit)
+    return 1;
+  unsigned Sz = getNumElements(VecTy);
+  if (NumParts >= Sz || Sz % NumParts != 0 ||
+      !hasFullVectorsOrPowerOf2(TTI, VecTy->getElementType(), Sz / NumParts))
+    return 1;
+  return NumParts;
+}
+
 namespace slpvectorizer {
 
 /// Bottom Up SLP Vectorizer.
@@ -4618,12 +4634,7 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) {
   if (!isValidElementType(ScalarTy))
     return std::nullopt;
   auto *VecTy = getWidenedType(ScalarTy, NumScalars);
-  int NumParts = TTI->getNumberOfParts(VecTy);
-  if (NumParts == 0 || NumParts >= NumScalars ||
-      VecTy->getNumElements() % NumParts != 0 ||
-      !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(),
-                                VecTy->getNumElements() / NumParts))
-    NumParts = 1;
+  unsigned NumParts = ::getNumberOfParts(*TTI, VecTy, NumScalars);
   SmallVector<int> ExtractMask;
   SmallVector<int> Mask;
   SmallVector<SmallVector<const TreeEntry *>> Entries;
@@ -5574,8 +5585,8 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
       }
     }
     if (Sz == 2 && TE.getVectorFactor() == 4 &&
-        TTI->getNumberOfParts(getWidenedType(TE.Scalars.front()->getType(),
-                                             2 * TE.getVectorFactor())) == 1)
+        ::getNumberOfParts(*TTI, getWidenedType(TE.Scalars.front()->getType(),
+                                                2 * TE.getVectorFactor())) == 1)
       return std::nullopt;
     if (!ShuffleVectorInst::isOneUseSingleSourceMask(TE.ReuseShuffleIndices,
                                                      Sz)) {
@@ -9847,12 +9858,15 @@ void BoUpSLP::transformNodes() {
           // only with the single non-undef element).
           bool IsSplat = isSplat(Slice);
           if (Slices.empty() || !IsSplat ||
-              (VF <= 2 && 2 * std::clamp(TTI->getNumberOfParts(getWidenedType(
-                                             Slice.front()->getType(), VF)),
-                                         1U, VF - 1) !=
-                              std::clamp(TTI->getNumberOfParts(getWidenedType(
-                                             Slice.front()->getType(), 2 * VF)),
-                                         1U, 2 * VF)) ||
+              (VF <= 2 &&
+               2 * std::clamp(
+                       ::getNumberOfParts(
+                           *TTI, getWidenedType(Slice.front()->getType(), VF)),
+                       1U, VF - 1) !=
+                   std::clamp(::getNumberOfParts(
+                                  *TTI, getWidenedType(Slice.front()->getType(),
+                                                       2 * VF)),
+                              1U, 2 * VF)) ||
               count(Slice, Slice.front()) ==
                   static_cast<long>(isa<UndefValue>(Slice.front()) ? VF - 1
                                                                    : 1)) {
@@ -10793,12 +10807,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     }
     assert(!CommonMask.empty() && "Expected non-empty common mask.");
     auto *MaskVecTy = getWidenedType(ScalarTy, Mask.size());
-    unsigned NumParts = TTI.getNumberOfParts(MaskVecTy);
-    if (NumParts == 0 || NumParts >= Mask.size() ||
-        MaskVecTy->getNumElements() % NumParts != 0 ||
-        !hasFullVectorsOrPowerOf2(TTI, MaskVecTy->getElementType(),
-                                  MaskVecTy->getNumElements() / NumParts))
-      NumParts = 1;
+    unsigned NumParts = ::getNumberOfParts(TTI, MaskVecTy, Mask.size());
     unsigned SliceSize = getPartNumElems(Mask.size(), NumParts);
     const auto *It =
         find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; });
@@ -10813,12 +10822,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     }
     assert(!CommonMask.empty() && "Expected non-empty common mask.");
     auto *MaskVecTy = getWidenedType(ScalarTy, Mask.size());
-    unsigned NumParts = TTI.getNumberOfParts(MaskVecTy);
-    if (NumParts == 0 || NumParts >= Mask.size() ||
-        MaskVecTy->getNumElements() % NumParts != 0 ||
-        !hasFullVectorsOrPowerOf2(TTI, MaskVecTy->getElementType(),
-                                  MaskVecTy->getNumElements() / NumParts))
-      NumParts = 1;
+    unsigned NumParts = ::getNumberOfParts(TTI, MaskVecTy, Mask.size());
     unsigned SliceSize = getPartNumElems(Mask.size(), NumParts);
     const auto *It =
         find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; });
@@ -11351,7 +11355,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     unsigned const NumElts = SrcVecTy->getNumElements();
     unsigned const NumScalars = VL.size();
 
-    unsigned NumOfParts = TTI->getNumberOfParts(SrcVecTy);
+    unsigned NumOfParts = ::getNumberOfParts(*TTI, SrcVecTy);
 
     SmallVector<int> InsertMask(NumElts, PoisonMaskElem);
     unsigned OffsetBeg = *getElementIndex(VL.front());
@@ -14862,12 +14866,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
   SmallVector<SmallVector<const TreeEntry *>> Entries;
   Type *OrigScalarTy = GatheredScalars.front()->getType();
   auto *VecTy = getWidenedType(ScalarTy, GatheredScalars.size());
-  unsigned NumParts = TTI->getNumberOfParts(VecTy);
-  if (NumParts == 0 || NumParts >= GatheredScalars.size() ||
-      VecTy->getNumElements() % NumParts != 0 ||
-      !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(),
-                                VecTy->getNumElements() / NumParts))
-    NumParts = 1;
+  unsigned NumParts = ::getNumberOfParts(*TTI, VecTy, GatheredScalars.size());
   if (!all_of(GatheredScalars, IsaPred<UndefValue>)) {
     // Check for gathered extracts.
     bool Resized = false;
@@ -14899,12 +14898,8 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
             Resized = true;
             GatheredScalars.append(VF - GatheredScalars.size(),
                                    PoisonValue::get(OrigScalarTy));
-            NumParts = TTI->getNumberOfParts(getWidenedType(OrigScalarTy, VF));
-            if (NumParts == 0 || NumParts >= GatheredScalars.size() ||
-                VecTy->getNumElements() % NumParts != 0 ||
-                !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(),
-                                          VecTy->getNumElements() / NumParts))
-              NumParts = 1;
+            NumParts =
+                ::getNumberOfParts(*TTI, getWidenedType(OrigScalarTy, VF), VF);
           }
       }
     }
@@ -17049,10 +17044,10 @@ void BoUpSLP::optimizeGatherSequence() {
     // Check if the last undefs actually change the final number of used vector
     // registers.
     return SM1.size() - LastUndefsCnt > 1 &&
-           TTI->getNumberOfParts(SI1->getType()) ==
-               TTI->getNumberOfParts(
-                   getWidenedType(SI1->getType()->getElementType(),
-                                  SM1.size() - LastUndefsCnt));
+           ::getNumberOfParts(*TTI, SI1->getType()) ==
+               ::getNumberOfParts(
+                   *TTI, getWidenedType(SI1->getType()->getElementType(),
+                                        SM1.size() - LastUndefsCnt));
   };
   // Perform O(N^2) search over the gather/shuffle sequences and merge identical
   // instructions. TODO: We can further optimize this scan if we split the
@@ -17829,9 +17824,12 @@ bool BoUpSLP::collectValuesToDemote(
       const unsigned VF = E.Scalars.size();
       Type *OrigScalarTy = E.Scalars.front()->getType();
       if (UniqueBases.size() <= 2 ||
-          TTI->getNumberOfParts(getWidenedType(OrigScalarTy, VF)) ==
-              TTI->getNumberOfParts(getWidenedType(
-                  IntegerType::get(OrigScalarTy->getContext(), BitWidth), VF)))
+          ::getNumberOfParts(*TTI, getWidenedType(OrigScalarTy, VF)) ==
+              ::getNumberOfParts(
+                  *TTI,
+                  getWidenedType(
+                      IntegerType::get(OrigScalarTy->getContext(), BitWidth),
+                      VF)))
         ToDemote.push_back(E.Idx);
     }
     return Res;
@@ -18241,8 +18239,8 @@ void BoUpSLP::computeMinimumValueSizes() {
                [&](Value *V) { return AnalyzedMinBWVals.contains(V); }))
       return 0u;
 
-    unsigned NumParts = TTI->getNumberOfParts(
-        getWidenedType(TreeRootIT, VF * ScalarTyNumElements));
+    unsigned NumParts = ::getNumberOfParts(
+        *TTI, getWidenedType(TreeRootIT, VF * ScalarTyNumElements));
 
     // The maximum bit width required to represent all the values that can be
     // demoted without loss of precision. It would be safe to truncate the roots
@@ -18302,8 +18300,10 @@ void BoUpSLP::computeMinimumValueSizes() {
     // use - ignore it.
     if (NumParts > 1 &&
         NumParts ==
-            TTI->getNumberOfParts(getWidenedType(
-                IntegerType::get(F->getContext(), bit_ceil(MaxBitWidth)), VF)))
+            ::getNumberOfParts(
+                *TTI, getWidenedType(IntegerType::get(F->getContext(),
+                                                      bit_ceil(MaxBitWidth)),
+                                     VF)))
       return 0u;
 
     unsigned Opcode = E.getOpcode();
@@ -20086,14 +20086,14 @@ class HorizontalReduction {
         ReduxWidth =
             getFloorFullVectorNumberOfElements(TTI, ScalarTy, ReduxWidth);
         VectorType *Tp = getWidenedType(ScalarTy, ReduxWidth);
-        NumParts = TTI.getNumberOfParts(Tp);
+        NumParts = ::getNumberOfParts(TTI, Tp);
         NumRegs =
             TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true, Tp));
         while (NumParts > NumRegs) {
           assert(ReduxWidth > 0 && "ReduxWidth is unexpectedly 0.");
           ReduxWidth = bit_floor(ReduxWidth - 1);
           VectorType *Tp = getWidenedType(ScalarTy, ReduxWidth);
-          NumParts = TTI.getNumberOfParts(Tp);
+          NumParts = ::getNumberOfParts(TTI, Tp);
           NumRegs =
               TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true, Tp));
         }
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll
index 6388cc2dedc73a..085d7a64fc9ac9 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll
@@ -7,9 +7,17 @@ define void @partial_vec_invalid_cost() #0 {
 ; CHECK-LABEL: define void @partial_vec_invalid_cost(
 ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> zeroinitializer)
+; CHECK-NEXT:    [[LSHR_1:%.*]] = lshr i96 0, 0
+; CHECK-NEXT:    [[LSHR_2:%.*]] = lshr i96 0, 0
+; CHECK-NEXT:    [[TRUNC_I96_1:%.*]] = trunc i96 [[LSHR_1]] to i32
+; CHECK-NEXT:    [[TRUNC_I96_2:%.*]] = trunc i96 [[LSHR_2]] to i32
+; CHECK-NEXT:    [[TRUNC_I96_3:%.*]] = trunc i96 0 to i32
+; CHECK-NEXT:    [[TRUNC_I96_4:%.*]] = trunc i96 0 to i32
 ; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> zeroinitializer)
-; CHECK-NEXT:    [[OP_RDX3:%.*]] = or i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[OP_RDX:%.*]] = or i32 [[TMP1]], [[TRUNC_I96_1]]
+; CHECK-NEXT:    [[OP_RDX1:%.*]] = or i32 [[TRUNC_I96_2]], [[TRUNC_I96_3]]
+; CHECK-NEXT:    [[OP_RDX2:%.*]] = or i32 [[OP_RDX]], [[OP_RDX1]]
+; CHECK-NEXT:    [[OP_RDX3:%.*]] = or i32 [[OP_RDX2]], [[TRUNC_I96_4]]
 ; CHECK-NEXT:    [[STORE_THIS:%.*]] = zext i32 [[OP_RDX3]] to i96
 ; CHECK-NEXT:    store i96 [[STORE_THIS]], ptr null, align 16
 ; CHECK-NEXT:    ret void

``````````

</details>


https://github.com/llvm/llvm-project/pull/124774


More information about the llvm-commits mailing list