[llvm] [LV]: Teach LV to recursively (de)interleave. (PR #89018)
Hassnaa Hamdi via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 4 03:35:49 PST 2024
================
@@ -2291,23 +2329,69 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
ArrayRef<VPValue *> VPDefs = definedValues();
const DataLayout &DL = State.CFG.PrevBB->getDataLayout();
if (VecTy->isScalableTy()) {
- assert(InterleaveFactor == 2 &&
+ assert(isPowerOf2_32(InterleaveFactor) &&
"Unsupported deinterleave factor for scalable vectors");
for (unsigned Part = 0; Part < State.UF; ++Part) {
// Scalable vectors cannot use arbitrary shufflevectors (only splats),
// so must use intrinsics to deinterleave.
- Value *DI = State.Builder.CreateIntrinsic(
- Intrinsic::vector_deinterleave2, VecTy, NewLoads[Part],
- /*FMFSource=*/nullptr, "strided.vec");
- unsigned J = 0;
- for (unsigned I = 0; I < InterleaveFactor; ++I) {
- Instruction *Member = Group->getMember(I);
- if (!Member)
- continue;
+ SmallVector<Value *> DeinterleavedValues;
+ // If the InterleaveFactor is > 2, so we will have to do recursive
+ // deinterleaving, because the current available deinterleave intrinsic
+ // supports only Factor of 2. DeinterleaveCount represent how many times
+ // we will do deinterleaving, we will do deinterleave on all nonleaf
+ // nodes in the deinterleave tree.
+ unsigned DeinterleaveCount = InterleaveFactor - 1;
+ std::vector<Value *> TempDeinterleavedValues;
+ TempDeinterleavedValues.push_back(NewLoads[Part]);
+ for (unsigned I = 0; I < DeinterleaveCount; ++I) {
+ auto *DiTy = TempDeinterleavedValues[I]->getType();
+ Value *DI = State.Builder.CreateIntrinsic(
+ Intrinsic::vector_deinterleave2, DiTy, TempDeinterleavedValues[I],
+ /*FMFSource=*/nullptr, "strided.vec");
+ Value *StridedVec = State.Builder.CreateExtractValue(DI, 0);
+ TempDeinterleavedValues.push_back(StridedVec);
+ StridedVec = State.Builder.CreateExtractValue(DI, 1);
+ TempDeinterleavedValues.push_back(StridedVec);
+ // Perform sorting at the start of each new level in the tree.
+ // A new level begins when the number of remaining values is a power
+ // of 2 and greater than 2. If a level has only 2 nodes, no sorting is
+ // needed as they are already in order. Number of remaining values to
+ // be processed:
+ unsigned NumRemainingValues = TempDeinterleavedValues.size() - I - 1;
+ if (NumRemainingValues > 2 && isPowerOf2_32(NumRemainingValues)) {
+ // these remaining values represent a new level in the tree,
+ // Reorder the values to match the correct deinterleaving order.
+ std::vector<Value *> RemainingValues(
+ TempDeinterleavedValues.begin() + I + 1,
+ TempDeinterleavedValues.end());
+ unsigned Middle = NumRemainingValues / 2;
+ for (unsigned J = 0, K = I + 1; J < NumRemainingValues;
+ J += 2, K++) {
----------------
hassnaaHamdi wrote:
done
https://github.com/llvm/llvm-project/pull/89018
More information about the llvm-commits
mailing list