[llvm] [LV]: Teach LV to recursively (de)interleave. (PR #89018)

Hassnaa Hamdi via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 4 03:35:49 PST 2024


================
@@ -2291,23 +2329,69 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
     ArrayRef<VPValue *> VPDefs = definedValues();
     const DataLayout &DL = State.CFG.PrevBB->getDataLayout();
     if (VecTy->isScalableTy()) {
-      assert(InterleaveFactor == 2 &&
+      assert(isPowerOf2_32(InterleaveFactor) &&
              "Unsupported deinterleave factor for scalable vectors");
 
       for (unsigned Part = 0; Part < State.UF; ++Part) {
         // Scalable vectors cannot use arbitrary shufflevectors (only splats),
         // so must use intrinsics to deinterleave.
-        Value *DI = State.Builder.CreateIntrinsic(
-            Intrinsic::vector_deinterleave2, VecTy, NewLoads[Part],
-            /*FMFSource=*/nullptr, "strided.vec");
-        unsigned J = 0;
-        for (unsigned I = 0; I < InterleaveFactor; ++I) {
-          Instruction *Member = Group->getMember(I);
 
-          if (!Member)
-            continue;
+        SmallVector<Value *> DeinterleavedValues;
+        // If the InterleaveFactor is > 2, so we will have to do recursive
+        // deinterleaving, because the current available deinterleave intrinsic
+        // supports only Factor of 2. DeinterleaveCount represent how many times
+        // we will do deinterleaving, we will do deinterleave on all nonleaf
+        // nodes in the deinterleave tree.
+        unsigned DeinterleaveCount = InterleaveFactor - 1;
+        std::vector<Value *> TempDeinterleavedValues;
+        TempDeinterleavedValues.push_back(NewLoads[Part]);
+        for (unsigned I = 0; I < DeinterleaveCount; ++I) {
+          auto *DiTy = TempDeinterleavedValues[I]->getType();
+          Value *DI = State.Builder.CreateIntrinsic(
+              Intrinsic::vector_deinterleave2, DiTy, TempDeinterleavedValues[I],
+              /*FMFSource=*/nullptr, "strided.vec");
+          Value *StridedVec = State.Builder.CreateExtractValue(DI, 0);
+          TempDeinterleavedValues.push_back(StridedVec);
+          StridedVec = State.Builder.CreateExtractValue(DI, 1);
+          TempDeinterleavedValues.push_back(StridedVec);
+          // Perform sorting at the start of each new level in the tree.
+          // A new level begins when the number of remaining values is a power
+          // of 2 and greater than 2. If a level has only 2 nodes, no sorting is
+          // needed as they are already in order. Number of remaining values to
+          // be processed:
+          unsigned NumRemainingValues = TempDeinterleavedValues.size() - I - 1;
+          if (NumRemainingValues > 2 && isPowerOf2_32(NumRemainingValues)) {
+            // these remaining values represent a new level in the tree,
+            // Reorder the values to match the correct deinterleaving order.
+            std::vector<Value *> RemainingValues(
+                TempDeinterleavedValues.begin() + I + 1,
+                TempDeinterleavedValues.end());
+            unsigned Middle = NumRemainingValues / 2;
+            for (unsigned J = 0, K = I + 1; J < NumRemainingValues;
+                 J += 2, K++) {
----------------
hassnaaHamdi wrote:

done

https://github.com/llvm/llvm-project/pull/89018


More information about the llvm-commits mailing list