[llvm] [LV]: Teach LV to recursively (de)interleave. (PR #89018)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 17 08:22:14 PDT 2024


================
@@ -2291,23 +2329,69 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
     ArrayRef<VPValue *> VPDefs = definedValues();
     const DataLayout &DL = State.CFG.PrevBB->getDataLayout();
     if (VecTy->isScalableTy()) {
-      assert(InterleaveFactor == 2 &&
+      assert(isPowerOf2_32(InterleaveFactor) &&
              "Unsupported deinterleave factor for scalable vectors");
 
       for (unsigned Part = 0; Part < State.UF; ++Part) {
         // Scalable vectors cannot use arbitrary shufflevectors (only splats),
         // so must use intrinsics to deinterleave.
-        Value *DI = State.Builder.CreateIntrinsic(
-            Intrinsic::vector_deinterleave2, VecTy, NewLoads[Part],
-            /*FMFSource=*/nullptr, "strided.vec");
-        unsigned J = 0;
-        for (unsigned I = 0; I < InterleaveFactor; ++I) {
-          Instruction *Member = Group->getMember(I);
 
-          if (!Member)
-            continue;
+        SmallVector<Value *> DeinterleavedValues;
+        // If the InterleaveFactor is > 2, so we will have to do recursive
+        // deinterleaving, because the current available deinterleave intrinsic
+        // supports only Factor of 2. DeinterleaveCount represent how many times
+        // we will do deinterleaving, we will do deinterleave on all nonleaf
+        // nodes in the deinterleave tree.
+        unsigned DeinterleaveCount = InterleaveFactor - 1;
+        std::vector<Value *> TempDeinterleavedValues;
+        TempDeinterleavedValues.push_back(NewLoads[Part]);
+        for (unsigned I = 0; I < DeinterleaveCount; ++I) {
+          auto *DiTy = TempDeinterleavedValues[I]->getType();
+          Value *DI = State.Builder.CreateIntrinsic(
+              Intrinsic::vector_deinterleave2, DiTy, TempDeinterleavedValues[I],
+              /*FMFSource=*/nullptr, "strided.vec");
+          Value *StridedVec = State.Builder.CreateExtractValue(DI, 0);
+          TempDeinterleavedValues.push_back(StridedVec);
+          StridedVec = State.Builder.CreateExtractValue(DI, 1);
+          TempDeinterleavedValues.push_back(StridedVec);
+          // Perform sorting at the start of each new level in the tree.
+          // A new level begins when the number of remaining values is a power
+          // of 2 and greater than 2. If a level has only 2 nodes, no sorting is
+          // needed as they are already in order. Number of remaining values to
+          // be processed:
+          unsigned NumRemainingValues = TempDeinterleavedValues.size() - I - 1;
+          if (NumRemainingValues > 2 && isPowerOf2_32(NumRemainingValues)) {
+            // these remaining values represent a new level in the tree,
+            // Reorder the values to match the correct deinterleaving order.
+            std::vector<Value *> RemainingValues(
+                TempDeinterleavedValues.begin() + I + 1,
+                TempDeinterleavedValues.end());
+            unsigned Middle = NumRemainingValues / 2;
+            for (unsigned J = 0, K = I + 1; J < NumRemainingValues;
+                 J += 2, K++) {
----------------
paulwalker-arm wrote:

Similar to my interleave comment in that I think you've backed yourself into a corner by not nesting where you create the `vector_deinterleave2` call. I knocked together the following, where also splitting the deinterleave into two phases (elements then vectors) also looks to simplify the implementation (assuming I've got my logic correct).

```
if (VecTy->isScalableTy()) {
  assert(Factor_is_a_power_of_2)
  
  vector<Value *> DEs(Vals);
  DEs.resize(Factor);

  for (int NumVectors = 1; NumVectors < Factor; NumVectors *= 2) {
    // deinterleave the elements within each vector
    Value *R[NumVectors]
    for (int i = 0; i < NumVectors; ++i)
      R[i] = builder.create_vector_deinterleave2(DEs[i]);
    
    // deinterleave the vectors themselves
    for (int i = 0; i < 2; ++i)
      for (int i = 0; i < NumVectors; ++i)
        DEs[NumVectors * i + j].push_back(builder.create_extract(R[i], i));
  }

  // DEs is ready to use
}
```

https://github.com/llvm/llvm-project/pull/89018


More information about the llvm-commits mailing list