[llvm] [IR][RISCV] Add llvm.vector.(de)interleave3/5/7 (PR #124825)
Min-Yih Hsu via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 4 18:00:18 PST 2025
================
@@ -10972,115 +10978,285 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
SDLoc DL(Op);
MVT VecVT = Op.getSimpleValueType();
- assert(VecVT.isScalableVector() &&
- "vector_interleave on non-scalable vector!");
+ const unsigned Factor = Op->getNumValues();
// 1 bit element vectors need to be widened to e8
if (VecVT.getVectorElementType() == MVT::i1)
return widenVectorOpsToi8(Op, DL, DAG);
- // If the VT is LMUL=8, we need to split and reassemble.
- if (VecVT.getSizeInBits().getKnownMinValue() ==
+ // Convert to scalable vectors first.
+ if (VecVT.isFixedLengthVector()) {
+ MVT ContainerVT = getContainerForFixedLengthVector(VecVT);
+ SmallVector<SDValue, 8> Ops(Factor);
+ for (unsigned i = 0U; i < Factor; ++i)
+ Ops[i] = convertToScalableVector(ContainerVT, Op.getOperand(i), DAG,
+ Subtarget);
+
+ SmallVector<EVT, 8> VTs(Factor, ContainerVT);
+ SDValue NewDeinterleave =
+ DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs, Ops);
+
+ SmallVector<SDValue, 8> Res(Factor);
+ for (unsigned i = 0U; i < Factor; ++i)
+ Res[i] = convertFromScalableVector(VecVT, NewDeinterleave.getValue(i),
+ DAG, Subtarget);
+ return DAG.getMergeValues(Res, DL);
+ }
+
+ // If concatenating would exceed LMUL=8, we need to split.
+ if ((VecVT.getSizeInBits().getKnownMinValue() * Factor) >
(8 * RISCV::RVVBitsPerBlock)) {
- auto [Op0Lo, Op0Hi] = DAG.SplitVectorOperand(Op.getNode(), 0);
- auto [Op1Lo, Op1Hi] = DAG.SplitVectorOperand(Op.getNode(), 1);
- EVT SplitVT = Op0Lo.getValueType();
-
- SDValue ResLo = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
- DAG.getVTList(SplitVT, SplitVT), Op0Lo, Op0Hi);
- SDValue ResHi = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
- DAG.getVTList(SplitVT, SplitVT), Op1Lo, Op1Hi);
-
- SDValue Even = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT,
- ResLo.getValue(0), ResHi.getValue(0));
- SDValue Odd = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, ResLo.getValue(1),
- ResHi.getValue(1));
- return DAG.getMergeValues({Even, Odd}, DL);
+ SmallVector<SDValue, 8> Ops(Factor * 2);
+ for (unsigned i = 0; i != Factor; ++i) {
+ auto [OpLo, OpHi] = DAG.SplitVectorOperand(Op.getNode(), i);
+ Ops[i * 2] = OpLo;
+ Ops[i * 2 + 1] = OpHi;
+ }
+
+ SmallVector<EVT, 8> VTs(Factor, Ops[0].getValueType());
+
+ SDValue Lo = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs,
+ ArrayRef(Ops).slice(0, Factor));
+ SDValue Hi = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs,
+ ArrayRef(Ops).slice(Factor, Factor));
+
+ SmallVector<SDValue, 8> Res(Factor);
+ for (unsigned i = 0; i != Factor; ++i)
+ Res[i] = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Lo.getValue(i),
+ Hi.getValue(i));
+
+ return DAG.getMergeValues(Res, DL);
}
- // Concatenate the two vectors as one vector to deinterleave
+ SmallVector<SDValue, 8> Ops(Op->op_values());
+
+ // Concatenate the vectors as one vector to deinterleave
MVT ConcatVT =
MVT::getVectorVT(VecVT.getVectorElementType(),
- VecVT.getVectorElementCount().multiplyCoefficientBy(2));
- SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT,
- Op.getOperand(0), Op.getOperand(1));
+ VecVT.getVectorElementCount().multiplyCoefficientBy(
+ PowerOf2Ceil(Factor)));
+ if (Ops.size() < PowerOf2Ceil(Factor))
+ Ops.append(PowerOf2Ceil(Factor) - Factor, DAG.getUNDEF(VecVT));
+ SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, Ops);
+
+ if (Factor == 2) {
+ // We can deinterleave through vnsrl.wi if the element type is smaller than
+ // ELEN
+ if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
+ SDValue Even = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 0, DAG);
+ SDValue Odd = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 1, DAG);
+ return DAG.getMergeValues({Even, Odd}, DL);
+ }
+
+ // For the indices, use the vmv.v.x of an i8 constant to fill the largest
+ // possibly mask vector, then extract the required subvector. Doing this
+ // (instead of a vid, vmsne sequence) reduces LMUL, and allows the mask
+ // creation to be rematerialized during register allocation to reduce
+ // register pressure if needed.
+
+ MVT MaskVT = ConcatVT.changeVectorElementType(MVT::i1);
+
+ SDValue EvenSplat = DAG.getConstant(0b01010101, DL, MVT::nxv8i8);
+ EvenSplat = DAG.getBitcast(MVT::nxv64i1, EvenSplat);
+ SDValue EvenMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT,
+ EvenSplat, DAG.getVectorIdxConstant(0, DL));
+
+ SDValue OddSplat = DAG.getConstant(0b10101010, DL, MVT::nxv8i8);
+ OddSplat = DAG.getBitcast(MVT::nxv64i1, OddSplat);
+ SDValue OddMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, OddSplat,
+ DAG.getVectorIdxConstant(0, DL));
+
+ // vcompress the even and odd elements into two separate vectors
+ SDValue EvenWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
+ EvenMask, DAG.getUNDEF(ConcatVT));
+ SDValue OddWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
+ OddMask, DAG.getUNDEF(ConcatVT));
+
+ // Extract the result half of the gather for even and odd
+ SDValue Even = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, EvenWide,
+ DAG.getVectorIdxConstant(0, DL));
+ SDValue Odd = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, OddWide,
+ DAG.getVectorIdxConstant(0, DL));
- // We can deinterleave through vnsrl.wi if the element type is smaller than
- // ELEN
- if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
- SDValue Even = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 0, DAG);
- SDValue Odd = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 1, DAG);
return DAG.getMergeValues({Even, Odd}, DL);
}
- // For the indices, use the vmv.v.x of an i8 constant to fill the largest
- // possibly mask vector, then extract the required subvector. Doing this
- // (instead of a vid, vmsne sequence) reduces LMUL, and allows the mask
- // creation to be rematerialized during register allocation to reduce
- // register pressure if needed.
+ // Store with unit-stride store and load it back with segmented load.
+ MVT XLenVT = Subtarget.getXLenVT();
+ SDValue VL = getDefaultScalableVLOps(ConcatVT, DL, DAG, Subtarget).second;
+ SDValue Passthru = DAG.getUNDEF(ConcatVT);
- MVT MaskVT = ConcatVT.changeVectorElementType(MVT::i1);
+ // Allocate a stack slot.
+ Align Alignment = DAG.getReducedAlign(VecVT, /*UseABI=*/false);
+ SDValue StackPtr =
+ DAG.CreateStackTemporary(ConcatVT.getStoreSize(), Alignment);
+ auto &MF = DAG.getMachineFunction();
+ auto FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+ auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
- SDValue EvenSplat = DAG.getConstant(0b01010101, DL, MVT::nxv8i8);
- EvenSplat = DAG.getBitcast(MVT::nxv64i1, EvenSplat);
- SDValue EvenMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, EvenSplat,
- DAG.getVectorIdxConstant(0, DL));
+ SDValue StoreOps[] = {DAG.getEntryNode(),
+ DAG.getTargetConstant(Intrinsic::riscv_vse, DL, XLenVT),
+ Concat, StackPtr, VL};
- SDValue OddSplat = DAG.getConstant(0b10101010, DL, MVT::nxv8i8);
- OddSplat = DAG.getBitcast(MVT::nxv64i1, OddSplat);
- SDValue OddMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, OddSplat,
- DAG.getVectorIdxConstant(0, DL));
+ SDValue Chain = DAG.getMemIntrinsicNode(
+ ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), StoreOps,
+ ConcatVT.getVectorElementType(), PtrInfo, Alignment,
+ MachineMemOperand::MOStore, MemoryLocation::UnknownSize);
- // vcompress the even and odd elements into two separate vectors
- SDValue EvenWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
- EvenMask, DAG.getUNDEF(ConcatVT));
- SDValue OddWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
- OddMask, DAG.getUNDEF(ConcatVT));
+ static const Intrinsic::ID VlsegIntrinsicsIds[] = {
+ Intrinsic::riscv_vlseg2, Intrinsic::riscv_vlseg3, Intrinsic::riscv_vlseg4,
+ Intrinsic::riscv_vlseg5, Intrinsic::riscv_vlseg6, Intrinsic::riscv_vlseg7,
+ Intrinsic::riscv_vlseg8};
- // Extract the result half of the gather for even and odd
- SDValue Even = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, EvenWide,
- DAG.getVectorIdxConstant(0, DL));
- SDValue Odd = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, OddWide,
- DAG.getVectorIdxConstant(0, DL));
+ SDValue LoadOps[] = {
----------------
mshockwave wrote:
Done
https://github.com/llvm/llvm-project/pull/124825
More information about the llvm-commits
mailing list