[llvm] [LV] Unify interleaved load handling for fixed and scalable VFs. nfc (PR #146914)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 3 08:37:01 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-vectorizers
Author: Mel Chen (Mel-Chen)
<details>
<summary>Changes</summary>
This patch modifies VPInterleaveRecipe::execute to handle both fixed and scalable VFs using a single loop.
---
Full diff: https://github.com/llvm/llvm-project/pull/146914.diff
1 Files Affected:
- (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+17-35)
``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 06511b61a67c3..2fc2447deb3fb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -3442,7 +3442,6 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
VPValue *BlockInMask = getMask();
VPValue *Addr = getAddr();
Value *ResAddr = State.get(Addr, VPLane(0));
- Value *PoisonVec = PoisonValue::get(VecTy);
auto CreateGroupMask = [&BlockInMask, &State,
&InterleaveFactor](Value *MaskForGaps) -> Value * {
@@ -3481,6 +3480,7 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
Instruction *NewLoad;
if (BlockInMask || MaskForGaps) {
Value *GroupMask = CreateGroupMask(MaskForGaps);
+ Value *PoisonVec = PoisonValue::get(VecTy);
NewLoad = State.Builder.CreateMaskedLoad(VecTy, ResAddr,
Group->getAlign(), GroupMask,
PoisonVec, "wide.masked.vec");
@@ -3490,57 +3490,39 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
Group->addMetadata(NewLoad);
ArrayRef<VPValue *> VPDefs = definedValues();
- const DataLayout &DL = State.CFG.PrevBB->getDataLayout();
if (VecTy->isScalableTy()) {
// Scalable vectors cannot use arbitrary shufflevectors (only splats),
// so must use intrinsics to deinterleave.
assert(InterleaveFactor <= 8 &&
"Unsupported deinterleave factor for scalable vectors");
- Value *Deinterleave = State.Builder.CreateIntrinsic(
+ NewLoad = State.Builder.CreateIntrinsic(
getDeinterleaveIntrinsicID(InterleaveFactor), NewLoad->getType(),
NewLoad,
/*FMFSource=*/nullptr, "strided.vec");
+ }
- for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) {
- Instruction *Member = Group->getMember(I);
- Value *StridedVec = State.Builder.CreateExtractValue(Deinterleave, I);
- if (!Member) {
- // This value is not needed as it's not used
- cast<Instruction>(StridedVec)->eraseFromParent();
- continue;
- }
- // If this member has different type, cast the result type.
- if (Member->getType() != ScalarTy) {
- VectorType *OtherVTy = VectorType::get(Member->getType(), State.VF);
- StridedVec =
- createBitOrPointerCast(State.Builder, StridedVec, OtherVTy, DL);
- }
-
- if (Group->isReverse())
- StridedVec = State.Builder.CreateVectorReverse(StridedVec, "reverse");
-
- State.set(VPDefs[J], StridedVec);
- ++J;
- }
+ auto CreateStridedVector = [&InterleaveFactor, &State,
+ &NewLoad](unsigned Index) -> Value * {
+ assert(Index < InterleaveFactor && "Illegal group index");
+ if (State.VF.isScalable())
+ return State.Builder.CreateExtractValue(NewLoad, Index);
- return;
- }
- assert(!State.VF.isScalable() && "VF is assumed to be non scalable.");
+ // For fixed length VF, use shuffle to extract the sub-vectors from the
+ // wide load.
+ auto StrideMask =
+ createStrideMask(Index, InterleaveFactor, State.VF.getFixedValue());
+ return State.Builder.CreateShuffleVector(NewLoad, StrideMask,
+ "strided.vec");
+ };
- // For each member in the group, shuffle out the appropriate data from the
- // wide loads.
- unsigned J = 0;
- for (unsigned I = 0; I < InterleaveFactor; ++I) {
+ for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) {
Instruction *Member = Group->getMember(I);
// Skip the gaps in the group.
if (!Member)
continue;
- auto StrideMask =
- createStrideMask(I, InterleaveFactor, State.VF.getFixedValue());
- Value *StridedVec =
- State.Builder.CreateShuffleVector(NewLoad, StrideMask, "strided.vec");
+ Value *StridedVec = CreateStridedVector(I);
// If this member has different type, cast the result type.
if (Member->getType() != ScalarTy) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/146914
More information about the llvm-commits
mailing list