[llvm] [LV]: Teach LV to recursively (de)interleave. (PR #89018)

Hassnaa Hamdi via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 23 22:28:29 PST 2024


================
@@ -2910,22 +2919,48 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
     ArrayRef<VPValue *> VPDefs = definedValues();
     const DataLayout &DL = State.CFG.PrevBB->getDataLayout();
     if (VecTy->isScalableTy()) {
-      assert(InterleaveFactor == 2 &&
+      assert(isPowerOf2_32(InterleaveFactor) &&
              "Unsupported deinterleave factor for scalable vectors");
 
-        // Scalable vectors cannot use arbitrary shufflevectors (only splats),
-        // so must use intrinsics to deinterleave.
-      Value *DI = State.Builder.CreateIntrinsic(
-          Intrinsic::vector_deinterleave2, VecTy, NewLoad,
-          /*FMFSource=*/nullptr, "strided.vec");
-      unsigned J = 0;
-      for (unsigned I = 0; I < InterleaveFactor; ++I) {
-        Instruction *Member = Group->getMember(I);
+      // Scalable vectors cannot use arbitrary shufflevectors (only splats),
+      // so must use intrinsics to deinterleave.
+
+      SmallVector<Value *> DeinterleavedValues(InterleaveFactor);
+      DeinterleavedValues[0] = NewLoad;
+      // For the case of InterleaveFactor > 2, we will have to do recursive
+      // deinterleaving, because the current available deinterleave intrinsic
+      // supports only Factor of 2, otherwise it will bailout after first
+      // iteration.
+      // As we are deinterleaving, the values will be doubled until reachingt
+      // to the InterleaveFactor.
+      for (int NumVectors = 1; NumVectors < InterleaveFactor; NumVectors *= 2) {
+        // deinterleave the elements within the vector
+        std::vector<Value *> TempDeinterleavedValues(NumVectors);
+        for (int I = 0; I < NumVectors; ++I) {
+          auto *DiTy = DeinterleavedValues[I]->getType();
+          TempDeinterleavedValues[I] = State.Builder.CreateIntrinsic(
+              Intrinsic::vector_deinterleave2, DiTy, DeinterleavedValues[I],
+              /*FMFSource=*/nullptr, "strided.vec");
+        }
+        // Extract the deinterleaved values:
+        for (int I = 0; I < 2; ++I)
+          for (int J = 0; J < NumVectors; ++J)
+            DeinterleavedValues[NumVectors * I + J] =
+                State.Builder.CreateExtractValue(TempDeinterleavedValues[J], I);
+      }
 
-        if (!Member)
+#ifndef NDEBUG
+      for (Value *Val : DeinterleavedValues)
+        assert(Val && "NULL Deinterleaved Value");
+#endif
+      for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) {
----------------
hassnaaHamdi wrote:

It's used at the end of the loop for updating the final state/(the used interleaved values)

https://github.com/llvm/llvm-project/pull/89018


More information about the llvm-commits mailing list