[llvm] [LV]: Teach LV to recursively (de)interleave. (PR #89018)

Hassnaa Hamdi via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 31 12:54:21 PDT 2024


https://github.com/hassnaaHamdi updated https://github.com/llvm/llvm-project/pull/89018

>From ef3a8ea8cc54f5ec34b5ed807c6f703ca6589e6a Mon Sep 17 00:00:00 2001
From: Hassnaa Hamdi <hassnaa.hamdi at arm.com>
Date: Thu, 22 Aug 2024 04:30:05 +0000
Subject: [PATCH 1/3] [LV]: Teach LV to recursively (de)interleave.

Currently available intrinsics are only ld2/st2, which don't support interleaving factor > 2.
This patch teaches the LV to use ld2/st2 recursively to support high interleaving factors.

Change-Id: I96af28dc6aeca0c6929d604176cc9ba29fca17df
---
 .../AArch64/AArch64TargetTransformInfo.cpp    |   6 +-
 .../Transforms/Vectorize/LoopVectorize.cpp    |   6 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 109 ++++++++++---
 .../AArch64/sve-deinterleave4.ll              | 144 +++++++++++++++++-
 .../AArch64/sve-interleaved-accesses.ll       |   6 +-
 5 files changed, 241 insertions(+), 30 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index e85fd73996dd1a..880eccd358b01b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3558,7 +3558,9 @@ InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
   assert(Factor >= 2 && "Invalid interleave factor");
   auto *VecVTy = cast<VectorType>(VecTy);
 
-  if (VecTy->isScalableTy() && (!ST->hasSVE() || Factor != 2))
+  unsigned MaxFactor = TLI->getMaxSupportedInterleaveFactor();
+  if (VecTy->isScalableTy() &&
+      (!ST->hasSVE() || !isPowerOf2_32(Factor) || Factor > MaxFactor))
     return InstructionCost::getInvalid();
 
   // Vectorization for masked interleaved accesses is only enabled for scalable
@@ -3566,7 +3568,7 @@ InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
   if (!VecTy->isScalableTy() && (UseMaskForCond || UseMaskForGaps))
     return InstructionCost::getInvalid();
 
-  if (!UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) {
+  if (!UseMaskForGaps && Factor <= MaxFactor) {
     unsigned MinElts = VecVTy->getElementCount().getKnownMinValue();
     auto *SubVecTy =
         VectorType::get(VecVTy->getElementType(),
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f5337b11edc977..c5067962bc8cd7 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8689,9 +8689,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
                      CM.getWideningDecision(IG->getInsertPos(), VF) ==
                          LoopVectorizationCostModel::CM_Interleave);
       // For scalable vectors, the only interleave factor currently supported
-      // is 2 since we require the (de)interleave2 intrinsics instead of
-      // shufflevectors.
-      assert((!Result || !VF.isScalable() || IG->getFactor() == 2) &&
+      // must be power of 2 since we require the (de)interleave2 intrinsics
+      // instead of shufflevectors.
+      assert((!Result || !VF.isScalable() || isPowerOf2_32(IG->getFactor())) &&
              "Unsupported interleave factor for scalable vectors");
       return Result;
     };
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 911b2fe9e9a1eb..62e5cb4c00c7be 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -31,6 +31,7 @@
 #include "llvm/Transforms/Utils/LoopUtils.h"
 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
 #include <cassert>
+#include <queue>
 
 using namespace llvm;
 
@@ -2126,10 +2127,39 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
   // Scalable vectors cannot use arbitrary shufflevectors (only splats), so
   // must use intrinsics to interleave.
   if (VecTy->isScalableTy()) {
-    VectorType *WideVecTy = VectorType::getDoubleElementsVectorType(VecTy);
-    return Builder.CreateIntrinsic(WideVecTy, Intrinsic::vector_interleave2,
-                                   Vals,
-                                   /*FMFSource=*/nullptr, Name);
+    unsigned InterleaveFactor = Vals.size();
+    SmallVector<Value *> InterleavingValues;
+    unsigned InterleavingValuesCount =
+        InterleaveFactor + (InterleaveFactor - 2);
+    InterleavingValues.resize(InterleaveFactor);
+    // Place the values to be interleaved in the correct order for the
+    // interleaving
+    for (unsigned I = 0, J = InterleaveFactor / 2, K = 0; K < InterleaveFactor;
+         K++) {
+      if (K % 2 == 0) {
+        InterleavingValues[K] = Vals[I];
+        I++;
+      } else {
+        InterleavingValues[K] = Vals[J];
+        J++;
+      }
+    }
+#ifndef NDEBUG
+    for (Value *Val : InterleavingValues)
+      assert(Val && "NULL Interleaving Value");
+#endif
+    for (unsigned I = 1; I < InterleavingValuesCount; I += 2) {
+      VectorType *InterleaveTy =
+          cast<VectorType>(InterleavingValues[I]->getType());
+      VectorType *WideVecTy =
+          VectorType::getDoubleElementsVectorType(InterleaveTy);
+      auto *InterleaveRes = Builder.CreateIntrinsic(
+          WideVecTy, Intrinsic::vector_interleave2,
+          {InterleavingValues[I - 1], InterleavingValues[I]},
+          /*FMFSource=*/nullptr, Name);
+      InterleavingValues.push_back(InterleaveRes);
+    }
+    return InterleavingValues[InterleavingValuesCount];
   }
 
   // Fixed length. Start by concatenating all vectors into a wide vector.
@@ -2238,15 +2268,12 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
                              unsigned Part, Value *MaskForGaps) -> Value * {
     if (State.VF.isScalable()) {
       assert(!MaskForGaps && "Interleaved groups with gaps are not supported.");
-      assert(InterleaveFactor == 2 &&
+      assert(isPowerOf2_32(InterleaveFactor) &&
              "Unsupported deinterleave factor for scalable vectors");
       auto *BlockInMaskPart = State.get(BlockInMask, Part);
-      SmallVector<Value *, 2> Ops = {BlockInMaskPart, BlockInMaskPart};
-      auto *MaskTy = VectorType::get(State.Builder.getInt1Ty(),
-                                     State.VF.getKnownMinValue() * 2, true);
-      return State.Builder.CreateIntrinsic(
-          MaskTy, Intrinsic::vector_interleave2, Ops,
-          /*FMFSource=*/nullptr, "interleaved.mask");
+      SmallVector<Value *> Ops;
+      Ops.resize(InterleaveFactor, BlockInMaskPart);
+      return interleaveVectors(State.Builder, Ops, "interleaved.mask");
     }
 
     if (!BlockInMask)
@@ -2291,23 +2318,63 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
     ArrayRef<VPValue *> VPDefs = definedValues();
     const DataLayout &DL = State.CFG.PrevBB->getDataLayout();
     if (VecTy->isScalableTy()) {
-      assert(InterleaveFactor == 2 &&
+      assert(isPowerOf2_32(InterleaveFactor) &&
              "Unsupported deinterleave factor for scalable vectors");
 
       for (unsigned Part = 0; Part < State.UF; ++Part) {
         // Scalable vectors cannot use arbitrary shufflevectors (only splats),
         // so must use intrinsics to deinterleave.
-        Value *DI = State.Builder.CreateIntrinsic(
-            Intrinsic::vector_deinterleave2, VecTy, NewLoads[Part],
-            /*FMFSource=*/nullptr, "strided.vec");
-        unsigned J = 0;
-        for (unsigned I = 0; I < InterleaveFactor; ++I) {
-          Instruction *Member = Group->getMember(I);
 
-          if (!Member)
-            continue;
+        SmallVector<Value *> DeinterleavedValues;
+        // If the InterleaveFactor is > 2, so we will have to do recursive
+        // deinterleaving, because the current available deinterleave intrinsice
+        // supports only Factor of 2. DeinterleaveCount represent how many times
+        // we will do deinterleaving, we will do deinterleave on all nonleaf
+        // nodes in the deinterleave tree.
+        unsigned DeinterleaveCount = InterleaveFactor - 1;
+        std::queue<Value *> TempDeinterleavedValues;
+        TempDeinterleavedValues.push(NewLoads[Part]);
+        for (unsigned I = 0; I < DeinterleaveCount; ++I) {
+          Value *ValueToDeinterleave = TempDeinterleavedValues.front();
+          auto *DiTy = ValueToDeinterleave->getType();
+          TempDeinterleavedValues.pop();
+          Value *DI = State.Builder.CreateIntrinsic(
+              Intrinsic::vector_deinterleave2, DiTy, ValueToDeinterleave,
+              /*FMFSource=*/nullptr, "strided.vec");
+          Value *StridedVec = State.Builder.CreateExtractValue(DI, 0);
+          TempDeinterleavedValues.push(StridedVec);
+          StridedVec = State.Builder.CreateExtractValue(DI, 1);
+          TempDeinterleavedValues.push(StridedVec);
+        }
 
-          Value *StridedVec = State.Builder.CreateExtractValue(DI, I);
+        assert(TempDeinterleavedValues.size() == InterleaveFactor &&
+               "Num of deinterleaved values must equals to InterleaveFactor");
+        // Sort deinterleaved values
+        DeinterleavedValues.resize(InterleaveFactor);
+        for (unsigned I = 0, J = InterleaveFactor / 2, K = 0;
+             K < InterleaveFactor; K++) {
+          auto *DeinterleavedValue = TempDeinterleavedValues.front();
+          TempDeinterleavedValues.pop();
+          if (K % 2 == 0) {
+            DeinterleavedValues[I] = DeinterleavedValue;
+            I++;
+          } else {
+            DeinterleavedValues[J] = DeinterleavedValue;
+            J++;
+          }
+        }
+#ifndef NDEBUG
+        for (Value *Val : DeinterleavedValues)
+          assert(Val && "NULL Deinterleaved Value");
+#endif
+        for (unsigned I = 0, J = 0; I < InterleaveFactor; ++I) {
+          Instruction *Member = Group->getMember(I);
+          Value *StridedVec = DeinterleavedValues[I];
+          if (!Member) {
+            // This value is not needed as it's not used
+            static_cast<Instruction *>(StridedVec)->eraseFromParent();
+            continue;
+          }
           // If this member has different type, cast the result type.
           if (Member->getType() != ScalarTy) {
             VectorType *OtherVTy = VectorType::get(Member->getType(), State.VF);
diff --git a/llvm/test/Transforms/InterleavedAccess/AArch64/sve-deinterleave4.ll b/llvm/test/Transforms/InterleavedAccess/AArch64/sve-deinterleave4.ll
index 06ecff67298813..f5f6c1878c82ef 100644
--- a/llvm/test/Transforms/InterleavedAccess/AArch64/sve-deinterleave4.ll
+++ b/llvm/test/Transforms/InterleavedAccess/AArch64/sve-deinterleave4.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
-; RUN: opt < %s -passes=interleaved-access -mtriple=aarch64-linux-gnu -mattr=+sve -S | FileCheck %s
+; RUN: opt < %s -passes=loop-vectorize,interleaved-access -mtriple=aarch64-linux-gnu -mattr=+sve -S | FileCheck %s
 
 
 define void @deinterleave4(ptr %src) {
@@ -136,3 +136,145 @@ define void @negative_deinterleave4_test(ptr %src) {
 
   ret void
 }
+
+%struct.xyzt = type { i32, i32, i32, i32 }
+
+define void @interleave_deinterleave(ptr writeonly %dst, ptr readonly %a, ptr readonly %b) {
+; CHECK-LABEL: define void @interleave_deinterleave
+; CHECK-SAME: (ptr writeonly [[DST:%.*]], ptr readonly [[A:%.*]], ptr readonly [[B:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.umax.i64(i64 8, i64 [[TMP1]])
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP2]]
+; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_MEMCHECK:%.*]]
+; CHECK:       vector.memcheck:
+; CHECK-NEXT:    [[SCEVGEP:%.*]] = getelementptr i8, ptr [[DST]], i64 16384
+; CHECK-NEXT:    [[SCEVGEP1:%.*]] = getelementptr i8, ptr [[A]], i64 16384
+; CHECK-NEXT:    [[SCEVGEP2:%.*]] = getelementptr i8, ptr [[B]], i64 16384
+; CHECK-NEXT:    [[BOUND0:%.*]] = icmp ult ptr [[DST]], [[SCEVGEP1]]
+; CHECK-NEXT:    [[BOUND1:%.*]] = icmp ult ptr [[A]], [[SCEVGEP]]
+; CHECK-NEXT:    [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]]
+; CHECK-NEXT:    [[BOUND03:%.*]] = icmp ult ptr [[DST]], [[SCEVGEP2]]
+; CHECK-NEXT:    [[BOUND14:%.*]] = icmp ult ptr [[B]], [[SCEVGEP]]
+; CHECK-NEXT:    [[FOUND_CONFLICT5:%.*]] = and i1 [[BOUND03]], [[BOUND14]]
+; CHECK-NEXT:    [[CONFLICT_RDX:%.*]] = or i1 [[FOUND_CONFLICT]], [[FOUND_CONFLICT5]]
+; CHECK-NEXT:    br i1 [[CONFLICT_RDX]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP3]], 4
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP4]]
+; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
+; CHECK-NEXT:    [[TMP5:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP6:%.*]] = mul i64 [[TMP5]], 4
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP7:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT_XYZT:%.*]], ptr [[A]], i64 [[TMP7]]
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds i32, ptr [[TMP8]], i32 0
+; CHECK-NEXT:    [[LDN:%.*]] = call { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.aarch64.sve.ld4.sret.nxv4i32(<vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), ptr [[TMP9]])
+; CHECK-NEXT:    [[TMP10:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } [[LDN]], 0
+; CHECK-NEXT:    [[TMP11:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } [[LDN]], 1
+; CHECK-NEXT:    [[TMP12:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } [[LDN]], 2
+; CHECK-NEXT:    [[TMP13:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } [[LDN]], 3
+; CHECK-NEXT:    [[TMP14:%.*]] = getelementptr inbounds [[STRUCT_XYZT]], ptr [[B]], i64 [[TMP7]]
+; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr inbounds i32, ptr [[TMP14]], i32 0
+; CHECK-NEXT:    [[LDN14:%.*]] = call { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.aarch64.sve.ld4.sret.nxv4i32(<vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), ptr [[TMP15]])
+; CHECK-NEXT:    [[TMP16:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } [[LDN14]], 0
+; CHECK-NEXT:    [[TMP17:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } [[LDN14]], 1
+; CHECK-NEXT:    [[TMP18:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } [[LDN14]], 2
+; CHECK-NEXT:    [[TMP19:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } [[LDN14]], 3
+; CHECK-NEXT:    [[TMP20:%.*]] = add nsw <vscale x 4 x i32> [[TMP16]], [[TMP10]]
+; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr inbounds [[STRUCT_XYZT]], ptr [[DST]], i64 [[TMP7]]
+; CHECK-NEXT:    [[TMP22:%.*]] = sub nsw <vscale x 4 x i32> [[TMP11]], [[TMP17]]
+; CHECK-NEXT:    [[TMP23:%.*]] = shl <vscale x 4 x i32> [[TMP12]], [[TMP18]]
+; CHECK-NEXT:    [[TMP24:%.*]] = ashr <vscale x 4 x i32> [[TMP13]], [[TMP19]]
+; CHECK-NEXT:    [[TMP25:%.*]] = getelementptr inbounds i8, ptr [[TMP21]], i64 12
+; CHECK-NEXT:    [[TMP26:%.*]] = getelementptr inbounds i32, ptr [[TMP25]], i32 -3
+; CHECK-NEXT:    call void @llvm.aarch64.sve.st4.nxv4i32(<vscale x 4 x i32> [[TMP20]], <vscale x 4 x i32> [[TMP22]], <vscale x 4 x i32> [[TMP23]], <vscale x 4 x i32> [[TMP24]], <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), ptr [[TMP26]])
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP6]]
+; CHECK-NEXT:    [[TMP27:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP27]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK:       scalar.ph:
+; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ], [ 0, [[VECTOR_MEMCHECK]] ]
+; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
+; CHECK:       for.body:
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds [[STRUCT_XYZT]], ptr [[A]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP28:%.*]] = load i32, ptr [[ARRAYIDX]], align 4
+; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr inbounds [[STRUCT_XYZT]], ptr [[B]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    [[TMP29:%.*]] = load i32, ptr [[ARRAYIDX2]], align 4
+; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[TMP29]], [[TMP28]]
+; CHECK-NEXT:    [[ARRAYIDX5:%.*]] = getelementptr inbounds [[STRUCT_XYZT]], ptr [[DST]], i64 [[INDVARS_IV]]
+; CHECK-NEXT:    store i32 [[ADD]], ptr [[ARRAYIDX5]], align 4
+; CHECK-NEXT:    [[Y:%.*]] = getelementptr inbounds nuw i8, ptr [[ARRAYIDX]], i64 4
+; CHECK-NEXT:    [[TMP30:%.*]] = load i32, ptr [[Y]], align 4
+; CHECK-NEXT:    [[Y11:%.*]] = getelementptr inbounds nuw i8, ptr [[ARRAYIDX2]], i64 4
+; CHECK-NEXT:    [[TMP31:%.*]] = load i32, ptr [[Y11]], align 4
+; CHECK-NEXT:    [[SUB:%.*]] = sub nsw i32 [[TMP30]], [[TMP31]]
+; CHECK-NEXT:    [[Y14:%.*]] = getelementptr inbounds nuw i8, ptr [[ARRAYIDX5]], i64 4
+; CHECK-NEXT:    store i32 [[SUB]], ptr [[Y14]], align 4
+; CHECK-NEXT:    [[Z:%.*]] = getelementptr inbounds nuw i8, ptr [[ARRAYIDX]], i64 8
+; CHECK-NEXT:    [[TMP32:%.*]] = load i32, ptr [[Z]], align 4
+; CHECK-NEXT:    [[Z19:%.*]] = getelementptr inbounds nuw i8, ptr [[ARRAYIDX2]], i64 8
+; CHECK-NEXT:    [[TMP33:%.*]] = load i32, ptr [[Z19]], align 4
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[TMP32]], [[TMP33]]
+; CHECK-NEXT:    [[Z22:%.*]] = getelementptr inbounds nuw i8, ptr [[ARRAYIDX5]], i64 8
+; CHECK-NEXT:    store i32 [[SHL]], ptr [[Z22]], align 4
+; CHECK-NEXT:    [[T:%.*]] = getelementptr inbounds nuw i8, ptr [[ARRAYIDX]], i64 12
+; CHECK-NEXT:    [[TMP34:%.*]] = load i32, ptr [[T]], align 4
+; CHECK-NEXT:    [[T27:%.*]] = getelementptr inbounds nuw i8, ptr [[ARRAYIDX2]], i64 12
+; CHECK-NEXT:    [[TMP35:%.*]] = load i32, ptr [[T27]], align 4
+; CHECK-NEXT:    [[SHR:%.*]] = ashr i32 [[TMP34]], [[TMP35]]
+; CHECK-NEXT:    [[T30:%.*]] = getelementptr inbounds nuw i8, ptr [[ARRAYIDX5]], i64 12
+; CHECK-NEXT:    store i32 [[SHR]], ptr [[T30]], align 4
+; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
+; CHECK-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], 1024
+; CHECK-NEXT:    br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP]], label [[FOR_BODY]], !llvm.loop [[LOOP3:![0-9]+]]
+; CHECK:       for.cond.cleanup:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %for.body
+
+for.body:
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %arrayidx = getelementptr inbounds %struct.xyzt, ptr %a, i64 %indvars.iv
+  %0 = load i32, ptr %arrayidx, align 4
+  %arrayidx2 = getelementptr inbounds %struct.xyzt, ptr %b, i64 %indvars.iv
+  %1 = load i32, ptr %arrayidx2, align 4
+  %add = add nsw i32 %1, %0
+  %arrayidx5 = getelementptr inbounds %struct.xyzt, ptr %dst, i64 %indvars.iv
+  store i32 %add, ptr %arrayidx5, align 4
+  %y = getelementptr inbounds nuw i8, ptr %arrayidx, i64 4
+  %2 = load i32, ptr %y, align 4
+  %y11 = getelementptr inbounds nuw i8, ptr %arrayidx2, i64 4
+  %3 = load i32, ptr %y11, align 4
+  %sub = sub nsw i32 %2, %3
+  %y14 = getelementptr inbounds nuw i8, ptr %arrayidx5, i64 4
+  store i32 %sub, ptr %y14, align 4
+  %z = getelementptr inbounds nuw i8, ptr %arrayidx, i64 8
+  %4 = load i32, ptr %z, align 4
+  %z19 = getelementptr inbounds nuw i8, ptr %arrayidx2, i64 8
+  %5 = load i32, ptr %z19, align 4
+  %shl = shl i32 %4, %5
+  %z22 = getelementptr inbounds nuw i8, ptr %arrayidx5, i64 8
+  store i32 %shl, ptr %z22, align 4
+  %t = getelementptr inbounds nuw i8, ptr %arrayidx, i64 12
+  %6 = load i32, ptr %t, align 4
+  %t27 = getelementptr inbounds nuw i8, ptr %arrayidx2, i64 12
+  %7 = load i32, ptr %t27, align 4
+  %shr = ashr i32 %6, %7
+  %t30 = getelementptr inbounds nuw i8, ptr %arrayidx5, i64 12
+  store i32 %shr, ptr %t30, align 4
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, 1024
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+
+for.cond.cleanup:
+  ret void
+}
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll
index d6794420c403f9..61bcdbc630eb82 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll
@@ -403,8 +403,8 @@ define void @test_reversed_load2_store2(ptr noalias nocapture readonly %A, ptr n
 ; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <vscale x 8 x i32>, ptr [[TMP9]], align 4
 ; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32> [[WIDE_VEC]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[STRIDED_VEC]], 0
-; CHECK-NEXT:    [[REVERSE:%.*]] = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> [[TMP10]])
 ; CHECK-NEXT:    [[TMP11:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[STRIDED_VEC]], 1
+; CHECK-NEXT:    [[REVERSE:%.*]] = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> [[TMP10]])
 ; CHECK-NEXT:    [[REVERSE1:%.*]] = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> [[TMP11]])
 ; CHECK-NEXT:    [[TMP12:%.*]] = add nsw <vscale x 4 x i32> [[REVERSE]], [[VEC_IND]]
 ; CHECK-NEXT:    [[TMP13:%.*]] = sub nsw <vscale x 4 x i32> [[REVERSE1]], [[VEC_IND]]
@@ -1521,10 +1521,10 @@ define void @PR34743(ptr %a, ptr %b, i64 %n) #1 {
 ; CHECK-NEXT:    [[SCALAR_RECUR_INIT:%.*]] = phi i16 [ [[VECTOR_RECUR_EXTRACT]], [[MIDDLE_BLOCK]] ], [ [[DOTPRE]], [[ENTRY]] ], [ [[DOTPRE]], [[VECTOR_MEMCHECK]] ]
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[SCALAR_RECUR:%.*]] = phi i16 [ [[SCALAR_RECUR_INIT]], [[SCALAR_PH]] ], [ [[LOAD2:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[TMP33:%.*]] = phi i16 [ [[SCALAR_RECUR_INIT]], [[SCALAR_PH]] ], [ [[LOAD2:%.*]], [[LOOP]] ]
 ; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[IV2:%.*]], [[LOOP]] ]
 ; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[BC_RESUME_VAL3]], [[SCALAR_PH]] ], [ [[I1:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[CONV:%.*]] = sext i16 [[SCALAR_RECUR]] to i32
+; CHECK-NEXT:    [[CONV:%.*]] = sext i16 [[TMP33]] to i32
 ; CHECK-NEXT:    [[I1]] = add nuw nsw i64 [[I]], 1
 ; CHECK-NEXT:    [[IV1:%.*]] = or disjoint i64 [[IV]], 1
 ; CHECK-NEXT:    [[IV2]] = add nuw nsw i64 [[IV]], 2

>From a1c5378a5a017b9eec26c49eed98a825ddc93351 Mon Sep 17 00:00:00 2001
From: Hassnaa Hamdi <hassnaa.hamdi at arm.com>
Date: Tue, 8 Oct 2024 00:08:35 +0000
Subject: [PATCH 2/3] [resolve review comments] do multiple ordering during
 (de)interleaving to put the nodes in the correct order for (de)interleaving.

Change-Id: I151a53c459a7f69e35feb428c1dface2fe57e9ce
---
 .../AArch64/AArch64TargetTransformInfo.cpp    |   6 +-
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 109 ++++++++++--------
 2 files changed, 65 insertions(+), 50 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 880eccd358b01b..0d4ab1ffe2f067 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3558,9 +3558,7 @@ InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
   assert(Factor >= 2 && "Invalid interleave factor");
   auto *VecVTy = cast<VectorType>(VecTy);
 
-  unsigned MaxFactor = TLI->getMaxSupportedInterleaveFactor();
-  if (VecTy->isScalableTy() &&
-      (!ST->hasSVE() || !isPowerOf2_32(Factor) || Factor > MaxFactor))
+  if (VecTy->isScalableTy() && (!ST->hasSVE() || (Factor != 2 && Factor != 4)))
     return InstructionCost::getInvalid();
 
   // Vectorization for masked interleaved accesses is only enabled for scalable
@@ -3568,7 +3566,7 @@ InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
   if (!VecTy->isScalableTy() && (UseMaskForCond || UseMaskForGaps))
     return InstructionCost::getInvalid();
 
-  if (!UseMaskForGaps && Factor <= MaxFactor) {
+  if (!UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) {
     unsigned MinElts = VecVTy->getElementCount().getKnownMinValue();
     auto *SubVecTy =
         VectorType::get(VecVTy->getElementType(),
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 62e5cb4c00c7be..aef85e58c9609a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2127,28 +2127,39 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
   // Scalable vectors cannot use arbitrary shufflevectors (only splats), so
   // must use intrinsics to interleave.
   if (VecTy->isScalableTy()) {
+    if (Vals.size() == 2) {
+      VectorType *WideVecTy = VectorType::getDoubleElementsVectorType(VecTy);
+      return Builder.CreateIntrinsic(WideVecTy, Intrinsic::vector_interleave2,
+                                     Vals,
+                                     /*FMFSource=*/nullptr, Name);
+    }
     unsigned InterleaveFactor = Vals.size();
-    SmallVector<Value *> InterleavingValues;
-    unsigned InterleavingValuesCount =
-        InterleaveFactor + (InterleaveFactor - 2);
-    InterleavingValues.resize(InterleaveFactor);
-    // Place the values to be interleaved in the correct order for the
-    // interleaving
-    for (unsigned I = 0, J = InterleaveFactor / 2, K = 0; K < InterleaveFactor;
-         K++) {
-      if (K % 2 == 0) {
-        InterleavingValues[K] = Vals[I];
-        I++;
-      } else {
-        InterleavingValues[K] = Vals[J];
-        J++;
+    SmallVector<Value *> InterleavingValues(Vals);
+    // The total number of nodes in a balanced binary tree is calculated as 2n -
+    // 1, where `n` is the number of leaf nodes (`InterleaveFactor`). In this
+    // context, we exclude the root node because it will serve as the final
+    // interleaved value. Thus, the number of nodes to be processed/interleaved
+    // is: (2n - 1) - 1 = 2n - 2.
+
+    unsigned NumInterleavingValues = 2 * InterleaveFactor - 2;
+    for (unsigned I = 1; I < NumInterleavingValues; I += 2) {
+      // values that haven't been processed yet:
+      unsigned Remaining = InterleavingValues.size() - I + 1;
+      if (Remaining > 2 && isPowerOf2_32(Remaining)) {
+
+        // The remaining values form a new level in the interleaving tree.
+        // Arrange these values in the correct interleaving order for this
+        // level. The interleaving order places alternating elements from the
+        // first and second halves,
+        std::vector<Value *> RemainingValues(InterleavingValues.begin() + I - 1,
+                                             InterleavingValues.end());
+        unsigned Middle = Remaining / 2;
+        for (unsigned J = I - 1, K = 0; J < InterleavingValues.size();
+             J += 2, K++) {
+          InterleavingValues[J] = RemainingValues[K];
+          InterleavingValues[J + 1] = RemainingValues[Middle + K];
+        }
       }
-    }
-#ifndef NDEBUG
-    for (Value *Val : InterleavingValues)
-      assert(Val && "NULL Interleaving Value");
-#endif
-    for (unsigned I = 1; I < InterleavingValuesCount; I += 2) {
       VectorType *InterleaveTy =
           cast<VectorType>(InterleavingValues[I]->getType());
       VectorType *WideVecTy =
@@ -2159,7 +2170,7 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
           /*FMFSource=*/nullptr, Name);
       InterleavingValues.push_back(InterleaveRes);
     }
-    return InterleavingValues[InterleavingValuesCount];
+    return InterleavingValues[NumInterleavingValues];
   }
 
   // Fixed length. Start by concatenating all vectors into a wide vector.
@@ -2327,42 +2338,48 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
 
         SmallVector<Value *> DeinterleavedValues;
         // If the InterleaveFactor is > 2, so we will have to do recursive
-        // deinterleaving, because the current available deinterleave intrinsice
+        // deinterleaving, because the current available deinterleave intrinsic
         // supports only Factor of 2. DeinterleaveCount represent how many times
         // we will do deinterleaving, we will do deinterleave on all nonleaf
         // nodes in the deinterleave tree.
         unsigned DeinterleaveCount = InterleaveFactor - 1;
-        std::queue<Value *> TempDeinterleavedValues;
-        TempDeinterleavedValues.push(NewLoads[Part]);
+        std::vector<Value *> TempDeinterleavedValues;
+        TempDeinterleavedValues.push_back(NewLoads[Part]);
         for (unsigned I = 0; I < DeinterleaveCount; ++I) {
-          Value *ValueToDeinterleave = TempDeinterleavedValues.front();
-          auto *DiTy = ValueToDeinterleave->getType();
-          TempDeinterleavedValues.pop();
+          auto *DiTy = TempDeinterleavedValues[I]->getType();
           Value *DI = State.Builder.CreateIntrinsic(
-              Intrinsic::vector_deinterleave2, DiTy, ValueToDeinterleave,
+              Intrinsic::vector_deinterleave2, DiTy, TempDeinterleavedValues[I],
               /*FMFSource=*/nullptr, "strided.vec");
           Value *StridedVec = State.Builder.CreateExtractValue(DI, 0);
-          TempDeinterleavedValues.push(StridedVec);
+          TempDeinterleavedValues.push_back(StridedVec);
           StridedVec = State.Builder.CreateExtractValue(DI, 1);
-          TempDeinterleavedValues.push(StridedVec);
-        }
-
-        assert(TempDeinterleavedValues.size() == InterleaveFactor &&
-               "Num of deinterleaved values must equals to InterleaveFactor");
-        // Sort deinterleaved values
-        DeinterleavedValues.resize(InterleaveFactor);
-        for (unsigned I = 0, J = InterleaveFactor / 2, K = 0;
-             K < InterleaveFactor; K++) {
-          auto *DeinterleavedValue = TempDeinterleavedValues.front();
-          TempDeinterleavedValues.pop();
-          if (K % 2 == 0) {
-            DeinterleavedValues[I] = DeinterleavedValue;
-            I++;
-          } else {
-            DeinterleavedValues[J] = DeinterleavedValue;
-            J++;
+          TempDeinterleavedValues.push_back(StridedVec);
+          // Perform sorting at the start of each new level in the tree.
+          // A new level begins when the number of remaining values is a power
+          // of 2 and greater than 2. If a level has only 2 nodes, no sorting is
+          // needed as they are already in order. Number of remaining values to
+          // be processed:
+          unsigned NumRemainingValues = TempDeinterleavedValues.size() - I - 1;
+          if (NumRemainingValues > 2 && isPowerOf2_32(NumRemainingValues)) {
+            // these remaining values represent a new level in the tree,
+            // Reorder the values to match the correct deinterleaving order.
+            std::vector<Value *> RemainingValues(
+                TempDeinterleavedValues.begin() + I + 1,
+                TempDeinterleavedValues.end());
+            unsigned Middle = NumRemainingValues / 2;
+            for (unsigned J = 0, K = I + 1; J < NumRemainingValues;
+                 J += 2, K++) {
+              TempDeinterleavedValues[K] = RemainingValues[J];
+              TempDeinterleavedValues[Middle + K] = RemainingValues[J + 1];
+            }
           }
         }
+        // Final deinterleaved values:
+        DeinterleavedValues.insert(DeinterleavedValues.begin(),
+                                   TempDeinterleavedValues.begin() +
+                                       InterleaveFactor - 1,
+                                   TempDeinterleavedValues.end());
+
 #ifndef NDEBUG
         for (Value *Val : DeinterleavedValues)
           assert(Val && "NULL Deinterleaved Value");

>From 91a1a245b0693e9e47f3d82f2e617ad351af78ac Mon Sep 17 00:00:00 2001
From: Hassnaa Hamdi <hassnaa.hamdi at arm.com>
Date: Thu, 31 Oct 2024 19:53:38 +0000
Subject: [PATCH 3/3] refactoring

Change-Id: If2a3789ed76c98a5f1d1be729f5051a2c54af2a7
---
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 111 +++++-------------
 1 file changed, 32 insertions(+), 79 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index aef85e58c9609a..8029538e24ea01 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -31,7 +31,6 @@
 #include "llvm/Transforms/Utils/LoopUtils.h"
 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
 #include <cassert>
-#include <queue>
 
 using namespace llvm;
 
@@ -2127,50 +2126,22 @@ static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals,
   // Scalable vectors cannot use arbitrary shufflevectors (only splats), so
   // must use intrinsics to interleave.
   if (VecTy->isScalableTy()) {
-    if (Vals.size() == 2) {
-      VectorType *WideVecTy = VectorType::getDoubleElementsVectorType(VecTy);
-      return Builder.CreateIntrinsic(WideVecTy, Intrinsic::vector_interleave2,
-                                     Vals,
-                                     /*FMFSource=*/nullptr, Name);
-    }
     unsigned InterleaveFactor = Vals.size();
     SmallVector<Value *> InterleavingValues(Vals);
-    // The total number of nodes in a balanced binary tree is calculated as 2n -
-    // 1, where `n` is the number of leaf nodes (`InterleaveFactor`). In this
-    // context, we exclude the root node because it will serve as the final
-    // interleaved value. Thus, the number of nodes to be processed/interleaved
-    // is: (2n - 1) - 1 = 2n - 2.
-
-    unsigned NumInterleavingValues = 2 * InterleaveFactor - 2;
-    for (unsigned I = 1; I < NumInterleavingValues; I += 2) {
-      // values that haven't been processed yet:
-      unsigned Remaining = InterleavingValues.size() - I + 1;
-      if (Remaining > 2 && isPowerOf2_32(Remaining)) {
-
-        // The remaining values form a new level in the interleaving tree.
-        // Arrange these values in the correct interleaving order for this
-        // level. The interleaving order places alternating elements from the
-        // first and second halves,
-        std::vector<Value *> RemainingValues(InterleavingValues.begin() + I - 1,
-                                             InterleavingValues.end());
-        unsigned Middle = Remaining / 2;
-        for (unsigned J = I - 1, K = 0; J < InterleavingValues.size();
-             J += 2, K++) {
-          InterleavingValues[J] = RemainingValues[K];
-          InterleavingValues[J + 1] = RemainingValues[Middle + K];
-        }
-      }
+    // As we are interleaving, the values sz will be shrinked until we have the
+    // single final interleaved value.
+    for (unsigned Midpoint = Factor / 2; Midpoint > 0; Midpoint /= 2) {
       VectorType *InterleaveTy =
-          cast<VectorType>(InterleavingValues[I]->getType());
+          cast<VectorType>(InterleavingValues[0]->getType());
       VectorType *WideVecTy =
           VectorType::getDoubleElementsVectorType(InterleaveTy);
-      auto *InterleaveRes = Builder.CreateIntrinsic(
-          WideVecTy, Intrinsic::vector_interleave2,
-          {InterleavingValues[I - 1], InterleavingValues[I]},
-          /*FMFSource=*/nullptr, Name);
-      InterleavingValues.push_back(InterleaveRes);
+      for (unsigned I = 0; I < Midpoint; ++I)
+        InterleavingValues[I] = Builder.CreateIntrinsic(
+            WideVecTy, Intrinsic::vector_interleave2,
+            {InterleavingValues[I], InterleavingValues[Midpoint + I]},
+            /*FMFSource=*/nullptr, Name);
     }
-    return InterleavingValues[NumInterleavingValues];
+    return InterleavingValues[0];
   }
 
   // Fixed length. Start by concatenating all vectors into a wide vector.
@@ -2336,49 +2307,31 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
         // Scalable vectors cannot use arbitrary shufflevectors (only splats),
         // so must use intrinsics to deinterleave.
 
-        SmallVector<Value *> DeinterleavedValues;
-        // If the InterleaveFactor is > 2, so we will have to do recursive
+        SmallVector<Value *> DeinterleavedValues(InterleaveFactor);
+        DeinterleavedValues[0] = NewLoads[Part];
+        // For the case of InterleaveFactor > 2, we will have to do recursive
         // deinterleaving, because the current available deinterleave intrinsic
-        // supports only Factor of 2. DeinterleaveCount represent how many times
-        // we will do deinterleaving, we will do deinterleave on all nonleaf
-        // nodes in the deinterleave tree.
-        unsigned DeinterleaveCount = InterleaveFactor - 1;
-        std::vector<Value *> TempDeinterleavedValues;
-        TempDeinterleavedValues.push_back(NewLoads[Part]);
-        for (unsigned I = 0; I < DeinterleaveCount; ++I) {
-          auto *DiTy = TempDeinterleavedValues[I]->getType();
-          Value *DI = State.Builder.CreateIntrinsic(
-              Intrinsic::vector_deinterleave2, DiTy, TempDeinterleavedValues[I],
-              /*FMFSource=*/nullptr, "strided.vec");
-          Value *StridedVec = State.Builder.CreateExtractValue(DI, 0);
-          TempDeinterleavedValues.push_back(StridedVec);
-          StridedVec = State.Builder.CreateExtractValue(DI, 1);
-          TempDeinterleavedValues.push_back(StridedVec);
-          // Perform sorting at the start of each new level in the tree.
-          // A new level begins when the number of remaining values is a power
-          // of 2 and greater than 2. If a level has only 2 nodes, no sorting is
-          // needed as they are already in order. Number of remaining values to
-          // be processed:
-          unsigned NumRemainingValues = TempDeinterleavedValues.size() - I - 1;
-          if (NumRemainingValues > 2 && isPowerOf2_32(NumRemainingValues)) {
-            // these remaining values represent a new level in the tree,
-            // Reorder the values to match the correct deinterleaving order.
-            std::vector<Value *> RemainingValues(
-                TempDeinterleavedValues.begin() + I + 1,
-                TempDeinterleavedValues.end());
-            unsigned Middle = NumRemainingValues / 2;
-            for (unsigned J = 0, K = I + 1; J < NumRemainingValues;
-                 J += 2, K++) {
-              TempDeinterleavedValues[K] = RemainingValues[J];
-              TempDeinterleavedValues[Middle + K] = RemainingValues[J + 1];
-            }
+        // supports only Factor of 2, otherwise it will bailout after first
+        // iteration.
+        // As we are deinterleaving, the values will be doubled until reachingt
+        // to the InterleaveFactor.
+        for (int NumVectors = 1; NumVectors < InterleaveFactor;
+             NumVectors *= 2) {
+          // deinterleave the elements within the vector
+          std::vector<Value *> TempDeinterleavedValues(NumVectors);
+          for (int I = 0; I < NumVectors; ++I) {
+            auto *DiTy = DeinterleavedValues[I]->getType();
+            TempDeinterleavedValues[I] = State.Builder.CreateIntrinsic(
+                Intrinsic::vector_deinterleave2, DiTy, DeinterleavedValues[I],
+                /*FMFSource=*/nullptr, "strided.vec");
           }
+          // Extract the deinterleaved values:
+          for (int I = 0; I < 2; ++I)
+            for (int J = 0; J < NumVectors; ++J)
+              DeinterleavedValues[NumVectors * I + J] =
+                  State.Builder.CreateExtractValue(TempDeinterleavedValues[J],
+                                                   I);
         }
-        // Final deinterleaved values:
-        DeinterleavedValues.insert(DeinterleavedValues.begin(),
-                                   TempDeinterleavedValues.begin() +
-                                       InterleaveFactor - 1,
-                                   TempDeinterleavedValues.end());
 
 #ifndef NDEBUG
         for (Value *Val : DeinterleavedValues)



More information about the llvm-commits mailing list