[llvm] [LV] Convert gather loads with invariant stride into strided loads (PR #147297)
Mel Chen via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 9 05:36:47 PST 2025
================
@@ -5106,3 +5133,191 @@ void VPlanTransforms::addExitUsersForFirstOrderRecurrences(VPlan &Plan,
}
}
}
+
+static std::pair<VPValue *, VPValue *> matchStridedStart(VPValue *CurIndex) {
+ // TODO: Support VPWidenPointerInductionRecipe.
+ if (auto *WidenIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(CurIndex))
+ return {WidenIV, WidenIV->getStepValue()};
+
+ auto *WidenR = dyn_cast<VPWidenRecipe>(CurIndex);
+ if (!WidenR || !WidenR->getUnderlyingInstr())
+ return {nullptr, nullptr};
+
+ unsigned Opcode = WidenR->getOpcode();
+ // TODO: Support Instruction::Add and Instruction::Or.
+ if (Opcode != Instruction::Shl && Opcode != Instruction::Mul)
+ return {nullptr, nullptr};
+
+ // Match the pattern binop(variant, uniform), or binop(uniform, variant) if
+ // the binary operator is commutative.
+ bool IsLHSUniform = vputils::isSingleScalar(WidenR->getOperand(0));
+ if (IsLHSUniform == vputils::isSingleScalar(WidenR->getOperand(1)) ||
+ (IsLHSUniform && !Instruction::isCommutative(Opcode)))
+ return {nullptr, nullptr};
+ unsigned VarIdx = IsLHSUniform ? 1 : 0;
+
+ auto [Start, Stride] = matchStridedStart(WidenR->getOperand(VarIdx));
+ if (!Start)
+ return {nullptr, nullptr};
+
+ SmallVector<VPValue *> StartOps(WidenR->operands());
+ StartOps[VarIdx] = Start;
+ auto *StartR = new VPReplicateRecipe(WidenR->getUnderlyingInstr(), StartOps,
+ /*IsUniform*/ true);
+ StartR->insertBefore(WidenR);
+
+ unsigned InvIdx = VarIdx == 0 ? 1 : 0;
+ auto *StrideR =
+ new VPInstruction(Opcode, {Stride, WidenR->getOperand(InvIdx)});
+ StrideR->insertBefore(WidenR);
+ return {StartR, StrideR};
+}
+
+/// Checks if the given VPWidenGEPRecipe \p WidenGEP represents a strided
+/// access. If so, it creates recipes representing the base pointer and stride
+/// in element units, and returns a tuple of {base pointer, stride, element
+/// type}. Otherwise, returns a tuple where all elements are nullptr.
+static std::tuple<VPValue *, VPValue *, Type *>
+determineBaseAndStride(VPWidenGEPRecipe *WidenGEP) {
+ // TODO: Check if the base pointer is strided.
+ if (!WidenGEP->isPointerLoopInvariant())
+ return {nullptr, nullptr, nullptr};
+
+ // Find the only one variant index.
+ std::optional<unsigned> VarIndex = WidenGEP->getUniqueVariantIndex();
+ if (!VarIndex)
+ return {nullptr, nullptr, nullptr};
+
+ Type *ElementTy = WidenGEP->getIndexedType(*VarIndex);
+ if (ElementTy->isScalableTy() || ElementTy->isStructTy() ||
+ ElementTy->isVectorTy())
+ return {nullptr, nullptr, nullptr};
+
+ unsigned VarOp = *VarIndex + 1;
+ VPValue *IndexVPV = WidenGEP->getOperand(VarOp);
+ auto [Start, Stride] = matchStridedStart(IndexVPV);
+ if (!Start)
+ return {nullptr, nullptr, nullptr};
+
+ SmallVector<VPValue *> Ops(WidenGEP->operands());
+ Ops[VarOp] = Start;
+ auto *BasePtr = new VPReplicateRecipe(WidenGEP->getUnderlyingInstr(), Ops,
+ /*IsUniform*/ true);
+ BasePtr->insertBefore(WidenGEP);
+
+ return {BasePtr, Stride, ElementTy};
+}
+
+void VPlanTransforms::convertToStridedAccesses(VPlan &Plan, VPCostContext &Ctx,
+ VFRange &Range) {
+ if (Plan.hasScalarVFOnly())
+ return;
+
+ VPTypeAnalysis TypeInfo(Plan);
+ DenseMap<VPWidenGEPRecipe *, std::tuple<VPValue *, VPValue *, Type *>>
+ StrideCache;
+ SmallVector<VPWidenMemoryRecipe *> ToErase;
+ VPValue *I32VF = nullptr;
+ for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
+ vp_depth_first_shallow(Plan.getVectorLoopRegion()->getEntry()))) {
+ for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
+ auto *LoadR = dyn_cast<VPWidenLoadRecipe>(&R);
+ // TODO: Support strided store.
+ // TODO: Transform reverse access into strided access with -1 stride.
+ // TODO: Transform gather/scatter with uniform address into strided access
+ // with 0 stride.
+ // TODO: Transform interleave access into multiple strided accesses.
+ if (!LoadR || LoadR->isConsecutive())
+ continue;
+
+ auto *Ptr = dyn_cast<VPWidenGEPRecipe>(LoadR->getAddr());
+ if (!Ptr)
+ continue;
+
+ Instruction &Ingredient = LoadR->getIngredient();
+ auto IsProfitable = [&](ElementCount VF) -> bool {
+ Type *DataTy = toVectorTy(getLoadStoreType(&Ingredient), VF);
+ const Align Alignment = getLoadStoreAlignment(&Ingredient);
+ if (!Ctx.TTI.isLegalStridedLoadStore(DataTy, Alignment))
+ return false;
+ const InstructionCost CurrentCost = LoadR->computeCost(VF, Ctx);
+ const InstructionCost StridedLoadStoreCost =
+ Ctx.TTI.getMemIntrinsicInstrCost(
+ MemIntrinsicCostAttributes(
+ Intrinsic::experimental_vp_strided_load, DataTy,
+ Ptr->getUnderlyingValue(), LoadR->isMasked(), Alignment,
+ &Ingredient),
+ Ctx.CostKind);
+ return StridedLoadStoreCost < CurrentCost;
+ };
+
+ if (!LoopVectorizationPlanner::getDecisionAndClampRange(IsProfitable,
+ Range))
+ continue;
+
+ // Try to get base and stride here.
+ VPValue *BasePtr, *StrideInElement;
+ Type *ElementTy;
+ auto It = StrideCache.find(Ptr);
+ if (It != StrideCache.end())
+ std::tie(BasePtr, StrideInElement, ElementTy) = It->second;
+ else
+ std::tie(BasePtr, StrideInElement, ElementTy) = StrideCache[Ptr] =
+ determineBaseAndStride(Ptr);
+
+ // Skip if the memory access is not a strided access.
+ if (!BasePtr)
+ continue;
+ assert(StrideInElement && ElementTy &&
+ "Can not get stride information for a strided access");
+
+ // Add VF of i32 version for EVL.
+ if (!I32VF) {
+ VPBuilder Builder(Plan.getVectorPreheader());
+ I32VF = Builder.createScalarZExtOrTrunc(
+ &Plan.getVF(), Type::getInt32Ty(Plan.getContext()),
+ TypeInfo.inferScalarType(&Plan.getVF()), DebugLoc::getUnknown());
+ }
+
+ // Create a new vector pointer for strided access.
+ auto *NewPtr = new VPVectorPointerRecipe(
+ BasePtr, ElementTy, StrideInElement, Ptr->getGEPNoWrapFlags(),
+ Ptr->getDebugLoc());
+ NewPtr->insertBefore(LoadR);
+
+ const DataLayout &DL = Ingredient.getDataLayout();
+ TypeSize TS = DL.getTypeAllocSize(ElementTy);
+ unsigned TypeScale = TS.getFixedValue();
+ VPValue *StrideInBytes = StrideInElement;
+ // Scale the stride by the size of the indexed type.
+ if (TypeScale != 1) {
+ VPValue *ScaleVPV = Plan.getConstantInt(
+ TypeInfo.inferScalarType(StrideInElement), TypeScale);
+ auto *ScaledStride =
+ new VPInstruction(Instruction::Mul, {StrideInElement, ScaleVPV});
----------------
Mel-Chen wrote:
Sure
2a0fe35984703b75388c038320f8cbfc8f7c61e8
https://github.com/llvm/llvm-project/pull/147297
More information about the llvm-commits
mailing list