[llvm] 6a74b0e - [SLP] Use early-return in canVectorizeLoads [nfc]

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 27 12:30:52 PDT 2024


Author: Philip Reames
Date: 2024-08-27T12:30:15-07:00
New Revision: 6a74b0ee591db817543988c3ce6c346741eddd54

URL: https://github.com/llvm/llvm-project/commit/6a74b0ee591db817543988c3ce6c346741eddd54
DIFF: https://github.com/llvm/llvm-project/commit/6a74b0ee591db817543988c3ce6c346741eddd54.diff

LOG: [SLP] Use early-return in canVectorizeLoads [nfc]

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 8309caa3ba1afa..0bb3f1d202308f 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -4707,209 +4707,211 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
       TTI->isLegalStridedLoadStore(VecTy, CommonAlignment) &&
       calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order))
     return LoadsState::StridedVectorize;
-  if (IsSorted || all_of(PointerOps, [&](Value *P) {
+
+  if (!IsSorted && !all_of(PointerOps, [&](Value *P) {
         return arePointersCompatible(P, PointerOps.front(), *TLI);
-      })) {
-    if (IsSorted) {
-      Value *Ptr0;
-      Value *PtrN;
-      if (Order.empty()) {
-        Ptr0 = PointerOps.front();
-        PtrN = PointerOps.back();
-      } else {
-        Ptr0 = PointerOps[Order.front()];
-        PtrN = PointerOps[Order.back()];
-      }
-      std::optional<int> Diff =
-          getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE);
-      // Check that the sorted loads are consecutive.
-      if (static_cast<unsigned>(*Diff) == Sz - 1)
-        return LoadsState::Vectorize;
-      // Simple check if not a strided access - clear order.
-      bool IsPossibleStrided = *Diff % (Sz - 1) == 0;
-      // Try to generate strided load node if:
-      // 1. Target with strided load support is detected.
-      // 2. The number of loads is greater than MinProfitableStridedLoads,
-      // or the potential stride <= MaxProfitableLoadStride and the
-      // potential stride is power-of-2 (to avoid perf regressions for the very
-      // small number of loads) and max distance > number of loads, or potential
-      // stride is -1.
-      // 3. The loads are ordered, or number of unordered loads <=
-      // MaxProfitableUnorderedLoads, or loads are in reversed order.
-      // (this check is to avoid extra costs for very expensive shuffles).
-      // 4. Any pointer operand is an instruction with the users outside of the
-      // current graph (for masked gathers extra extractelement instructions
-      // might be required).
-      auto IsAnyPointerUsedOutGraph =
-          IsPossibleStrided && any_of(PointerOps, [&](Value *V) {
-            return isa<Instruction>(V) && any_of(V->users(), [&](User *U) {
-                     return !getTreeEntry(U) && !MustGather.contains(U);
-                   });
-          });
-      const unsigned AbsoluteDiff = std::abs(*Diff);
-      if (IsPossibleStrided &&
-          (IsAnyPointerUsedOutGraph ||
-           ((Sz > MinProfitableStridedLoads ||
-             (AbsoluteDiff <= MaxProfitableLoadStride * Sz &&
-              has_single_bit(AbsoluteDiff))) &&
-            AbsoluteDiff > Sz) ||
-           *Diff == -(static_cast<int>(Sz) - 1))) {
-        int Stride = *Diff / static_cast<int>(Sz - 1);
-        if (*Diff == Stride * static_cast<int>(Sz - 1)) {
-          Align Alignment =
-              cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
-                  ->getAlign();
-          if (TTI->isLegalStridedLoadStore(VecTy, Alignment)) {
-            // Iterate through all pointers and check if all distances are
-            // unique multiple of Dist.
-            SmallSet<int, 4> Dists;
-            for (Value *Ptr : PointerOps) {
-              int Dist = 0;
-              if (Ptr == PtrN)
-                Dist = *Diff;
-              else if (Ptr != Ptr0)
-                Dist =
-                    *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
-              // If the strides are not the same or repeated, we can't
-              // vectorize.
-              if (((Dist / Stride) * Stride) != Dist ||
-                  !Dists.insert(Dist).second)
-                break;
-            }
-            if (Dists.size() == Sz)
-              return LoadsState::StridedVectorize;
+      }))
+    return LoadsState::Gather;
+
+  if (IsSorted) {
+    Value *Ptr0;
+    Value *PtrN;
+    if (Order.empty()) {
+      Ptr0 = PointerOps.front();
+      PtrN = PointerOps.back();
+    } else {
+      Ptr0 = PointerOps[Order.front()];
+      PtrN = PointerOps[Order.back()];
+    }
+    std::optional<int> Diff =
+        getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE);
+    // Check that the sorted loads are consecutive.
+    if (static_cast<unsigned>(*Diff) == Sz - 1)
+      return LoadsState::Vectorize;
+    // Simple check if not a strided access - clear order.
+    bool IsPossibleStrided = *Diff % (Sz - 1) == 0;
+    // Try to generate strided load node if:
+    // 1. Target with strided load support is detected.
+    // 2. The number of loads is greater than MinProfitableStridedLoads,
+    // or the potential stride <= MaxProfitableLoadStride and the
+    // potential stride is power-of-2 (to avoid perf regressions for the very
+    // small number of loads) and max distance > number of loads, or potential
+    // stride is -1.
+    // 3. The loads are ordered, or number of unordered loads <=
+    // MaxProfitableUnorderedLoads, or loads are in reversed order.
+    // (this check is to avoid extra costs for very expensive shuffles).
+    // 4. Any pointer operand is an instruction with the users outside of the
+    // current graph (for masked gathers extra extractelement instructions
+    // might be required).
+    auto IsAnyPointerUsedOutGraph =
+        IsPossibleStrided && any_of(PointerOps, [&](Value *V) {
+          return isa<Instruction>(V) && any_of(V->users(), [&](User *U) {
+                   return !getTreeEntry(U) && !MustGather.contains(U);
+                 });
+        });
+    const unsigned AbsoluteDiff = std::abs(*Diff);
+    if (IsPossibleStrided &&
+        (IsAnyPointerUsedOutGraph ||
+         ((Sz > MinProfitableStridedLoads ||
+           (AbsoluteDiff <= MaxProfitableLoadStride * Sz &&
+            has_single_bit(AbsoluteDiff))) &&
+          AbsoluteDiff > Sz) ||
+         *Diff == -(static_cast<int>(Sz) - 1))) {
+      int Stride = *Diff / static_cast<int>(Sz - 1);
+      if (*Diff == Stride * static_cast<int>(Sz - 1)) {
+        Align Alignment =
+            cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
+                ->getAlign();
+        if (TTI->isLegalStridedLoadStore(VecTy, Alignment)) {
+          // Iterate through all pointers and check if all distances are
+          // unique multiple of Dist.
+          SmallSet<int, 4> Dists;
+          for (Value *Ptr : PointerOps) {
+            int Dist = 0;
+            if (Ptr == PtrN)
+              Dist = *Diff;
+            else if (Ptr != Ptr0)
+              Dist =
+                  *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
+            // If the strides are not the same or repeated, we can't
+            // vectorize.
+            if (((Dist / Stride) * Stride) != Dist ||
+                !Dists.insert(Dist).second)
+              break;
           }
+          if (Dists.size() == Sz)
+            return LoadsState::StridedVectorize;
         }
       }
     }
-    auto CheckForShuffledLoads = [&, &TTI = *TTI](Align CommonAlignment) {
-      unsigned Sz = DL->getTypeSizeInBits(ScalarTy);
-      unsigned MinVF = getMinVF(Sz);
-      unsigned MaxVF = std::max<unsigned>(bit_floor(VL.size() / 2), MinVF);
-      MaxVF = std::min(getMaximumVF(Sz, Instruction::Load), MaxVF);
-      for (unsigned VF = MaxVF; VF >= MinVF; VF /= 2) {
-        unsigned VectorizedCnt = 0;
-        SmallVector<LoadsState> States;
-        for (unsigned Cnt = 0, End = VL.size(); Cnt + VF <= End;
-             Cnt += VF, ++VectorizedCnt) {
-          ArrayRef<Value *> Slice = VL.slice(Cnt, VF);
-          SmallVector<unsigned> Order;
-          SmallVector<Value *> PointerOps;
-          LoadsState LS =
-              canVectorizeLoads(Slice, Slice.front(), Order, PointerOps,
-                                /*TryRecursiveCheck=*/false);
-          // Check that the sorted loads are consecutive.
-          if (LS == LoadsState::Gather)
+  }
+  auto CheckForShuffledLoads = [&, &TTI = *TTI](Align CommonAlignment) {
+    unsigned Sz = DL->getTypeSizeInBits(ScalarTy);
+    unsigned MinVF = getMinVF(Sz);
+    unsigned MaxVF = std::max<unsigned>(bit_floor(VL.size() / 2), MinVF);
+    MaxVF = std::min(getMaximumVF(Sz, Instruction::Load), MaxVF);
+    for (unsigned VF = MaxVF; VF >= MinVF; VF /= 2) {
+      unsigned VectorizedCnt = 0;
+      SmallVector<LoadsState> States;
+      for (unsigned Cnt = 0, End = VL.size(); Cnt + VF <= End;
+           Cnt += VF, ++VectorizedCnt) {
+        ArrayRef<Value *> Slice = VL.slice(Cnt, VF);
+        SmallVector<unsigned> Order;
+        SmallVector<Value *> PointerOps;
+        LoadsState LS =
+            canVectorizeLoads(Slice, Slice.front(), Order, PointerOps,
+                              /*TryRecursiveCheck=*/false);
+        // Check that the sorted loads are consecutive.
+        if (LS == LoadsState::Gather)
+          break;
+        // If need the reorder - consider as high-cost masked gather for now.
+        if ((LS == LoadsState::Vectorize ||
+             LS == LoadsState::StridedVectorize) &&
+            !Order.empty() && !isReverseOrder(Order))
+          LS = LoadsState::ScatterVectorize;
+        States.push_back(LS);
+      }
+      // Can be vectorized later as a serie of loads/insertelements.
+      if (VectorizedCnt == VL.size() / VF) {
+        // Compare masked gather cost and loads + insersubvector costs.
+        TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+        auto [ScalarGEPCost, VectorGEPCost] = getGEPCosts(
+            TTI, PointerOps, PointerOps.front(), Instruction::GetElementPtr,
+            CostKind, ScalarTy, VecTy);
+        InstructionCost MaskedGatherCost =
+            TTI.getGatherScatterOpCost(
+                Instruction::Load, VecTy,
+                cast<LoadInst>(VL0)->getPointerOperand(),
+                /*VariableMask=*/false, CommonAlignment, CostKind) +
+            VectorGEPCost - ScalarGEPCost;
+        InstructionCost VecLdCost = 0;
+        auto *SubVecTy = getWidenedType(ScalarTy, VF);
+        for (auto [I, LS] : enumerate(States)) {
+          auto *LI0 = cast<LoadInst>(VL[I * VF]);
+          switch (LS) {
+          case LoadsState::Vectorize: {
+            auto [ScalarGEPCost, VectorGEPCost] =
+                getGEPCosts(TTI, ArrayRef(PointerOps).slice(I * VF, VF),
+                            LI0->getPointerOperand(), Instruction::Load,
+                            CostKind, ScalarTy, SubVecTy);
+            VecLdCost += TTI.getMemoryOpCost(
+                             Instruction::Load, SubVecTy, LI0->getAlign(),
+                             LI0->getPointerAddressSpace(), CostKind,
+                             TTI::OperandValueInfo()) +
+                         VectorGEPCost - ScalarGEPCost;
             break;
-          // If need the reorder - consider as high-cost masked gather for now.
-          if ((LS == LoadsState::Vectorize ||
-               LS == LoadsState::StridedVectorize) &&
-              !Order.empty() && !isReverseOrder(Order))
-            LS = LoadsState::ScatterVectorize;
-          States.push_back(LS);
-        }
-        // Can be vectorized later as a serie of loads/insertelements.
-        if (VectorizedCnt == VL.size() / VF) {
-          // Compare masked gather cost and loads + insersubvector costs.
-          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
-          auto [ScalarGEPCost, VectorGEPCost] = getGEPCosts(
-              TTI, PointerOps, PointerOps.front(), Instruction::GetElementPtr,
-              CostKind, ScalarTy, VecTy);
-          InstructionCost MaskedGatherCost =
-              TTI.getGatherScatterOpCost(
-                  Instruction::Load, VecTy,
-                  cast<LoadInst>(VL0)->getPointerOperand(),
-                  /*VariableMask=*/false, CommonAlignment, CostKind) +
-              VectorGEPCost - ScalarGEPCost;
-          InstructionCost VecLdCost = 0;
-          auto *SubVecTy = getWidenedType(ScalarTy, VF);
-          for (auto [I, LS] : enumerate(States)) {
-            auto *LI0 = cast<LoadInst>(VL[I * VF]);
-            switch (LS) {
-            case LoadsState::Vectorize: {
-              auto [ScalarGEPCost, VectorGEPCost] =
-                  getGEPCosts(TTI, ArrayRef(PointerOps).slice(I * VF, VF),
-                              LI0->getPointerOperand(), Instruction::Load,
-                              CostKind, ScalarTy, SubVecTy);
-              VecLdCost += TTI.getMemoryOpCost(
-                               Instruction::Load, SubVecTy, LI0->getAlign(),
-                               LI0->getPointerAddressSpace(), CostKind,
-                               TTI::OperandValueInfo()) +
-                           VectorGEPCost - ScalarGEPCost;
-              break;
-            }
-            case LoadsState::StridedVectorize: {
-              auto [ScalarGEPCost, VectorGEPCost] =
-                  getGEPCosts(TTI, ArrayRef(PointerOps).slice(I * VF, VF),
-                              LI0->getPointerOperand(), Instruction::Load,
-                              CostKind, ScalarTy, SubVecTy);
-              VecLdCost +=
-                  TTI.getStridedMemoryOpCost(
-                      Instruction::Load, SubVecTy, LI0->getPointerOperand(),
-                      /*VariableMask=*/false, CommonAlignment, CostKind) +
-                  VectorGEPCost - ScalarGEPCost;
-              break;
-            }
-            case LoadsState::ScatterVectorize: {
-              auto [ScalarGEPCost, VectorGEPCost] = getGEPCosts(
-                  TTI, ArrayRef(PointerOps).slice(I * VF, VF),
-                  LI0->getPointerOperand(), Instruction::GetElementPtr,
-                  CostKind, ScalarTy, SubVecTy);
-              VecLdCost +=
-                  TTI.getGatherScatterOpCost(
-                      Instruction::Load, SubVecTy, LI0->getPointerOperand(),
-                      /*VariableMask=*/false, CommonAlignment, CostKind) +
-                  VectorGEPCost - ScalarGEPCost;
-              break;
-            }
-            case LoadsState::Gather:
-              llvm_unreachable(
-                  "Expected only consecutive, strided or masked gather loads.");
-            }
-            SmallVector<int> ShuffleMask(VL.size());
-            for (int Idx : seq<int>(0, VL.size()))
-              ShuffleMask[Idx] = Idx / VF == I ? VL.size() + Idx % VF : Idx;
+          }
+          case LoadsState::StridedVectorize: {
+            auto [ScalarGEPCost, VectorGEPCost] =
+                getGEPCosts(TTI, ArrayRef(PointerOps).slice(I * VF, VF),
+                            LI0->getPointerOperand(), Instruction::Load,
+                            CostKind, ScalarTy, SubVecTy);
             VecLdCost +=
-                ::getShuffleCost(TTI, TTI::SK_InsertSubvector, VecTy,
-                                 ShuffleMask, CostKind, I * VF, SubVecTy);
+                TTI.getStridedMemoryOpCost(
+                    Instruction::Load, SubVecTy, LI0->getPointerOperand(),
+                    /*VariableMask=*/false, CommonAlignment, CostKind) +
+                VectorGEPCost - ScalarGEPCost;
+            break;
           }
-          // If masked gather cost is higher - better to vectorize, so
-          // consider it as a gather node. It will be better estimated
-          // later.
-          if (MaskedGatherCost >= VecLdCost)
-            return true;
+          case LoadsState::ScatterVectorize: {
+            auto [ScalarGEPCost, VectorGEPCost] = getGEPCosts(
+                TTI, ArrayRef(PointerOps).slice(I * VF, VF),
+                LI0->getPointerOperand(), Instruction::GetElementPtr,
+                CostKind, ScalarTy, SubVecTy);
+            VecLdCost +=
+                TTI.getGatherScatterOpCost(
+                    Instruction::Load, SubVecTy, LI0->getPointerOperand(),
+                    /*VariableMask=*/false, CommonAlignment, CostKind) +
+                VectorGEPCost - ScalarGEPCost;
+            break;
+          }
+          case LoadsState::Gather:
+            llvm_unreachable(
+                "Expected only consecutive, strided or masked gather loads.");
+          }
+          SmallVector<int> ShuffleMask(VL.size());
+          for (int Idx : seq<int>(0, VL.size()))
+            ShuffleMask[Idx] = Idx / VF == I ? VL.size() + Idx % VF : Idx;
+          VecLdCost +=
+              ::getShuffleCost(TTI, TTI::SK_InsertSubvector, VecTy,
+                               ShuffleMask, CostKind, I * VF, SubVecTy);
         }
+        // If masked gather cost is higher - better to vectorize, so
+        // consider it as a gather node. It will be better estimated
+        // later.
+        if (MaskedGatherCost >= VecLdCost)
+          return true;
       }
-      return false;
-    };
-    // TODO: need to improve analysis of the pointers, if not all of them are
-    // GEPs or have > 2 operands, we end up with a gather node, which just
-    // increases the cost.
-    Loop *L = LI->getLoopFor(cast<LoadInst>(VL0)->getParent());
-    bool ProfitableGatherPointers =
-        L && Sz > 2 &&
-        static_cast<unsigned>(count_if(PointerOps, [L](Value *V) {
-          return L->isLoopInvariant(V);
-        })) <= Sz / 2;
-    if (ProfitableGatherPointers || all_of(PointerOps, [IsSorted](Value *P) {
-          auto *GEP = dyn_cast<GetElementPtrInst>(P);
-          return (IsSorted && !GEP && doesNotNeedToBeScheduled(P)) ||
-                 (GEP && GEP->getNumOperands() == 2 &&
-                  isa<Constant, Instruction>(GEP->getOperand(1)));
-        })) {
-      Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
-      if (TTI->isLegalMaskedGather(VecTy, CommonAlignment) &&
-          !TTI->forceScalarizeMaskedGather(VecTy, CommonAlignment)) {
-        // Check if potential masked gather can be represented as series
-        // of loads + insertsubvectors.
-        if (TryRecursiveCheck && CheckForShuffledLoads(CommonAlignment)) {
-          // If masked gather cost is higher - better to vectorize, so
-          // consider it as a gather node. It will be better estimated
-          // later.
-          return LoadsState::Gather;
-        }
-        return LoadsState::ScatterVectorize;
+    }
+    return false;
+  };
+  // TODO: need to improve analysis of the pointers, if not all of them are
+  // GEPs or have > 2 operands, we end up with a gather node, which just
+  // increases the cost.
+  Loop *L = LI->getLoopFor(cast<LoadInst>(VL0)->getParent());
+  bool ProfitableGatherPointers =
+      L && Sz > 2 &&
+      static_cast<unsigned>(count_if(PointerOps, [L](Value *V) {
+        return L->isLoopInvariant(V);
+      })) <= Sz / 2;
+  if (ProfitableGatherPointers || all_of(PointerOps, [IsSorted](Value *P) {
+        auto *GEP = dyn_cast<GetElementPtrInst>(P);
+        return (IsSorted && !GEP && doesNotNeedToBeScheduled(P)) ||
+               (GEP && GEP->getNumOperands() == 2 &&
+                isa<Constant, Instruction>(GEP->getOperand(1)));
+      })) {
+    Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
+    if (TTI->isLegalMaskedGather(VecTy, CommonAlignment) &&
+        !TTI->forceScalarizeMaskedGather(VecTy, CommonAlignment)) {
+      // Check if potential masked gather can be represented as series
+      // of loads + insertsubvectors.
+      if (TryRecursiveCheck && CheckForShuffledLoads(CommonAlignment)) {
+        // If masked gather cost is higher - better to vectorize, so
+        // consider it as a gather node. It will be better estimated
+        // later.
+        return LoadsState::Gather;
       }
+      return LoadsState::ScatterVectorize;
     }
   }
 


        


More information about the llvm-commits mailing list