[llvm] [SLP][NFC] Simplify type checks with isa predicates (PR #87182)

Jakub Kuderski via llvm-commits llvm-commits at lists.llvm.org
Sat Mar 30 22:18:48 PDT 2024


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/87182

For more context on isa predicates, see: https://github.com/llvm/llvm-project/pull/83753.

>From bca1f3446a503c5779780972b8a37242af3128b8 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 31 Mar 2024 01:16:10 -0400
Subject: [PATCH] [SLPVectorizer][NFC] Simplify type checks with isa predicates

For more context on isa predicates, see: https://github.com/llvm/llvm-project/pull/83753.
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 95 +++++++------------
 1 file changed, 33 insertions(+), 62 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2875e71081d928..2bc0c5dcc6069d 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -458,8 +458,7 @@ static SmallBitVector isUndefVector(const Value *V,
 /// ShuffleVectorInst/getShuffleCost?
 static std::optional<TargetTransformInfo::ShuffleKind>
 isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) {
-  const auto *It =
-      find_if(VL, [](Value *V) { return isa<ExtractElementInst>(V); });
+  const auto *It = find_if(VL, IsaPred<ExtractElementInst>);
   if (It == VL.end())
     return std::nullopt;
   auto *EI0 = cast<ExtractElementInst>(*It);
@@ -4695,12 +4694,8 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
     // TODO: add analysis of other gather nodes with extractelement
     // instructions and other values/instructions, not only undefs.
     if ((TE.getOpcode() == Instruction::ExtractElement ||
-         (all_of(TE.Scalars,
-                 [](Value *V) {
-                   return isa<UndefValue, ExtractElementInst>(V);
-                 }) &&
-          any_of(TE.Scalars,
-                 [](Value *V) { return isa<ExtractElementInst>(V); }))) &&
+         (all_of(TE.Scalars, IsaPred<UndefValue, ExtractElementInst>) &&
+          any_of(TE.Scalars, IsaPred<ExtractElementInst>))) &&
         all_of(TE.Scalars, [](Value *V) {
           auto *EE = dyn_cast<ExtractElementInst>(V);
           return !EE || isa<FixedVectorType>(EE->getVectorOperandType());
@@ -4721,7 +4716,7 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
     // might be transformed.
     int Sz = TE.Scalars.size();
     if (isSplat(TE.Scalars) && !allConstant(TE.Scalars) &&
-        count_if(TE.Scalars, UndefValue::classof) == Sz - 1) {
+        count_if(TE.Scalars, IsaPred<UndefValue>) == Sz - 1) {
       const auto *It =
           find_if(TE.Scalars, [](Value *V) { return !isConstant(V); });
       if (It == TE.Scalars.begin())
@@ -6345,11 +6340,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
       UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize &&
       !(S.getOpcode() && allSameBlock(VL))) {
     assert(S.OpValue->getType()->isPointerTy() &&
-           count_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }) >=
-               2 &&
+           count_if(VL, IsaPred<GetElementPtrInst>) >= 2 &&
            "Expected pointers only.");
     // Reset S to make it GetElementPtr kind of node.
-    const auto *It = find_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); });
+    const auto *It = find_if(VL, IsaPred<GetElementPtrInst>);
     assert(It != VL.end() && "Expected at least one GEP.");
     S = getSameOpcode(*It, *TLI);
   }
@@ -6893,17 +6887,12 @@ unsigned BoUpSLP::canMapToVector(Type *T) const {
 bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue,
                               SmallVectorImpl<unsigned> &CurrentOrder,
                               bool ResizeAllowed) const {
-  const auto *It = find_if(VL, [](Value *V) {
-    return isa<ExtractElementInst, ExtractValueInst>(V);
-  });
+  const auto *It = find_if(VL, IsaPred<ExtractElementInst, ExtractValueInst>);
   assert(It != VL.end() && "Expected at least one extract instruction.");
   auto *E0 = cast<Instruction>(*It);
-  assert(all_of(VL,
-                [](Value *V) {
-                  return isa<UndefValue, ExtractElementInst, ExtractValueInst>(
-                      V);
-                }) &&
-         "Invalid opcode");
+  assert(
+      all_of(VL, IsaPred<UndefValue, ExtractElementInst, ExtractValueInst>) &&
+      "Invalid opcode");
   // Check if all of the extracts come from the same vector and from the
   // correct offset.
   Value *Vec = E0->getOperand(0);
@@ -7575,7 +7564,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
   }
 
   InstructionCost getBuildVectorCost(ArrayRef<Value *> VL, Value *Root) {
-    if ((!Root && allConstant(VL)) || all_of(VL, UndefValue::classof))
+    if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>))
       return TTI::TCC_Free;
     auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size());
     InstructionCost GatherCost = 0;
@@ -7743,13 +7732,12 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     } else if (!Root && isSplat(VL)) {
       // Found the broadcasting of the single scalar, calculate the cost as
       // the broadcast.
-      const auto *It =
-          find_if(VL, [](Value *V) { return !isa<UndefValue>(V); });
+      const auto *It = find_if_not(VL, IsaPred<UndefValue>);
       assert(It != VL.end() && "Expected at least one non-undef value.");
       // Add broadcast for non-identity shuffle only.
       bool NeedShuffle =
           count(VL, *It) > 1 &&
-          (VL.front() != *It || !all_of(VL.drop_front(), UndefValue::classof));
+          (VL.front() != *It || !all_of(VL.drop_front(), IsaPred<UndefValue>));
       if (!NeedShuffle)
         return TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
                                       CostKind, std::distance(VL.begin(), It),
@@ -7757,7 +7745,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
 
       SmallVector<int> ShuffleMask(VL.size(), PoisonMaskElem);
       transform(VL, ShuffleMask.begin(), [](Value *V) {
-        return isa<PoisonValue>(V) ? PoisonMaskElem : 0;   
+        return isa<PoisonValue>(V) ? PoisonMaskElem : 0;
       });
       InstructionCost InsertCost = TTI.getVectorInstrCost(
           Instruction::InsertElement, VecTy, CostKind, 0,
@@ -7768,7 +7756,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
                                 /*SubTp=*/nullptr, /*Args=*/*It);
     }
     return GatherCost +
-           (all_of(Gathers, UndefValue::classof)
+           (all_of(Gathers, IsaPred<UndefValue>)
                 ? TTI::TCC_Free
                 : R.getGatherCost(Gathers, !Root && VL.equals(Gathers)));
   };
@@ -8178,9 +8166,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
         // Take credit for instruction that will become dead.
         if (EE->hasOneUse() || !PrevNodeFound) {
           Instruction *Ext = EE->user_back();
-          if (isa<SExtInst, ZExtInst>(Ext) && all_of(Ext->users(), [](User *U) {
-                return isa<GetElementPtrInst>(U);
-              })) {
+          if (isa<SExtInst, ZExtInst>(Ext) &&
+              all_of(Ext->users(), IsaPred<GetElementPtrInst>)) {
             // Use getExtractWithExtendCost() to calculate the cost of
             // extractelement/ext pair.
             Cost -=
@@ -8645,8 +8632,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       if (I->hasOneUse()) {
         Instruction *Ext = I->user_back();
         if ((isa<SExtInst>(Ext) || isa<ZExtInst>(Ext)) &&
-            all_of(Ext->users(),
-                   [](User *U) { return isa<GetElementPtrInst>(U); })) {
+            all_of(Ext->users(), IsaPred<GetElementPtrInst>)) {
           // Use getExtractWithExtendCost() to calculate the cost of
           // extractelement/ext pair.
           InstructionCost Cost = TTI->getExtractWithExtendCost(
@@ -9130,10 +9116,7 @@ bool BoUpSLP::isFullyVectorizableTinyTree(bool ForReduction) const {
            (allConstant(TE->Scalars) || isSplat(TE->Scalars) ||
             TE->Scalars.size() < Limit ||
             ((TE->getOpcode() == Instruction::ExtractElement ||
-              all_of(TE->Scalars,
-                     [](Value *V) {
-                       return isa<ExtractElementInst, UndefValue>(V);
-                     })) &&
+              all_of(TE->Scalars, IsaPred<ExtractElementInst, UndefValue>)) &&
              isFixedVectorShuffle(TE->Scalars, Mask)) ||
             (TE->State == TreeEntry::NeedToGather &&
              TE->getOpcode() == Instruction::Load && !TE->isAltShuffle()));
@@ -9254,9 +9237,7 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const {
       all_of(VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) {
         return (TE->State == TreeEntry::NeedToGather &&
                 TE->getOpcode() != Instruction::ExtractElement &&
-                count_if(TE->Scalars,
-                         [](Value *V) { return isa<ExtractElementInst>(V); }) <=
-                    Limit) ||
+                count_if(TE->Scalars, IsaPred<ExtractElementInst>) <= Limit) ||
                TE->getOpcode() == Instruction::PHI;
       }))
     return true;
@@ -9285,9 +9266,7 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const {
                  return isa<ExtractElementInst, UndefValue>(V) ||
                         (IsAllowedSingleBVNode &&
                          !V->hasNUsesOrMore(UsesLimit) &&
-                         any_of(V->users(), [](User *U) {
-                           return isa<InsertElementInst>(U);
-                         }));
+                         any_of(V->users(), IsaPred<InsertElementInst>));
                });
       }))
     return false;
@@ -10284,7 +10263,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
     }
   }
 
-  bool IsSplatOrUndefs = isSplat(VL) || all_of(VL, UndefValue::classof);
+  bool IsSplatOrUndefs = isSplat(VL) || all_of(VL, IsaPred<UndefValue>);
   // Checks if the 2 PHIs are compatible in terms of high possibility to be
   // vectorized.
   auto AreCompatiblePHIs = [&](Value *V, Value *V1) {
@@ -11261,8 +11240,7 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx,
   InstructionsState S = getSameOpcode(VL, *TLI);
   // Special processing for GEPs bundle, which may include non-gep values.
   if (!S.getOpcode() && VL.front()->getType()->isPointerTy()) {
-    const auto *It =
-        find_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); });
+    const auto *It = find_if(VL, IsaPred<GetElementPtrInst>);
     if (It != VL.end())
       S = getSameOpcode(*It, *TLI);
   }
@@ -11432,7 +11410,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
   unsigned NumParts = TTI->getNumberOfParts(VecTy);
   if (NumParts == 0 || NumParts >= GatheredScalars.size())
     NumParts = 1;
-  if (!all_of(GatheredScalars, UndefValue::classof)) {
+  if (!all_of(GatheredScalars, IsaPred<UndefValue>)) {
     // Check for gathered extracts.
     bool Resized = false;
     ExtractShuffles =
@@ -11757,7 +11735,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
         GatheredScalars[I] = PoisonValue::get(ScalarTy);
     }
     // Generate constants for final shuffle and build a mask for them.
-    if (!all_of(GatheredScalars, PoisonValue::classof)) {
+    if (!all_of(GatheredScalars, IsaPred<PoisonValue>)) {
       SmallVector<int> BVMask(GatheredScalars.size(), PoisonMaskElem);
       TryPackScalars(GatheredScalars, BVMask, /*IsRootPoison=*/true);
       Value *BV = ShuffleBuilder.gather(GatheredScalars, BVMask.size());
@@ -14509,7 +14487,7 @@ void BoUpSLP::computeMinimumValueSizes() {
                        return SIt != DemotedConsts.end() &&
                               is_contained(SIt->getSecond(), Idx);
                      }) ||
-              all_of(CTE->Scalars, Constant::classof))
+              all_of(CTE->Scalars, IsaPred<Constant>))
             MinBWs.try_emplace(CTE, MaxBitWidth, IsSigned);
         }
       }
@@ -15257,12 +15235,10 @@ class HorizontalReduction {
   static Value *createOp(IRBuilderBase &Builder, RecurKind RdxKind, Value *LHS,
                          Value *RHS, const Twine &Name,
                          const ReductionOpsListType &ReductionOps) {
-    bool UseSelect =
-        ReductionOps.size() == 2 ||
-        // Logical or/and.
-        (ReductionOps.size() == 1 && any_of(ReductionOps.front(), [](Value *V) {
-           return isa<SelectInst>(V);
-         }));
+    bool UseSelect = ReductionOps.size() == 2 ||
+                     // Logical or/and.
+                     (ReductionOps.size() == 1 &&
+                      any_of(ReductionOps.front(), IsaPred<SelectInst>));
     assert((!UseSelect || ReductionOps.size() != 2 ||
             isa<SelectInst>(ReductionOps[1][0])) &&
            "Expected cmp + select pairs for reduction");
@@ -15501,7 +15477,7 @@ class HorizontalReduction {
             !hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
             !isVectorizable(RdxKind, EdgeInst) ||
             (R.isAnalyzedReductionRoot(EdgeInst) &&
-             all_of(EdgeInst->operands(), Constant::classof))) {
+             all_of(EdgeInst->operands(), IsaPred<Constant>))) {
           PossibleReducedVals.push_back(EdgeVal);
           continue;
         }
@@ -16857,9 +16833,7 @@ bool SLPVectorizerPass::vectorizeInsertElementInst(InsertElementInst *IEI,
   SmallVector<Value *, 16> BuildVectorOpds;
   SmallVector<int> Mask;
   if (!findBuildAggregate(IEI, TTI, BuildVectorOpds, BuildVectorInsts) ||
-      (llvm::all_of(
-           BuildVectorOpds,
-           [](Value *V) { return isa<ExtractElementInst, UndefValue>(V); }) &&
+      (llvm::all_of(BuildVectorOpds, IsaPred<ExtractElementInst, UndefValue>) &&
        isFixedVectorShuffle(BuildVectorOpds, Mask)))
     return false;
 
@@ -17080,10 +17054,7 @@ bool SLPVectorizerPass::vectorizeCmpInsts(iterator_range<ItT> CmpInsts,
 
 bool SLPVectorizerPass::vectorizeInserts(InstSetVector &Instructions,
                                          BasicBlock *BB, BoUpSLP &R) {
-  assert(all_of(Instructions,
-                [](auto *I) {
-                  return isa<InsertElementInst, InsertValueInst>(I);
-                }) &&
+  assert(all_of(Instructions, IsaPred<InsertElementInst, InsertValueInst>) &&
          "This function only accepts Insert instructions");
   bool OpsChanged = false;
   SmallVector<WeakTrackingVH> PostponedInsts;



More information about the llvm-commits mailing list