[llvm] [IR][RISCV] Add llvm.vector.(de)interleave3/5/7 (PR #124825)
Min-Yih Hsu via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 30 10:19:40 PST 2025
================
@@ -10975,75 +10975,116 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
assert(VecVT.isScalableVector() &&
"vector_interleave on non-scalable vector!");
+ const unsigned Factor = Op->getNumValues();
+
// 1 bit element vectors need to be widened to e8
if (VecVT.getVectorElementType() == MVT::i1)
return widenVectorOpsToi8(Op, DL, DAG);
- // If the VT is LMUL=8, we need to split and reassemble.
- if (VecVT.getSizeInBits().getKnownMinValue() ==
+ // If concatenating would exceed LMUL=8, we need to split.
+ if ((VecVT.getSizeInBits().getKnownMinValue() * Factor) >
(8 * RISCV::RVVBitsPerBlock)) {
- auto [Op0Lo, Op0Hi] = DAG.SplitVectorOperand(Op.getNode(), 0);
- auto [Op1Lo, Op1Hi] = DAG.SplitVectorOperand(Op.getNode(), 1);
- EVT SplitVT = Op0Lo.getValueType();
-
- SDValue ResLo = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
- DAG.getVTList(SplitVT, SplitVT), Op0Lo, Op0Hi);
- SDValue ResHi = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
- DAG.getVTList(SplitVT, SplitVT), Op1Lo, Op1Hi);
-
- SDValue Even = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT,
- ResLo.getValue(0), ResHi.getValue(0));
- SDValue Odd = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, ResLo.getValue(1),
- ResHi.getValue(1));
- return DAG.getMergeValues({Even, Odd}, DL);
- }
+ SmallVector<SDValue, 8> Ops(Factor * 2);
+ for (unsigned i = 0; i != Factor; ++i) {
+ auto [OpLo, OpHi] = DAG.SplitVectorOperand(Op.getNode(), i);
+ Ops[i * 2] = OpLo;
+ Ops[i * 2 + 1] = OpHi;
+ }
- // Concatenate the two vectors as one vector to deinterleave
- MVT ConcatVT =
- MVT::getVectorVT(VecVT.getVectorElementType(),
- VecVT.getVectorElementCount().multiplyCoefficientBy(2));
- SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT,
- Op.getOperand(0), Op.getOperand(1));
+ SmallVector<EVT, 8> VTs(Factor, Ops[0].getValueType());
- // We can deinterleave through vnsrl.wi if the element type is smaller than
- // ELEN
- if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
- SDValue Even = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 0, DAG);
- SDValue Odd = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 1, DAG);
- return DAG.getMergeValues({Even, Odd}, DL);
- }
+ SDValue Lo = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs,
+ ArrayRef(Ops).slice(0, Factor));
+ SDValue Hi = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs,
+ ArrayRef(Ops).slice(Factor, Factor));
- // For the indices, use the vmv.v.x of an i8 constant to fill the largest
- // possibly mask vector, then extract the required subvector. Doing this
- // (instead of a vid, vmsne sequence) reduces LMUL, and allows the mask
- // creation to be rematerialized during register allocation to reduce
- // register pressure if needed.
+ SmallVector<SDValue, 8> Res(Factor);
+ for (unsigned i = 0; i != Factor; ++i)
+ Res[i] = DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Lo.getValue(i),
+ Hi.getValue(i));
- MVT MaskVT = ConcatVT.changeVectorElementType(MVT::i1);
+ return DAG.getMergeValues(Res, DL);
+ }
- SDValue EvenSplat = DAG.getConstant(0b01010101, DL, MVT::nxv8i8);
- EvenSplat = DAG.getBitcast(MVT::nxv64i1, EvenSplat);
- SDValue EvenMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, EvenSplat,
- DAG.getVectorIdxConstant(0, DL));
+ SmallVector<SDValue, 8> Ops(Op->op_values());
- SDValue OddSplat = DAG.getConstant(0b10101010, DL, MVT::nxv8i8);
- OddSplat = DAG.getBitcast(MVT::nxv64i1, OddSplat);
- SDValue OddMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, OddSplat,
- DAG.getVectorIdxConstant(0, DL));
+ // Concatenate the vectors as one vector to deinterleave
+ MVT ConcatVT =
+ MVT::getVectorVT(VecVT.getVectorElementType(),
+ VecVT.getVectorElementCount().multiplyCoefficientBy(
+ PowerOf2Ceil(Factor)));
+ if (Ops.size() < PowerOf2Ceil(Factor))
+ Ops.append(PowerOf2Ceil(Factor) - Factor, DAG.getUNDEF(VecVT));
+ SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, Ops);
+
+ if (Factor == 2) {
+ // We can deinterleave through vnsrl.wi if the element type is smaller than
+ // ELEN
+ if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
+ SDValue Even = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 0, DAG);
+ SDValue Odd = getDeinterleaveShiftAndTrunc(DL, VecVT, Concat, 2, 1, DAG);
+ return DAG.getMergeValues({Even, Odd}, DL);
+ }
+
+ // For the indices, use the vmv.v.x of an i8 constant to fill the largest
+ // possibly mask vector, then extract the required subvector. Doing this
+ // (instead of a vid, vmsne sequence) reduces LMUL, and allows the mask
+ // creation to be rematerialized during register allocation to reduce
+ // register pressure if needed.
+
+ MVT MaskVT = ConcatVT.changeVectorElementType(MVT::i1);
+
+ SDValue EvenSplat = DAG.getConstant(0b01010101, DL, MVT::nxv8i8);
+ EvenSplat = DAG.getBitcast(MVT::nxv64i1, EvenSplat);
+ SDValue EvenMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT,
+ EvenSplat, DAG.getVectorIdxConstant(0, DL));
+
+ SDValue OddSplat = DAG.getConstant(0b10101010, DL, MVT::nxv8i8);
+ OddSplat = DAG.getBitcast(MVT::nxv64i1, OddSplat);
+ SDValue OddMask = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MaskVT, OddSplat,
+ DAG.getVectorIdxConstant(0, DL));
+
+ // vcompress the even and odd elements into two separate vectors
+ SDValue EvenWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
+ EvenMask, DAG.getUNDEF(ConcatVT));
+ SDValue OddWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
+ OddMask, DAG.getUNDEF(ConcatVT));
+
+ // Extract the result half of the gather for even and odd
+ SDValue Even = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, EvenWide,
+ DAG.getVectorIdxConstant(0, DL));
+ SDValue Odd = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, OddWide,
+ DAG.getVectorIdxConstant(0, DL));
- // vcompress the even and odd elements into two separate vectors
- SDValue EvenWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
- EvenMask, DAG.getUNDEF(ConcatVT));
- SDValue OddWide = DAG.getNode(ISD::VECTOR_COMPRESS, DL, ConcatVT, Concat,
- OddMask, DAG.getUNDEF(ConcatVT));
+ return DAG.getMergeValues({Even, Odd}, DL);
+ }
- // Extract the result half of the gather for even and odd
- SDValue Even = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, EvenWide,
- DAG.getVectorIdxConstant(0, DL));
- SDValue Odd = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VecVT, OddWide,
- DAG.getVectorIdxConstant(0, DL));
+ // We want to operate on all lanes, so get the mask and VL and mask for it
+ auto [Mask, VL] = getDefaultScalableVLOps(ConcatVT, DL, DAG, Subtarget);
+ SDValue Passthru = DAG.getUNDEF(ConcatVT);
+
+ // For the indices, use the same SEW to avoid an extra vsetvli
----------------
mshockwave wrote:
I've changed the logics here to use unit-stride store + segmented load instead.
https://github.com/llvm/llvm-project/pull/124825
More information about the llvm-commits
mailing list