[llvm] [LV] Convert gather loads with invariant stride into strided loads (PR #147297)

Mel Chen via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 22 02:44:13 PDT 2025


================
@@ -3767,3 +3776,183 @@ void VPlanTransforms::addBranchWeightToMiddleTerminator(
       MDB.createBranchWeights({1, VectorStep - 1}, /*IsExpected=*/false);
   MiddleTerm->addMetadata(LLVMContext::MD_prof, BranchWeights);
 }
+
+static std::pair<VPValue *, VPValue *> matchStridedStart(VPValue *CurIndex) {
+  if (auto *WidenIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(CurIndex))
+    return {WidenIV, WidenIV->getStepValue()};
+
+  auto *WidenR = dyn_cast<VPWidenRecipe>(CurIndex);
+  if (!WidenR || !CurIndex->getUnderlyingValue())
+    return {nullptr, nullptr};
+
+  unsigned Opcode = WidenR->getOpcode();
+  // TODO: Support Instruction::Add and Instruction::Or.
+  if (Opcode != Instruction::Shl && Opcode != Instruction::Mul)
+    return {nullptr, nullptr};
+
+  // Match the pattern binop(variant, invariant), or binop(invariant, variant)
+  // if the binary operator is commutative.
+  bool IsLHSUniform = vputils::isSingleScalar(WidenR->getOperand(0));
+  if (IsLHSUniform == vputils::isSingleScalar(WidenR->getOperand(1)) ||
+      (IsLHSUniform && !Instruction::isCommutative(Opcode)))
+    return {nullptr, nullptr};
+  unsigned VarIdx = IsLHSUniform ? 1 : 0;
+
+  auto [Start, Stride] = matchStridedStart(WidenR->getOperand(VarIdx));
+  if (!Start)
+    return {nullptr, nullptr};
+
+  SmallVector<VPValue *> StartOps(WidenR->operands());
+  StartOps[VarIdx] = Start;
+  auto *StartR = new VPReplicateRecipe(WidenR->getUnderlyingInstr(), StartOps,
+                                       /*IsUniform*/ true);
+  StartR->insertBefore(WidenR);
+
+  unsigned InvIdx = VarIdx == 0 ? 1 : 0;
+  auto *StrideR =
+      new VPInstruction(Opcode, {Stride, WidenR->getOperand(InvIdx)});
+  StrideR->insertBefore(WidenR);
+  return {StartR, StrideR};
+}
+
+static std::tuple<VPValue *, VPValue *, Type *>
+determineBaseAndStride(VPWidenGEPRecipe *WidenGEP) {
+  // TODO: Check if the base pointer is strided.
+  if (!WidenGEP->isPointerLoopInvariant())
+    return {nullptr, nullptr, nullptr};
+
+  // Find the only one variant index.
+  std::optional<unsigned> VarIndex = WidenGEP->getUniqueVariantIndex();
+  if (!VarIndex)
+    return {nullptr, nullptr, nullptr};
+
+  Type *ElementTy = WidenGEP->getIndexedType(*VarIndex);
+  if (ElementTy->isScalableTy() || ElementTy->isStructTy() ||
+      ElementTy->isVectorTy())
+    return {nullptr, nullptr, nullptr};
+
+  unsigned VarOp = *VarIndex + 1;
+  VPValue *IndexVPV = WidenGEP->getOperand(VarOp);
+  auto [Start, Stride] = matchStridedStart(IndexVPV);
+  if (!Start)
+    return {nullptr, nullptr, nullptr};
+
+  SmallVector<VPValue *> Ops(WidenGEP->operands());
+  Ops[VarOp] = Start;
+  auto *BasePtr = new VPReplicateRecipe(WidenGEP->getUnderlyingInstr(), Ops,
+                                        /*IsUniform*/ true);
+  BasePtr->insertBefore(WidenGEP);
+
+  return {BasePtr, Stride, ElementTy};
+}
+
+void VPlanTransforms::convertToStridedAccesses(VPlan &Plan, VPCostContext &Ctx,
+                                               VFRange &Range) {
+  if (Plan.hasScalarVFOnly())
+    return;
+
+  VPTypeAnalysis TypeInfo(Plan);
+  DenseMap<VPWidenGEPRecipe *, std::tuple<VPValue *, VPValue *, Type *>>
+      StrideCache;
+  SmallVector<VPRecipeBase *> ToErase;
+  SmallPtrSet<VPValue *, 4> PossiblyDead;
+  for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
+           vp_depth_first_shallow(Plan.getVectorLoopRegion()->getEntry()))) {
+    for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+      auto *MemR = dyn_cast<VPWidenMemoryRecipe>(&R);
+      // TODO: support strided store
+      // TODO: support reverse access
+      // TODO: transform interleave access into multiple strided accesses
+      if (!MemR || !isa<VPWidenLoadRecipe>(MemR) || MemR->isConsecutive())
+        continue;
+
+      auto *Ptr = dyn_cast<VPWidenGEPRecipe>(MemR->getAddr());
+      if (!Ptr)
+        continue;
+
+      // Memory cost model requires the pointer operand of memory access
+      // instruction.
+      Value *PtrUV = Ptr->getUnderlyingValue();
+      if (!PtrUV)
+        continue;
+
+      // Try to get base and stride here.
+      VPValue *BasePtr, *StrideInElement;
+      Type *ElementTy;
+      auto It = StrideCache.find(Ptr);
+      if (It != StrideCache.end())
+        std::tie(BasePtr, StrideInElement, ElementTy) = It->second;
+      else
+        std::tie(BasePtr, StrideInElement, ElementTy) = StrideCache[Ptr] =
+            determineBaseAndStride(Ptr);
+
+      // Skip if the memory access is not a strided access.
+      if (!BasePtr) {
+        assert(!StrideInElement && !ElementTy);
+        continue;
+      }
+      assert(StrideInElement && ElementTy);
+
+      Instruction &Ingredient = MemR->getIngredient();
+      auto IsProfitable = [&](ElementCount VF) -> bool {
+        Type *DataTy = toVectorTy(getLoadStoreType(&Ingredient), VF);
+        const Align Alignment = getLoadStoreAlignment(&Ingredient);
+        if (!Ctx.TTI.isLegalStridedLoadStore(DataTy, Alignment))
+          return false;
+        const InstructionCost CurrentCost = MemR->computeCost(VF, Ctx);
+        const InstructionCost StridedLoadStoreCost =
+            Ctx.TTI.getStridedMemoryOpCost(Instruction::Load, DataTy, PtrUV,
+                                           MemR->isMasked(), Alignment,
+                                           Ctx.CostKind, &Ingredient);
+        // FIXME: Fix the cost of gather/scatter and strided access.
+        return StridedLoadStoreCost <= CurrentCost;
+      };
+
+      if (!LoopVectorizationPlanner::getDecisionAndClampRange(IsProfitable,
+                                                              Range)) {
+        PossiblyDead.insert(BasePtr);
+        PossiblyDead.insert(StrideInElement);
+        continue;
+      }
+      PossiblyDead.insert(Ptr);
+
+      // Create a new vector pointer for strided access.
+      auto *GEP = dyn_cast<GetElementPtrInst>(PtrUV->stripPointerCasts());
+      auto *NewPtr = new VPVectorPointerRecipe(
+          BasePtr, ElementTy, StrideInElement,
+          GEP ? GEP->getNoWrapFlags() : GEPNoWrapFlags::none(),
+          Ptr->getDebugLoc());
----------------
Mel-Chen wrote:

Need more test cases for UF > 1. 

https://github.com/llvm/llvm-project/pull/147297


More information about the llvm-commits mailing list