[llvm] [IR][RISCV] Add llvm.vector.(de)interleave3/5/7 (PR #124825)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 30 10:19:40 PST 2025


================
@@ -10975,75 +10975,116 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
   assert(VecVT.isScalableVector() &&
          "vector_interleave on non-scalable vector!");
 
+  const unsigned Factor = Op->getNumValues();
+
   // 1 bit element vectors need to be widened to e8
   if (VecVT.getVectorElementType() == MVT::i1)
     return widenVectorOpsToi8(Op, DL, DAG);
 
-  // If the VT is LMUL=8, we need to split and reassemble.
-  if (VecVT.getSizeInBits().getKnownMinValue() ==
+  // If concatenating would exceed LMUL=8, we need to split.
+  if ((VecVT.getSizeInBits().getKnownMinValue() * Factor) >
       (8 * RISCV::RVVBitsPerBlock)) {
-    auto [Op0Lo, Op0Hi] = DAG.SplitVectorOperand(Op.getNode(), 0);
-    auto [Op1Lo, Op1Hi] = DAG.SplitVectorOperand(Op.getNode(), 1);
-    EVT SplitVT = Op0Lo.getValueType();
-
-    SDValue ResLo = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
-                                DAG.getVTList(SplitVT, SplitVT), Op0Lo, Op0Hi);
-    SDValue ResHi = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
-                                DAG.getVTList(SplitVT, SplitVT), Op1Lo, Op1Hi);
-
-    SDValue Even = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT,
-                               ResLo.getValue(0), ResHi.getValue(0));
-    SDValue Odd = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, ResLo.getValue(1),
-                              ResHi.getValue(1));
-    return DAG.getMergeValues({Even, Odd}, DL);
-  }
+    SmallVector<SDValue, 8> Ops(Factor * 2);
+    for (unsigned i = 0; i != Factor; ++i) {
+      auto [OpLo, OpHi] = DAG.SplitVectorOperand(Op.getNode(), i);
+      Ops[i * 2] = OpLo;
+      Ops[i * 2 + 1] = OpHi;
+    }
 
-  // Concatenate the two vectors as one vector to deinterleave
-  MVT ConcatVT =
-      MVT::getVectorVT(VecVT.getVectorElementType(),
-                       VecVT.getVectorElementCount().multiplyCoefficientBy(2));
-  SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT,
-                               Op.getOperand(0), Op.getOperand(1));
+    SmallVector<EVT, 8> VTs(Factor, Ops[0].getValueType());
 
-  // We can deinterleave through vnsrl.wi if the element type is smaller than
-  // ELEN
-  if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
-    SDValue Even = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 0, DAG);
-    SDValue Odd = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 1, DAG);
-    return DAG.getMergeValues({Even, Odd}, DL);
-  }
+    SDValue Lo = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs,
+                             ArrayRef(Ops).slice(0, Factor));
+    SDValue Hi = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs,
+                             ArrayRef(Ops).slice(Factor, Factor));
 
-  // For the indices, use the vmv.v.x of an i8 constant to fill the largest
-  // possibly mask vector, then extract the required subvector.  Doing this
-  // (instead of a vid, vmsne sequence) reduces LMUL, and allows the mask
-  // creation to be rematerialized during register allocation to reduce
-  // register pressure if needed.
+    SmallVector<SDValue, 8> Res(Factor);
+    for (unsigned i = 0; i != Factor; ++i)
+      Res[i] = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Lo.getValue(i),
+                           Hi.getValue(i));
 
-  MVT MaskVT = ConcatVT.changeVectorElementType(MVT::i1);
+    return DAG.getMergeValues(Res, DL);
+  }
 
-  SDValue EvenSplat = DAG.getConstant(0b01010101, DL, MVT::nxv8i8);
-  EvenSplat = DAG.getBitcast(MVT::nxv64i1, EvenSplat);
-  SDValue EvenMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, EvenSplat,
-                                 DAG.getVectorIdxConstant(0, DL));
+  SmallVector<SDValue, 8> Ops(Op->op_values());
 
-  SDValue OddSplat = DAG.getConstant(0b10101010, DL, MVT::nxv8i8);
-  OddSplat = DAG.getBitcast(MVT::nxv64i1, OddSplat);
-  SDValue OddMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, OddSplat,
-                                DAG.getVectorIdxConstant(0, DL));
+  // Concatenate the vectors as one vector to deinterleave
+  MVT ConcatVT =
+      MVT::getVectorVT(VecVT.getVectorElementType(),
+                       VecVT.getVectorElementCount().multiplyCoefficientBy(
+                           PowerOf2Ceil(Factor)));
+  if (Ops.size() < PowerOf2Ceil(Factor))
+    Ops.append(PowerOf2Ceil(Factor) - Factor, DAG.getUNDEF(VecVT));
+  SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, Ops);
+
+  if (Factor == 2) {
+    // We can deinterleave through vnsrl.wi if the element type is smaller than
+    // ELEN
+    if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
+      SDValue Even = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 0, DAG);
+      SDValue Odd = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 1, DAG);
+      return DAG.getMergeValues({Even, Odd}, DL);
+    }
+
+    // For the indices, use the vmv.v.x of an i8 constant to fill the largest
+    // possibly mask vector, then extract the required subvector.  Doing this
+    // (instead of a vid, vmsne sequence) reduces LMUL, and allows the mask
+    // creation to be rematerialized during register allocation to reduce
+    // register pressure if needed.
+
+    MVT MaskVT = ConcatVT.changeVectorElementType(MVT::i1);
+
+    SDValue EvenSplat = DAG.getConstant(0b01010101, DL, MVT::nxv8i8);
+    EvenSplat = DAG.getBitcast(MVT::nxv64i1, EvenSplat);
+    SDValue EvenMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT,
+                                   EvenSplat, DAG.getVectorIdxConstant(0, DL));
+
+    SDValue OddSplat = DAG.getConstant(0b10101010, DL, MVT::nxv8i8);
+    OddSplat = DAG.getBitcast(MVT::nxv64i1, OddSplat);
+    SDValue OddMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, OddSplat,
+                                  DAG.getVectorIdxConstant(0, DL));
+
+    // vcompress the even and odd elements into two separate vectors
+    SDValue EvenWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
+                                   EvenMask, DAG.getUNDEF(ConcatVT));
+    SDValue OddWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
+                                  OddMask, DAG.getUNDEF(ConcatVT));
+
+    // Extract the result half of the gather for even and odd
+    SDValue Even = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, EvenWide,
+                               DAG.getVectorIdxConstant(0, DL));
+    SDValue Odd = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, OddWide,
+                              DAG.getVectorIdxConstant(0, DL));
 
-  // vcompress the even and odd elements into two separate vectors
-  SDValue EvenWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
-                                 EvenMask, DAG.getUNDEF(ConcatVT));
-  SDValue OddWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
-                                OddMask, DAG.getUNDEF(ConcatVT));
+    return DAG.getMergeValues({Even, Odd}, DL);
+  }
 
-  // Extract the result half of the gather for even and odd
-  SDValue Even = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, EvenWide,
-                             DAG.getVectorIdxConstant(0, DL));
-  SDValue Odd = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, OddWide,
-                            DAG.getVectorIdxConstant(0, DL));
+  // We want to operate on all lanes, so get the mask and VL and mask for it
+  auto [Mask, VL] = getDefaultScalableVLOps(ConcatVT, DL, DAG, Subtarget);
+  SDValue Passthru = DAG.getUNDEF(ConcatVT);
+
+  // For the indices, use the same SEW to avoid an extra vsetvli
----------------
mshockwave wrote:

I've changed the logics here to use unit-stride store + segmented load instead.

https://github.com/llvm/llvm-project/pull/124825


More information about the llvm-commits mailing list