[llvm] [SLP]Unify getNumberOfParts use (PR #124774)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 28 07:55:56 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-vectorizers
Author: Alexey Bataev (alexey-bataev)
<details>
<summary>Changes</summary>
Adds getNumberOfParts and uses it instead of similar code across code
base, fixes analysis of non-vectorizable types in
computeMinimumValueSizes.
---
Full diff: https://github.com/llvm/llvm-project/pull/124774.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+52-52)
- (modified) llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll (+10-2)
``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f73ad1b15891a3..7f86651256e577 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1314,6 +1314,22 @@ static bool hasFullVectorsOrPowerOf2(const TargetTransformInfo &TTI, Type *Ty,
Sz % NumParts == 0;
}
+/// Returns number of parts, the type \p VecTy will be split at the codegen
+/// phase. If the type is going to be scalarized or does not uses whole
+/// registers, returns 1.
+static unsigned
+getNumberOfParts(const TargetTransformInfo &TTI, VectorType *VecTy,
+ const unsigned Limit = std::numeric_limits<unsigned>::max()) {
+ unsigned NumParts = TTI.getNumberOfParts(VecTy);
+ if (NumParts == 0 || NumParts >= Limit)
+ return 1;
+ unsigned Sz = getNumElements(VecTy);
+ if (NumParts >= Sz || Sz % NumParts != 0 ||
+ !hasFullVectorsOrPowerOf2(TTI, VecTy->getElementType(), Sz / NumParts))
+ return 1;
+ return NumParts;
+}
+
namespace slpvectorizer {
/// Bottom Up SLP Vectorizer.
@@ -4618,12 +4634,7 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) {
if (!isValidElementType(ScalarTy))
return std::nullopt;
auto *VecTy = getWidenedType(ScalarTy, NumScalars);
- int NumParts = TTI->getNumberOfParts(VecTy);
- if (NumParts == 0 || NumParts >= NumScalars ||
- VecTy->getNumElements() % NumParts != 0 ||
- !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(),
- VecTy->getNumElements() / NumParts))
- NumParts = 1;
+ unsigned NumParts = ::getNumberOfParts(*TTI, VecTy, NumScalars);
SmallVector<int> ExtractMask;
SmallVector<int> Mask;
SmallVector<SmallVector<const TreeEntry *>> Entries;
@@ -5574,8 +5585,8 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
}
}
if (Sz == 2 && TE.getVectorFactor() == 4 &&
- TTI->getNumberOfParts(getWidenedType(TE.Scalars.front()->getType(),
- 2 * TE.getVectorFactor())) == 1)
+ ::getNumberOfParts(*TTI, getWidenedType(TE.Scalars.front()->getType(),
+ 2 * TE.getVectorFactor())) == 1)
return std::nullopt;
if (!ShuffleVectorInst::isOneUseSingleSourceMask(TE.ReuseShuffleIndices,
Sz)) {
@@ -9847,12 +9858,15 @@ void BoUpSLP::transformNodes() {
// only with the single non-undef element).
bool IsSplat = isSplat(Slice);
if (Slices.empty() || !IsSplat ||
- (VF <= 2 && 2 * std::clamp(TTI->getNumberOfParts(getWidenedType(
- Slice.front()->getType(), VF)),
- 1U, VF - 1) !=
- std::clamp(TTI->getNumberOfParts(getWidenedType(
- Slice.front()->getType(), 2 * VF)),
- 1U, 2 * VF)) ||
+ (VF <= 2 &&
+ 2 * std::clamp(
+ ::getNumberOfParts(
+ *TTI, getWidenedType(Slice.front()->getType(), VF)),
+ 1U, VF - 1) !=
+ std::clamp(::getNumberOfParts(
+ *TTI, getWidenedType(Slice.front()->getType(),
+ 2 * VF)),
+ 1U, 2 * VF)) ||
count(Slice, Slice.front()) ==
static_cast<long>(isa<UndefValue>(Slice.front()) ? VF - 1
: 1)) {
@@ -10793,12 +10807,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
}
assert(!CommonMask.empty() && "Expected non-empty common mask.");
auto *MaskVecTy = getWidenedType(ScalarTy, Mask.size());
- unsigned NumParts = TTI.getNumberOfParts(MaskVecTy);
- if (NumParts == 0 || NumParts >= Mask.size() ||
- MaskVecTy->getNumElements() % NumParts != 0 ||
- !hasFullVectorsOrPowerOf2(TTI, MaskVecTy->getElementType(),
- MaskVecTy->getNumElements() / NumParts))
- NumParts = 1;
+ unsigned NumParts = ::getNumberOfParts(TTI, MaskVecTy, Mask.size());
unsigned SliceSize = getPartNumElems(Mask.size(), NumParts);
const auto *It =
find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; });
@@ -10813,12 +10822,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
}
assert(!CommonMask.empty() && "Expected non-empty common mask.");
auto *MaskVecTy = getWidenedType(ScalarTy, Mask.size());
- unsigned NumParts = TTI.getNumberOfParts(MaskVecTy);
- if (NumParts == 0 || NumParts >= Mask.size() ||
- MaskVecTy->getNumElements() % NumParts != 0 ||
- !hasFullVectorsOrPowerOf2(TTI, MaskVecTy->getElementType(),
- MaskVecTy->getNumElements() / NumParts))
- NumParts = 1;
+ unsigned NumParts = ::getNumberOfParts(TTI, MaskVecTy, Mask.size());
unsigned SliceSize = getPartNumElems(Mask.size(), NumParts);
const auto *It =
find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; });
@@ -11351,7 +11355,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
unsigned const NumElts = SrcVecTy->getNumElements();
unsigned const NumScalars = VL.size();
- unsigned NumOfParts = TTI->getNumberOfParts(SrcVecTy);
+ unsigned NumOfParts = ::getNumberOfParts(*TTI, SrcVecTy);
SmallVector<int> InsertMask(NumElts, PoisonMaskElem);
unsigned OffsetBeg = *getElementIndex(VL.front());
@@ -14862,12 +14866,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
SmallVector<SmallVector<const TreeEntry *>> Entries;
Type *OrigScalarTy = GatheredScalars.front()->getType();
auto *VecTy = getWidenedType(ScalarTy, GatheredScalars.size());
- unsigned NumParts = TTI->getNumberOfParts(VecTy);
- if (NumParts == 0 || NumParts >= GatheredScalars.size() ||
- VecTy->getNumElements() % NumParts != 0 ||
- !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(),
- VecTy->getNumElements() / NumParts))
- NumParts = 1;
+ unsigned NumParts = ::getNumberOfParts(*TTI, VecTy, GatheredScalars.size());
if (!all_of(GatheredScalars, IsaPred<UndefValue>)) {
// Check for gathered extracts.
bool Resized = false;
@@ -14899,12 +14898,8 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
Resized = true;
GatheredScalars.append(VF - GatheredScalars.size(),
PoisonValue::get(OrigScalarTy));
- NumParts = TTI->getNumberOfParts(getWidenedType(OrigScalarTy, VF));
- if (NumParts == 0 || NumParts >= GatheredScalars.size() ||
- VecTy->getNumElements() % NumParts != 0 ||
- !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(),
- VecTy->getNumElements() / NumParts))
- NumParts = 1;
+ NumParts =
+ ::getNumberOfParts(*TTI, getWidenedType(OrigScalarTy, VF), VF);
}
}
}
@@ -17049,10 +17044,10 @@ void BoUpSLP::optimizeGatherSequence() {
// Check if the last undefs actually change the final number of used vector
// registers.
return SM1.size() - LastUndefsCnt > 1 &&
- TTI->getNumberOfParts(SI1->getType()) ==
- TTI->getNumberOfParts(
- getWidenedType(SI1->getType()->getElementType(),
- SM1.size() - LastUndefsCnt));
+ ::getNumberOfParts(*TTI, SI1->getType()) ==
+ ::getNumberOfParts(
+ *TTI, getWidenedType(SI1->getType()->getElementType(),
+ SM1.size() - LastUndefsCnt));
};
// Perform O(N^2) search over the gather/shuffle sequences and merge identical
// instructions. TODO: We can further optimize this scan if we split the
@@ -17829,9 +17824,12 @@ bool BoUpSLP::collectValuesToDemote(
const unsigned VF = E.Scalars.size();
Type *OrigScalarTy = E.Scalars.front()->getType();
if (UniqueBases.size() <= 2 ||
- TTI->getNumberOfParts(getWidenedType(OrigScalarTy, VF)) ==
- TTI->getNumberOfParts(getWidenedType(
- IntegerType::get(OrigScalarTy->getContext(), BitWidth), VF)))
+ ::getNumberOfParts(*TTI, getWidenedType(OrigScalarTy, VF)) ==
+ ::getNumberOfParts(
+ *TTI,
+ getWidenedType(
+ IntegerType::get(OrigScalarTy->getContext(), BitWidth),
+ VF)))
ToDemote.push_back(E.Idx);
}
return Res;
@@ -18241,8 +18239,8 @@ void BoUpSLP::computeMinimumValueSizes() {
[&](Value *V) { return AnalyzedMinBWVals.contains(V); }))
return 0u;
- unsigned NumParts = TTI->getNumberOfParts(
- getWidenedType(TreeRootIT, VF * ScalarTyNumElements));
+ unsigned NumParts = ::getNumberOfParts(
+ *TTI, getWidenedType(TreeRootIT, VF * ScalarTyNumElements));
// The maximum bit width required to represent all the values that can be
// demoted without loss of precision. It would be safe to truncate the roots
@@ -18302,8 +18300,10 @@ void BoUpSLP::computeMinimumValueSizes() {
// use - ignore it.
if (NumParts > 1 &&
NumParts ==
- TTI->getNumberOfParts(getWidenedType(
- IntegerType::get(F->getContext(), bit_ceil(MaxBitWidth)), VF)))
+ ::getNumberOfParts(
+ *TTI, getWidenedType(IntegerType::get(F->getContext(),
+ bit_ceil(MaxBitWidth)),
+ VF)))
return 0u;
unsigned Opcode = E.getOpcode();
@@ -20086,14 +20086,14 @@ class HorizontalReduction {
ReduxWidth =
getFloorFullVectorNumberOfElements(TTI, ScalarTy, ReduxWidth);
VectorType *Tp = getWidenedType(ScalarTy, ReduxWidth);
- NumParts = TTI.getNumberOfParts(Tp);
+ NumParts = ::getNumberOfParts(TTI, Tp);
NumRegs =
TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true, Tp));
while (NumParts > NumRegs) {
assert(ReduxWidth > 0 && "ReduxWidth is unexpectedly 0.");
ReduxWidth = bit_floor(ReduxWidth - 1);
VectorType *Tp = getWidenedType(ScalarTy, ReduxWidth);
- NumParts = TTI.getNumberOfParts(Tp);
+ NumParts = ::getNumberOfParts(TTI, Tp);
NumRegs =
TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true, Tp));
}
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll
index 6388cc2dedc73a..085d7a64fc9ac9 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll
@@ -7,9 +7,17 @@ define void @partial_vec_invalid_cost() #0 {
; CHECK-LABEL: define void @partial_vec_invalid_cost(
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> zeroinitializer)
+; CHECK-NEXT: [[LSHR_1:%.*]] = lshr i96 0, 0
+; CHECK-NEXT: [[LSHR_2:%.*]] = lshr i96 0, 0
+; CHECK-NEXT: [[TRUNC_I96_1:%.*]] = trunc i96 [[LSHR_1]] to i32
+; CHECK-NEXT: [[TRUNC_I96_2:%.*]] = trunc i96 [[LSHR_2]] to i32
+; CHECK-NEXT: [[TRUNC_I96_3:%.*]] = trunc i96 0 to i32
+; CHECK-NEXT: [[TRUNC_I96_4:%.*]] = trunc i96 0 to i32
; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> zeroinitializer)
-; CHECK-NEXT: [[OP_RDX3:%.*]] = or i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[OP_RDX:%.*]] = or i32 [[TMP1]], [[TRUNC_I96_1]]
+; CHECK-NEXT: [[OP_RDX1:%.*]] = or i32 [[TRUNC_I96_2]], [[TRUNC_I96_3]]
+; CHECK-NEXT: [[OP_RDX2:%.*]] = or i32 [[OP_RDX]], [[OP_RDX1]]
+; CHECK-NEXT: [[OP_RDX3:%.*]] = or i32 [[OP_RDX2]], [[TRUNC_I96_4]]
; CHECK-NEXT: [[STORE_THIS:%.*]] = zext i32 [[OP_RDX3]] to i96
; CHECK-NEXT: store i96 [[STORE_THIS]], ptr null, align 16
; CHECK-NEXT: ret void
``````````
</details>
https://github.com/llvm/llvm-project/pull/124774
More information about the llvm-commits
mailing list