[llvm] [IR][RISCV] Add llvm.vector.(de)interleave3/5/7 (PR #124825)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 4 18:00:18 PST 2025


================
@@ -10972,115 +10978,285 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
   SDLoc DL(Op);
   MVT VecVT = Op.getSimpleValueType();
 
-  assert(VecVT.isScalableVector() &&
-         "vector_interleave on non-scalable vector!");
+  const unsigned Factor = Op->getNumValues();
 
   // 1 bit element vectors need to be widened to e8
   if (VecVT.getVectorElementType() == MVT::i1)
     return widenVectorOpsToi8(Op, DL, DAG);
 
-  // If the VT is LMUL=8, we need to split and reassemble.
-  if (VecVT.getSizeInBits().getKnownMinValue() ==
+  // Convert to scalable vectors first.
+  if (VecVT.isFixedLengthVector()) {
+    MVT ContainerVT = getContainerForFixedLengthVector(VecVT);
+    SmallVector<SDValue, 8> Ops(Factor);
+    for (unsigned i = 0U; i < Factor; ++i)
+      Ops[i] = convertToScalableVector(ContainerVT, Op.getOperand(i), DAG,
+                                       Subtarget);
+
+    SmallVector<EVT, 8> VTs(Factor, ContainerVT);
+    SDValue NewDeinterleave =
+        DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs, Ops);
+
+    SmallVector<SDValue, 8> Res(Factor);
+    for (unsigned i = 0U; i < Factor; ++i)
+      Res[i] = convertFromScalableVector(VecVT, NewDeinterleave.getValue(i),
+                                         DAG, Subtarget);
+    return DAG.getMergeValues(Res, DL);
+  }
+
+  // If concatenating would exceed LMUL=8, we need to split.
+  if ((VecVT.getSizeInBits().getKnownMinValue() * Factor) >
       (8 * RISCV::RVVBitsPerBlock)) {
-    auto [Op0Lo, Op0Hi] = DAG.SplitVectorOperand(Op.getNode(), 0);
-    auto [Op1Lo, Op1Hi] = DAG.SplitVectorOperand(Op.getNode(), 1);
-    EVT SplitVT = Op0Lo.getValueType();
-
-    SDValue ResLo = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
-                                DAG.getVTList(SplitVT, SplitVT), Op0Lo, Op0Hi);
-    SDValue ResHi = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
-                                DAG.getVTList(SplitVT, SplitVT), Op1Lo, Op1Hi);
-
-    SDValue Even = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT,
-                               ResLo.getValue(0), ResHi.getValue(0));
-    SDValue Odd = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, ResLo.getValue(1),
-                              ResHi.getValue(1));
-    return DAG.getMergeValues({Even, Odd}, DL);
+    SmallVector<SDValue, 8> Ops(Factor * 2);
+    for (unsigned i = 0; i != Factor; ++i) {
+      auto [OpLo, OpHi] = DAG.SplitVectorOperand(Op.getNode(), i);
+      Ops[i * 2] = OpLo;
+      Ops[i * 2 + 1] = OpHi;
+    }
+
+    SmallVector<EVT, 8> VTs(Factor, Ops[0].getValueType());
+
+    SDValue Lo = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs,
+                             ArrayRef(Ops).slice(0, Factor));
+    SDValue Hi = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs,
+                             ArrayRef(Ops).slice(Factor, Factor));
+
+    SmallVector<SDValue, 8> Res(Factor);
+    for (unsigned i = 0; i != Factor; ++i)
+      Res[i] = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Lo.getValue(i),
+                           Hi.getValue(i));
+
+    return DAG.getMergeValues(Res, DL);
   }
 
-  // Concatenate the two vectors as one vector to deinterleave
+  SmallVector<SDValue, 8> Ops(Op->op_values());
+
+  // Concatenate the vectors as one vector to deinterleave
   MVT ConcatVT =
       MVT::getVectorVT(VecVT.getVectorElementType(),
-                       VecVT.getVectorElementCount().multiplyCoefficientBy(2));
-  SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT,
-                               Op.getOperand(0), Op.getOperand(1));
+                       VecVT.getVectorElementCount().multiplyCoefficientBy(
+                           PowerOf2Ceil(Factor)));
+  if (Ops.size() < PowerOf2Ceil(Factor))
+    Ops.append(PowerOf2Ceil(Factor) - Factor, DAG.getUNDEF(VecVT));
+  SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, Ops);
+
+  if (Factor == 2) {
+    // We can deinterleave through vnsrl.wi if the element type is smaller than
+    // ELEN
+    if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
+      SDValue Even = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 0, DAG);
+      SDValue Odd = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 1, DAG);
+      return DAG.getMergeValues({Even, Odd}, DL);
+    }
+
+    // For the indices, use the vmv.v.x of an i8 constant to fill the largest
+    // possibly mask vector, then extract the required subvector.  Doing this
+    // (instead of a vid, vmsne sequence) reduces LMUL, and allows the mask
+    // creation to be rematerialized during register allocation to reduce
+    // register pressure if needed.
+
+    MVT MaskVT = ConcatVT.changeVectorElementType(MVT::i1);
+
+    SDValue EvenSplat = DAG.getConstant(0b01010101, DL, MVT::nxv8i8);
+    EvenSplat = DAG.getBitcast(MVT::nxv64i1, EvenSplat);
+    SDValue EvenMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT,
+                                   EvenSplat, DAG.getVectorIdxConstant(0, DL));
+
+    SDValue OddSplat = DAG.getConstant(0b10101010, DL, MVT::nxv8i8);
+    OddSplat = DAG.getBitcast(MVT::nxv64i1, OddSplat);
+    SDValue OddMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, OddSplat,
+                                  DAG.getVectorIdxConstant(0, DL));
+
+    // vcompress the even and odd elements into two separate vectors
+    SDValue EvenWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
+                                   EvenMask, DAG.getUNDEF(ConcatVT));
+    SDValue OddWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
+                                  OddMask, DAG.getUNDEF(ConcatVT));
+
+    // Extract the result half of the gather for even and odd
+    SDValue Even = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, EvenWide,
+                               DAG.getVectorIdxConstant(0, DL));
+    SDValue Odd = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, OddWide,
+                              DAG.getVectorIdxConstant(0, DL));
 
-  // We can deinterleave through vnsrl.wi if the element type is smaller than
-  // ELEN
-  if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
-    SDValue Even = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 0, DAG);
-    SDValue Odd = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 1, DAG);
     return DAG.getMergeValues({Even, Odd}, DL);
   }
 
-  // For the indices, use the vmv.v.x of an i8 constant to fill the largest
-  // possibly mask vector, then extract the required subvector.  Doing this
-  // (instead of a vid, vmsne sequence) reduces LMUL, and allows the mask
-  // creation to be rematerialized during register allocation to reduce
-  // register pressure if needed.
+  // Store with unit-stride store and load it back with segmented load.
+  MVT XLenVT = Subtarget.getXLenVT();
+  SDValue VL = getDefaultScalableVLOps(ConcatVT, DL, DAG, Subtarget).second;
+  SDValue Passthru = DAG.getUNDEF(ConcatVT);
 
-  MVT MaskVT = ConcatVT.changeVectorElementType(MVT::i1);
+  // Allocate a stack slot.
+  Align Alignment = DAG.getReducedAlign(VecVT, /*UseABI=*/false);
+  SDValue StackPtr =
+      DAG.CreateStackTemporary(ConcatVT.getStoreSize(), Alignment);
+  auto &MF = DAG.getMachineFunction();
+  auto FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+  auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
 
-  SDValue EvenSplat = DAG.getConstant(0b01010101, DL, MVT::nxv8i8);
-  EvenSplat = DAG.getBitcast(MVT::nxv64i1, EvenSplat);
-  SDValue EvenMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, EvenSplat,
-                                 DAG.getVectorIdxConstant(0, DL));
+  SDValue StoreOps[] = {DAG.getEntryNode(),
+                        DAG.getTargetConstant(Intrinsic::riscv_vse, DL, XLenVT),
+                        Concat, StackPtr, VL};
 
-  SDValue OddSplat = DAG.getConstant(0b10101010, DL, MVT::nxv8i8);
-  OddSplat = DAG.getBitcast(MVT::nxv64i1, OddSplat);
-  SDValue OddMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, OddSplat,
-                                DAG.getVectorIdxConstant(0, DL));
+  SDValue Chain = DAG.getMemIntrinsicNode(
+      ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), StoreOps,
+      ConcatVT.getVectorElementType(), PtrInfo, Alignment,
+      MachineMemOperand::MOStore, MemoryLocation::UnknownSize);
 
-  // vcompress the even and odd elements into two separate vectors
-  SDValue EvenWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
-                                 EvenMask, DAG.getUNDEF(ConcatVT));
-  SDValue OddWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
-                                OddMask, DAG.getUNDEF(ConcatVT));
+  static const Intrinsic::ID VlsegIntrinsicsIds[] = {
+      Intrinsic::riscv_vlseg2, Intrinsic::riscv_vlseg3, Intrinsic::riscv_vlseg4,
+      Intrinsic::riscv_vlseg5, Intrinsic::riscv_vlseg6, Intrinsic::riscv_vlseg7,
+      Intrinsic::riscv_vlseg8};
 
-  // Extract the result half of the gather for even and odd
-  SDValue Even = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, EvenWide,
-                             DAG.getVectorIdxConstant(0, DL));
-  SDValue Odd = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, OddWide,
-                            DAG.getVectorIdxConstant(0, DL));
+  SDValue LoadOps[] = {
----------------
mshockwave wrote:

Done

https://github.com/llvm/llvm-project/pull/124825


More information about the llvm-commits mailing list