[llvm] [SLP][TTI]Add support for strided loads. (PR #80310)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 1 09:23:13 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

<details>
<summary>Changes</summary>

Added basic support for strided memory operations in TTI and strided
loads particularly in SLP vectorizer. Supports both runtime and constant
strides. If the strided load must be reversed, applies -stride to avoid
extra reverse shuffle.


---

Patch is 93.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80310.diff


12 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+34) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+13) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+14) 
- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+23) 
- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (+23) 
- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+278-119) 
- (modified) llvm/test/Transforms/SLPVectorizer/RISCV/complex-loads.ll (+65-67) 
- (modified) llvm/test/Transforms/SLPVectorizer/RISCV/strided-loads-vectorized.ll (+19-190) 
- (modified) llvm/test/Transforms/SLPVectorizer/RISCV/strided-loads-with-external-use-ptr.ll (+2-2) 
- (modified) llvm/test/Transforms/SLPVectorizer/RISCV/strided-loads.ll (+5-8) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/gep-nodes-with-non-gep-inst.ll (+1-1) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/remark_gather-load-redux-cost.ll (+1-1) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 3b615bc700bbb..b0b6dab03fa38 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -781,6 +781,9 @@ class TargetTransformInfo {
   /// Return true if the target supports masked expand load.
   bool isLegalMaskedExpandLoad(Type *DataType) const;
 
+  /// Return true if the target supports strided load.
+  bool isLegalStridedLoad(Type *DataType, Align Alignment) const;
+
   /// Return true if this is an alternating opcode pattern that can be lowered
   /// to a single instruction on the target. In X86 this is for the addsub
   /// instruction which corrsponds to a Shuffle + Fadd + FSub pattern in IR.
@@ -1412,6 +1415,20 @@ class TargetTransformInfo {
       Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
       const Instruction *I = nullptr) const;
 
+  /// \return The cost of strided memory operations.
+  /// \p Opcode - is a type of memory access Load or Store
+  /// \p DataTy - a vector type of the data to be loaded or stored
+  /// \p Ptr - pointer [or vector of pointers] - address[es] in memory
+  /// \p VariableMask - true when the memory access is predicated with a mask
+  ///                   that is not a compile-time constant
+  /// \p Alignment - alignment of single element
+  /// \p I - the optional original context instruction, if one exists, e.g. the
+  ///        load/store to transform or the call to the gather/scatter intrinsic
+  InstructionCost getStridedMemoryOpCost(
+      unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+      Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput,
+      const Instruction *I = nullptr) const;
+
   /// \return The cost of the interleaved memory operation.
   /// \p Opcode is the memory operation code
   /// \p VecTy is the vector type of the interleaved access.
@@ -1848,6 +1865,7 @@ class TargetTransformInfo::Concept {
                                            Align Alignment) = 0;
   virtual bool isLegalMaskedCompressStore(Type *DataType) = 0;
   virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0;
+  virtual bool isLegalStridedLoad(Type *DataType, Align Alignment) = 0;
   virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0,
                                unsigned Opcode1,
                                const SmallBitVector &OpcodeMask) const = 0;
@@ -2023,6 +2041,11 @@ class TargetTransformInfo::Concept {
                          bool VariableMask, Align Alignment,
                          TTI::TargetCostKind CostKind,
                          const Instruction *I = nullptr) = 0;
+  virtual InstructionCost
+  getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
+                         bool VariableMask, Align Alignment,
+                         TTI::TargetCostKind CostKind,
+                         const Instruction *I = nullptr) = 0;
 
   virtual InstructionCost getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
@@ -2341,6 +2364,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   bool isLegalMaskedExpandLoad(Type *DataType) override {
     return Impl.isLegalMaskedExpandLoad(DataType);
   }
+  bool isLegalStridedLoad(Type *DataType, Align Alignment) override {
+    return Impl.isLegalStridedLoad(DataType, Alignment);
+  }
   bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1,
                        const SmallBitVector &OpcodeMask) const override {
     return Impl.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask);
@@ -2671,6 +2697,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
                                        Alignment, CostKind, I);
   }
+  InstructionCost
+  getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr,
+                         bool VariableMask, Align Alignment,
+                         TTI::TargetCostKind CostKind,
+                         const Instruction *I = nullptr) override {
+    return Impl.getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+                                       Alignment, CostKind, I);
+  }
   InstructionCost getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 9958b4daa6ed8..2a7e7b364ac40 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -304,6 +304,10 @@ class TargetTransformInfoImplBase {
 
   bool isLegalMaskedExpandLoad(Type *DataType) const { return false; }
 
+  bool isLegalStridedLoad(Type *DataType, Align Alignment) const {
+    return false;
+  }
+
   bool enableOrderedReductions() const { return false; }
 
   bool hasDivRemOp(Type *DataType, bool IsSigned) const { return false; }
@@ -687,6 +691,15 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
+                                         const Value *Ptr, bool VariableMask,
+                                         Align Alignment,
+                                         TTI::TargetCostKind CostKind,
+                                         const Instruction *I = nullptr) const {
+    return CostKind == TTI::TCK_RecipThroughput ? TTI::TCC_Expensive
+                                                : TTI::TCC_Basic;
+  }
+
   unsigned getInterleavedMemoryOpCost(
       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 8902dde37cbca..b86397ae7d267 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -500,6 +500,11 @@ bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const {
   return TTIImpl->isLegalMaskedExpandLoad(DataType);
 }
 
+bool TargetTransformInfo::isLegalStridedLoad(Type *DataType,
+                                             Align Alignment) const {
+  return TTIImpl->isLegalStridedLoad(DataType, Alignment);
+}
+
 bool TargetTransformInfo::enableOrderedReductions() const {
   return TTIImpl->enableOrderedReductions();
 }
@@ -1041,6 +1046,15 @@ InstructionCost TargetTransformInfo::getGatherScatterOpCost(
   return Cost;
 }
 
+InstructionCost TargetTransformInfo::getStridedMemoryOpCost(
+    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
+  InstructionCost Cost = TTIImpl->getStridedMemoryOpCost(
+      Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I);
+  assert(Cost >= 0 && "TTI should not produce negative costs!");
+  return Cost;
+}
+
 InstructionCost TargetTransformInfo::getInterleavedMemoryOpCost(
     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index fe1cdb2dfa423..9cec8ee4cb7f2 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -658,6 +658,29 @@ InstructionCost RISCVTTIImpl::getGatherScatterOpCost(
   return NumLoads * MemOpCost;
 }
 
+InstructionCost RISCVTTIImpl::getStridedMemoryOpCost(
+    unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
+    Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
+  if (CostKind != TTI::TCK_RecipThroughput)
+    return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+                                         Alignment, CostKind, I);
+
+  if ((Opcode == Instruction::Load && !isLegalStridedLoad(DataTy, Alignment)) ||
+      Opcode != Instruction::Load)
+    return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask,
+                                         Alignment, CostKind, I);
+
+  // Cost is proportional to the number of memory operations implied.  For
+  // scalable vectors, we use an estimate on that number since we don't
+  // know exactly what VL will be.
+  auto &VTy = *cast<VectorType>(DataTy);
+  InstructionCost MemOpCost =
+      getMemoryOpCost(Opcode, VTy.getElementType(), Alignment, 0, CostKind,
+                      {TTI::OK_AnyValue, TTI::OP_None}, I);
+  unsigned NumLoads = getEstimatedVLFor(&VTy);
+  return NumLoads * MemOpCost;
+}
+
 // Currently, these represent both throughput and codesize costs
 // for the respective intrinsics.  The costs in this table are simply
 // instruction counts with the following adjustments made:
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 0747a778fe9a2..742b1aadf00bd 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -143,6 +143,12 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
                                          TTI::TargetCostKind CostKind,
                                          const Instruction *I);
 
+  InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
+                                         const Value *Ptr, bool VariableMask,
+                                         Align Alignment,
+                                         TTI::TargetCostKind CostKind,
+                                         const Instruction *I);
+
   InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
                                    TTI::CastContextHint CCH,
                                    TTI::TargetCostKind CostKind,
@@ -250,6 +256,23 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
     return ST->is64Bit() && !ST->hasVInstructionsI64();
   }
 
+  bool isLegalStridedLoad(Type *DataType, Align Alignment) {
+    if (!ST->hasVInstructions())
+      return false;
+
+    EVT DataTypeVT = TLI->getValueType(DL, DataType);
+
+    // Only support fixed vectors if we know the minimum vector size.
+    if (DataTypeVT.isFixedLengthVector() && !ST->useRVVForFixedLengthVectors())
+      return false;
+
+    EVT ElemType = DataTypeVT.getScalarType();
+    if (!ST->hasFastUnalignedAccess() && Alignment < ElemType.getStoreSize())
+      return false;
+
+    return TLI->isLegalElementTypeForRVV(ElemType);
+  }
+
   bool isVScaleKnownToBeAPowerOfTwo() const {
     return TLI->isVScaleKnownToBeAPowerOfTwo();
   }
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index a8aea112bc28e..90b9b51c470bf 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -87,6 +87,7 @@
 #include "llvm/Transforms/Utils/InjectTLIMappings.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
 #include <algorithm>
 #include <cassert>
 #include <cstdint>
@@ -175,6 +176,15 @@ static cl::opt<int> RootLookAheadMaxDepth(
     "slp-max-root-look-ahead-depth", cl::init(2), cl::Hidden,
     cl::desc("The maximum look-ahead depth for searching best rooting option"));
 
+static cl::opt<unsigned> MinProfitableStridedLoads(
+    "slp-min-strided-loads", cl::init(2), cl::Hidden,
+    cl::desc("The minimum number of loads, which should be considered strided, "
+             "if the stride is > 1 or is runtime value"));
+
+static cl::opt<unsigned> MaxProfitableLoadStride(
+    "slp-max-stride", cl::init(8), cl::Hidden,
+    cl::desc("The maximum stride, considered to be profitable."));
+
 static cl::opt<bool>
     ViewSLPTree("view-slp-tree", cl::Hidden,
                 cl::desc("Display the SLP trees with Graphviz"));
@@ -2575,7 +2585,7 @@ class BoUpSLP {
     enum EntryState {
       Vectorize,
       ScatterVectorize,
-      PossibleStridedVectorize,
+      StridedVectorize,
       NeedToGather
     };
     EntryState State;
@@ -2753,8 +2763,8 @@ class BoUpSLP {
       case ScatterVectorize:
         dbgs() << "ScatterVectorize\n";
         break;
-      case PossibleStridedVectorize:
-        dbgs() << "PossibleStridedVectorize\n";
+      case StridedVectorize:
+        dbgs() << "StridedVectorize\n";
         break;
       case NeedToGather:
         dbgs() << "NeedToGather\n";
@@ -3680,7 +3690,7 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits {
     if (Entry->State == TreeEntry::NeedToGather)
       return "color=red";
     if (Entry->State == TreeEntry::ScatterVectorize ||
-        Entry->State == TreeEntry::PossibleStridedVectorize)
+        Entry->State == TreeEntry::StridedVectorize)
       return "color=blue";
     return "";
   }
@@ -3846,7 +3856,7 @@ enum class LoadsState {
   Gather,
   Vectorize,
   ScatterVectorize,
-  PossibleStridedVectorize
+  StridedVectorize
 };
 } // anonymous namespace
 
@@ -3878,6 +3888,130 @@ static Align computeCommonAlignment(ArrayRef<Value *> VL) {
   return CommonAlignment;
 }
 
+/// Check if \p Order represents reverse order.
+static bool isReverseOrder(ArrayRef<unsigned> Order) {
+  unsigned Sz = Order.size();
+  return !Order.empty() && all_of(enumerate(Order), [&](const auto &Pair) {
+    return Pair.value() == Sz || Sz - Pair.index() - 1 == Pair.value();
+  });
+}
+
+/// Checks if the provided list of pointers \p Pointers represents the strided
+/// pointers for type ElemTy. If they are not, std::nullopt is returned.
+/// Otherwise, if \p Inst is not specified, just initialized optional value is
+/// returned to show that the pointers represent strided pointers. If \p Inst
+/// specified, the runtime stride is materialized before the given \p Inst.
+/// \returns std::nullopt if the pointers are not pointers with the runtime
+/// stride, nullptr or actual stride value, otherwise.
+static std::optional<Value *>
+calculateRtStride(ArrayRef<Value *> PointerOps, Type *ElemTy,
+                  const DataLayout &DL, ScalarEvolution &SE,
+                  SmallVectorImpl<unsigned> &SortedIndices,
+                  Instruction *Inst = nullptr) {
+  SmallVector<const SCEV *> SCEVs;
+  const SCEV *PtrSCEVA = nullptr;
+  const SCEV *PtrSCEVB = nullptr;
+  for (Value *Ptr : PointerOps) {
+    const SCEV *PtrSCEV = SE.getSCEV(Ptr);
+    if (!PtrSCEV)
+      return std::nullopt;
+    SCEVs.push_back(PtrSCEV);
+    if (!PtrSCEVA && !PtrSCEVB) {
+      PtrSCEVA = PtrSCEVB = PtrSCEV;
+      continue;
+    }
+    const SCEV *Diff = SE.getMinusSCEV(PtrSCEV, PtrSCEVA);
+    if (!Diff || isa<SCEVCouldNotCompute>(Diff))
+      return std::nullopt;
+    if (Diff->isNonConstantNegative()) {
+      PtrSCEVA = PtrSCEV;
+      continue;
+    }
+    const SCEV *Diff1 = SE.getMinusSCEV(PtrSCEVB, PtrSCEV);
+    if (!Diff1 || isa<SCEVCouldNotCompute>(Diff1))
+      return std::nullopt;
+    if (Diff1->isNonConstantNegative()) {
+      PtrSCEVB = PtrSCEV;
+      continue;
+    }
+  }
+  const SCEV *Stride = SE.getMinusSCEV(PtrSCEVB, PtrSCEVA);
+  if (!Stride)
+    return std::nullopt;
+  int Size = DL.getTypeStoreSize(ElemTy);
+  auto TryGetStride = [&](const SCEV *Dist,
+                          const SCEV *Multiplier) -> const SCEV * {
+    if (const auto *M = dyn_cast<SCEVMulExpr>(Dist)) {
+      if (M->getOperand(0) == Multiplier)
+        return M->getOperand(1);
+      if (M->getOperand(1) == Multiplier)
+        return M->getOperand(0);
+      return nullptr;
+    }
+    if (Multiplier == Dist)
+      return SE.getConstant(Dist->getType(), 1);
+    return SE.getUDivExactExpr(Dist, Multiplier);
+  };
+  if (Size != 1 || SCEVs.size() > 2) {
+    const SCEV *Sz =
+        SE.getConstant(Stride->getType(), Size * (SCEVs.size() - 1));
+    Stride = TryGetStride(Stride, Sz);
+    if (!Stride)
+      return std::nullopt;
+  }
+  if (!Stride || isa<SCEVConstant>(Stride))
+    return std::nullopt;
+  // Iterate through all pointers and check if all distances are
+  // unique multiple of Dist.
+  using DistOrdPair = std::pair<int64_t, int>;
+  auto Compare = llvm::less_first();
+  std::set<DistOrdPair, decltype(Compare)> Offsets(Compare);
+  int Cnt = 0;
+  bool IsConsecutive = true;
+  for (const SCEV *PtrSCEV : SCEVs) {
+    unsigned Dist = 0;
+    if (PtrSCEV != PtrSCEVA) {
+      const SCEV *Diff = SE.getMinusSCEV(PtrSCEV, PtrSCEVA);
+      const SCEV *Coeff = TryGetStride(Diff, Stride);
+      if (!Coeff)
+        return std::nullopt;
+      const auto *SC = dyn_cast<SCEVConstant>(Coeff);
+      if (!SC || isa<SCEVCouldNotCompute>(SC))
+        return std::nullopt;
+      if (!SE.getMinusSCEV(PtrSCEV,
+                           SE.getAddExpr(PtrSCEVA, SE.getMulExpr(Stride, SC)))
+               ->isZero())
+        return std::nullopt;
+      Dist = SC->getAPInt().getZExtValue();
+    }
+    // If the strides are not the same or repeated, we can't vectorize.
+    if ((Dist / Size) * Size != Dist || (Dist / Size) >= SCEVs.size())
+      return std::nullopt;
+    auto Res = Offsets.emplace(Dist, Cnt);
+    if (!Res.second)
+      return std::nullopt;
+    // Consecutive order if the inserted element is the last one.
+    IsConsecutive = IsConsecutive && std::next(Res.first) == Offsets.end();
+    ++Cnt;
+  }
+  if (Offsets.size() != SCEVs.size())
+    return std::nullopt;
+  SortedIndices.clear();
+  if (!IsConsecutive) {
+    // Fill SortedIndices array only if it is non-consecutive.
+    SortedIndices.resize(PointerOps.size());
+    Cnt = 0;
+    for (const std::pair<int64_t, int> &Pair : Offsets) {
+      SortedIndices[Cnt] = Pair.second;
+      ++Cnt;
+    }
+  }
+  if (!Inst)
+    return nullptr;
+  SCEVExpander Expander(SE, DL, "strided-load-vec");
+  return Expander.expandCodeFor(Stride, Stride->getType(), Inst);
+}
+
 /// Checks if the given array of loads can be represented as a vectorized,
 /// scatter or just simple gather.
 static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
@@ -3900,7 +4034,8 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
   // Make sure all loads in the bundle are simple - we can't vectorize
   // atomic or volatile loads.
   PointerOps.clear();
-  PointerOps.resize(VL.size());
+  const unsigned Sz = VL.size();
+  PointerOps.resize(Sz);
   auto *POIter = PointerOps.begin();
   for (Value *V : VL) {
     auto *L = cast<LoadInst>(V);
@@ -3913,10 +4048,15 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
   Order.clear();
   // Check the order of pointer operands or that all pointers are the same.
   bool IsSorted = sortPtrAccesses(PointerOps, ScalarTy, DL, SE, Order);
+  Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
+  auto *VecTy = FixedVectorType::get(ScalarTy, Sz);
+  if (!IsSorted && Sz > MinProfitableStridedLoads && TTI.isTypeLegal(VecTy) &&
+      TTI.isLegalStridedLoad(VecTy, CommonAlignment) &&
+      calculateRtStride(PointerOps, ScalarTy, DL, SE, Order))
+    return LoadsState::StridedVectorize;
   if (IsSorted || all_of(PointerOps, [&](Value *P) {
         return arePointersCompatible(P, PointerOps.front(), TLI);
       })) {
-    bool IsPossibleStrided = false;
     if (IsSorted) {
       Value *Ptr0;
       Value *PtrN;
@@ -3930,30 +4070,68 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
       std::optional<int> Diff =
           getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, DL, SE);
       // Check that the sorted loads are consecutive.
-      if (static_cast<unsigned>(*Diff) == VL.size() - 1)
+      if (static_cast<unsigned>(*Diff) == Sz - 1)
         return LoadsState::Vectorize;
       // Simple check if not a strided access - clear order.
-      IsPossibleStrided = *Diff % (VL.size() - 1) == 0;
+      bool IsPossibleStrided = *Diff % (Sz - 1) == 0;
+      // Try to generate strided load node if:
+      // 1. Targ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/80310


More information about the llvm-commits mailing list