[llvm] Reland: [LV]: Teach LV to recursively (de)interleave. (PR #122989)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 15 10:49:09 PST 2025


================
@@ -2997,22 +3004,48 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
     ArrayRef<VPValue *> VPDefs = definedValues();
     const DataLayout &DL = State.CFG.PrevBB->getDataLayout();
     if (VecTy->isScalableTy()) {
-      assert(InterleaveFactor == 2 &&
+      assert(isPowerOf2_32(InterleaveFactor) &&
              "Unsupported deinterleave factor for scalable vectors");
 
-        // Scalable vectors cannot use arbitrary shufflevectors (only splats),
-        // so must use intrinsics to deinterleave.
-      Value *DI = State.Builder.CreateIntrinsic(
-          Intrinsic::vector_deinterleave2, VecTy, NewLoad,
-          /*FMFSource=*/nullptr, "strided.vec");
-      unsigned J = 0;
-      for (unsigned I = 0; I < InterleaveFactor; ++I) {
-        Instruction *Member = Group->getMember(I);
+      // Scalable vectors cannot use arbitrary shufflevectors (only splats),
+      // so must use intrinsics to deinterleave.
+      SmallVector<Value *> DeinterleavedValues(InterleaveFactor);
+      DeinterleavedValues[0] = NewLoad;
+      // For the case of InterleaveFactor > 2, we will have to do recursive
+      // deinterleaving, because the current available deinterleave intrinsic
+      // supports only Factor of 2, otherwise it will bailout after first
+      // iteration.
+      // When deinterleaving, the number of values will double until we
+      // have "InterleaveFactor".
+      for (unsigned NumVectors = 1; NumVectors < InterleaveFactor;
+           NumVectors *= 2) {
+        // Deinterleave the elements within the vector
+        SmallVector<Value *> TempDeinterleavedValues(NumVectors);
+        for (unsigned I = 0; I < NumVectors; ++I) {
+          auto *DiTy = DeinterleavedValues[I]->getType();
+          TempDeinterleavedValues[I] = State.Builder.CreateIntrinsic(
+              Intrinsic::vector_deinterleave2, DiTy, DeinterleavedValues[I],
+              /*FMFSource=*/nullptr, "strided.vec");
+        }
+        // Extract the deinterleaved values:
+        for (unsigned I = 0; I < 2; ++I)
+          for (unsigned J = 0; J < NumVectors; ++J)
+            DeinterleavedValues[NumVectors * I + J] =
+                State.Builder.CreateExtractValue(TempDeinterleavedValues[J], I);
+      }
 
-        if (!Member)
+#ifndef NDEBUG
+      for (Value *Val : DeinterleavedValues)
+        assert(Val && "NULL Deinterleaved Value");
+#endif
+      for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) {
+        Instruction *Member = Group->getMember(I);
+        Value *StridedVec = DeinterleavedValues[I];
+        if (!Member) {
+          // This value is not needed as it's not used
+          static_cast<Instruction *>(StridedVec)->eraseFromParent();
----------------
fhahn wrote:

This seems odd, shouldn't that using LLVM's `cast`, i.e. be `cast<Instruction>`?

https://github.com/llvm/llvm-project/pull/122989


More information about the llvm-commits mailing list