[llvm] [SLP][NFC] Simplify type checks with isa predicates (PR #87182)
Jakub Kuderski via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 30 22:18:48 PDT 2024
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/87182
For more context on isa predicates, see: https://github.com/llvm/llvm-project/pull/83753.
>From bca1f3446a503c5779780972b8a37242af3128b8 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 31 Mar 2024 01:16:10 -0400
Subject: [PATCH] [SLPVectorizer][NFC] Simplify type checks with isa predicates
For more context on isa predicates, see: https://github.com/llvm/llvm-project/pull/83753.
---
.../Transforms/Vectorize/SLPVectorizer.cpp | 95 +++++++------------
1 file changed, 33 insertions(+), 62 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2875e71081d928..2bc0c5dcc6069d 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -458,8 +458,7 @@ static SmallBitVector isUndefVector(const Value *V,
/// ShuffleVectorInst/getShuffleCost?
static std::optional<TargetTransformInfo::ShuffleKind>
isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) {
- const auto *It =
- find_if(VL, [](Value *V) { return isa<ExtractElementInst>(V); });
+ const auto *It = find_if(VL, IsaPred<ExtractElementInst>);
if (It == VL.end())
return std::nullopt;
auto *EI0 = cast<ExtractElementInst>(*It);
@@ -4695,12 +4694,8 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
// TODO: add analysis of other gather nodes with extractelement
// instructions and other values/instructions, not only undefs.
if ((TE.getOpcode() == Instruction::ExtractElement ||
- (all_of(TE.Scalars,
- [](Value *V) {
- return isa<UndefValue, ExtractElementInst>(V);
- }) &&
- any_of(TE.Scalars,
- [](Value *V) { return isa<ExtractElementInst>(V); }))) &&
+ (all_of(TE.Scalars, IsaPred<UndefValue, ExtractElementInst>) &&
+ any_of(TE.Scalars, IsaPred<ExtractElementInst>))) &&
all_of(TE.Scalars, [](Value *V) {
auto *EE = dyn_cast<ExtractElementInst>(V);
return !EE || isa<FixedVectorType>(EE->getVectorOperandType());
@@ -4721,7 +4716,7 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
// might be transformed.
int Sz = TE.Scalars.size();
if (isSplat(TE.Scalars) && !allConstant(TE.Scalars) &&
- count_if(TE.Scalars, UndefValue::classof) == Sz - 1) {
+ count_if(TE.Scalars, IsaPred<UndefValue>) == Sz - 1) {
const auto *It =
find_if(TE.Scalars, [](Value *V) { return !isConstant(V); });
if (It == TE.Scalars.begin())
@@ -6345,11 +6340,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize &&
!(S.getOpcode() && allSameBlock(VL))) {
assert(S.OpValue->getType()->isPointerTy() &&
- count_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }) >=
- 2 &&
+ count_if(VL, IsaPred<GetElementPtrInst>) >= 2 &&
"Expected pointers only.");
// Reset S to make it GetElementPtr kind of node.
- const auto *It = find_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); });
+ const auto *It = find_if(VL, IsaPred<GetElementPtrInst>);
assert(It != VL.end() && "Expected at least one GEP.");
S = getSameOpcode(*It, *TLI);
}
@@ -6893,17 +6887,12 @@ unsigned BoUpSLP::canMapToVector(Type *T) const {
bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue,
SmallVectorImpl<unsigned> &CurrentOrder,
bool ResizeAllowed) const {
- const auto *It = find_if(VL, [](Value *V) {
- return isa<ExtractElementInst, ExtractValueInst>(V);
- });
+ const auto *It = find_if(VL, IsaPred<ExtractElementInst, ExtractValueInst>);
assert(It != VL.end() && "Expected at least one extract instruction.");
auto *E0 = cast<Instruction>(*It);
- assert(all_of(VL,
- [](Value *V) {
- return isa<UndefValue, ExtractElementInst, ExtractValueInst>(
- V);
- }) &&
- "Invalid opcode");
+ assert(
+ all_of(VL, IsaPred<UndefValue, ExtractElementInst, ExtractValueInst>) &&
+ "Invalid opcode");
// Check if all of the extracts come from the same vector and from the
// correct offset.
Value *Vec = E0->getOperand(0);
@@ -7575,7 +7564,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
}
InstructionCost getBuildVectorCost(ArrayRef<Value *> VL, Value *Root) {
- if ((!Root && allConstant(VL)) || all_of(VL, UndefValue::classof))
+ if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>))
return TTI::TCC_Free;
auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size());
InstructionCost GatherCost = 0;
@@ -7743,13 +7732,12 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
} else if (!Root && isSplat(VL)) {
// Found the broadcasting of the single scalar, calculate the cost as
// the broadcast.
- const auto *It =
- find_if(VL, [](Value *V) { return !isa<UndefValue>(V); });
+ const auto *It = find_if_not(VL, IsaPred<UndefValue>);
assert(It != VL.end() && "Expected at least one non-undef value.");
// Add broadcast for non-identity shuffle only.
bool NeedShuffle =
count(VL, *It) > 1 &&
- (VL.front() != *It || !all_of(VL.drop_front(), UndefValue::classof));
+ (VL.front() != *It || !all_of(VL.drop_front(), IsaPred<UndefValue>));
if (!NeedShuffle)
return TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
CostKind, std::distance(VL.begin(), It),
@@ -7757,7 +7745,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
SmallVector<int> ShuffleMask(VL.size(), PoisonMaskElem);
transform(VL, ShuffleMask.begin(), [](Value *V) {
- return isa<PoisonValue>(V) ? PoisonMaskElem : 0;
+ return isa<PoisonValue>(V) ? PoisonMaskElem : 0;
});
InstructionCost InsertCost = TTI.getVectorInstrCost(
Instruction::InsertElement, VecTy, CostKind, 0,
@@ -7768,7 +7756,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
/*SubTp=*/nullptr, /*Args=*/*It);
}
return GatherCost +
- (all_of(Gathers, UndefValue::classof)
+ (all_of(Gathers, IsaPred<UndefValue>)
? TTI::TCC_Free
: R.getGatherCost(Gathers, !Root && VL.equals(Gathers)));
};
@@ -8178,9 +8166,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
// Take credit for instruction that will become dead.
if (EE->hasOneUse() || !PrevNodeFound) {
Instruction *Ext = EE->user_back();
- if (isa<SExtInst, ZExtInst>(Ext) && all_of(Ext->users(), [](User *U) {
- return isa<GetElementPtrInst>(U);
- })) {
+ if (isa<SExtInst, ZExtInst>(Ext) &&
+ all_of(Ext->users(), IsaPred<GetElementPtrInst>)) {
// Use getExtractWithExtendCost() to calculate the cost of
// extractelement/ext pair.
Cost -=
@@ -8645,8 +8632,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
if (I->hasOneUse()) {
Instruction *Ext = I->user_back();
if ((isa<SExtInst>(Ext) || isa<ZExtInst>(Ext)) &&
- all_of(Ext->users(),
- [](User *U) { return isa<GetElementPtrInst>(U); })) {
+ all_of(Ext->users(), IsaPred<GetElementPtrInst>)) {
// Use getExtractWithExtendCost() to calculate the cost of
// extractelement/ext pair.
InstructionCost Cost = TTI->getExtractWithExtendCost(
@@ -9130,10 +9116,7 @@ bool BoUpSLP::isFullyVectorizableTinyTree(bool ForReduction) const {
(allConstant(TE->Scalars) || isSplat(TE->Scalars) ||
TE->Scalars.size() < Limit ||
((TE->getOpcode() == Instruction::ExtractElement ||
- all_of(TE->Scalars,
- [](Value *V) {
- return isa<ExtractElementInst, UndefValue>(V);
- })) &&
+ all_of(TE->Scalars, IsaPred<ExtractElementInst, UndefValue>)) &&
isFixedVectorShuffle(TE->Scalars, Mask)) ||
(TE->State == TreeEntry::NeedToGather &&
TE->getOpcode() == Instruction::Load && !TE->isAltShuffle()));
@@ -9254,9 +9237,7 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const {
all_of(VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) {
return (TE->State == TreeEntry::NeedToGather &&
TE->getOpcode() != Instruction::ExtractElement &&
- count_if(TE->Scalars,
- [](Value *V) { return isa<ExtractElementInst>(V); }) <=
- Limit) ||
+ count_if(TE->Scalars, IsaPred<ExtractElementInst>) <= Limit) ||
TE->getOpcode() == Instruction::PHI;
}))
return true;
@@ -9285,9 +9266,7 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const {
return isa<ExtractElementInst, UndefValue>(V) ||
(IsAllowedSingleBVNode &&
!V->hasNUsesOrMore(UsesLimit) &&
- any_of(V->users(), [](User *U) {
- return isa<InsertElementInst>(U);
- }));
+ any_of(V->users(), IsaPred<InsertElementInst>));
});
}))
return false;
@@ -10284,7 +10263,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
}
}
- bool IsSplatOrUndefs = isSplat(VL) || all_of(VL, UndefValue::classof);
+ bool IsSplatOrUndefs = isSplat(VL) || all_of(VL, IsaPred<UndefValue>);
// Checks if the 2 PHIs are compatible in terms of high possibility to be
// vectorized.
auto AreCompatiblePHIs = [&](Value *V, Value *V1) {
@@ -11261,8 +11240,7 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx,
InstructionsState S = getSameOpcode(VL, *TLI);
// Special processing for GEPs bundle, which may include non-gep values.
if (!S.getOpcode() && VL.front()->getType()->isPointerTy()) {
- const auto *It =
- find_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); });
+ const auto *It = find_if(VL, IsaPred<GetElementPtrInst>);
if (It != VL.end())
S = getSameOpcode(*It, *TLI);
}
@@ -11432,7 +11410,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
unsigned NumParts = TTI->getNumberOfParts(VecTy);
if (NumParts == 0 || NumParts >= GatheredScalars.size())
NumParts = 1;
- if (!all_of(GatheredScalars, UndefValue::classof)) {
+ if (!all_of(GatheredScalars, IsaPred<UndefValue>)) {
// Check for gathered extracts.
bool Resized = false;
ExtractShuffles =
@@ -11757,7 +11735,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
GatheredScalars[I] = PoisonValue::get(ScalarTy);
}
// Generate constants for final shuffle and build a mask for them.
- if (!all_of(GatheredScalars, PoisonValue::classof)) {
+ if (!all_of(GatheredScalars, IsaPred<PoisonValue>)) {
SmallVector<int> BVMask(GatheredScalars.size(), PoisonMaskElem);
TryPackScalars(GatheredScalars, BVMask, /*IsRootPoison=*/true);
Value *BV = ShuffleBuilder.gather(GatheredScalars, BVMask.size());
@@ -14509,7 +14487,7 @@ void BoUpSLP::computeMinimumValueSizes() {
return SIt != DemotedConsts.end() &&
is_contained(SIt->getSecond(), Idx);
}) ||
- all_of(CTE->Scalars, Constant::classof))
+ all_of(CTE->Scalars, IsaPred<Constant>))
MinBWs.try_emplace(CTE, MaxBitWidth, IsSigned);
}
}
@@ -15257,12 +15235,10 @@ class HorizontalReduction {
static Value *createOp(IRBuilderBase &Builder, RecurKind RdxKind, Value *LHS,
Value *RHS, const Twine &Name,
const ReductionOpsListType &ReductionOps) {
- bool UseSelect =
- ReductionOps.size() == 2 ||
- // Logical or/and.
- (ReductionOps.size() == 1 && any_of(ReductionOps.front(), [](Value *V) {
- return isa<SelectInst>(V);
- }));
+ bool UseSelect = ReductionOps.size() == 2 ||
+ // Logical or/and.
+ (ReductionOps.size() == 1 &&
+ any_of(ReductionOps.front(), IsaPred<SelectInst>));
assert((!UseSelect || ReductionOps.size() != 2 ||
isa<SelectInst>(ReductionOps[1][0])) &&
"Expected cmp + select pairs for reduction");
@@ -15501,7 +15477,7 @@ class HorizontalReduction {
!hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
!isVectorizable(RdxKind, EdgeInst) ||
(R.isAnalyzedReductionRoot(EdgeInst) &&
- all_of(EdgeInst->operands(), Constant::classof))) {
+ all_of(EdgeInst->operands(), IsaPred<Constant>))) {
PossibleReducedVals.push_back(EdgeVal);
continue;
}
@@ -16857,9 +16833,7 @@ bool SLPVectorizerPass::vectorizeInsertElementInst(InsertElementInst *IEI,
SmallVector<Value *, 16> BuildVectorOpds;
SmallVector<int> Mask;
if (!findBuildAggregate(IEI, TTI, BuildVectorOpds, BuildVectorInsts) ||
- (llvm::all_of(
- BuildVectorOpds,
- [](Value *V) { return isa<ExtractElementInst, UndefValue>(V); }) &&
+ (llvm::all_of(BuildVectorOpds, IsaPred<ExtractElementInst, UndefValue>) &&
isFixedVectorShuffle(BuildVectorOpds, Mask)))
return false;
@@ -17080,10 +17054,7 @@ bool SLPVectorizerPass::vectorizeCmpInsts(iterator_range<ItT> CmpInsts,
bool SLPVectorizerPass::vectorizeInserts(InstSetVector &Instructions,
BasicBlock *BB, BoUpSLP &R) {
- assert(all_of(Instructions,
- [](auto *I) {
- return isa<InsertElementInst, InsertValueInst>(I);
- }) &&
+ assert(all_of(Instructions, IsaPred<InsertElementInst, InsertValueInst>) &&
"This function only accepts Insert instructions");
bool OpsChanged = false;
SmallVector<WeakTrackingVH> PostponedInsts;
More information about the llvm-commits
mailing list