[llvm] 81dc54e - [X86] Add widenMaskVector helper function to remove duplicated code for widening mask vectors for KSHIFT etc.
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 31 07:44:57 PDT 2023
Author: Simon Pilgrim
Date: 2023-08-31T15:44:22+01:00
New Revision: 81dc54e823a8746cdd35e2e0c07da476cf312dc0
URL: https://github.com/llvm/llvm-project/commit/81dc54e823a8746cdd35e2e0c07da476cf312dc0
DIFF: https://github.com/llvm/llvm-project/commit/81dc54e823a8746cdd35e2e0c07da476cf312dc0.diff
LOG: [X86] Add widenMaskVector helper function to remove duplicated code for widening mask vectors for KSHIFT etc.
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 0b7a4c1fe5b0d3..3828dcf5cbc56f 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -3811,6 +3811,25 @@ static SDValue widenSubVector(SDValue Vec, bool ZeroNewElements,
return widenSubVector(VT, Vec, ZeroNewElements, Subtarget, DAG, dl);
}
+/// Widen a mask vector type to a minimum of v8i1/v16i1 to allow use of KSHIFT
+/// and bitcast with integer types.
+static MVT widenMaskVectorType(MVT VT, const X86Subtarget &Subtarget) {
+ assert(VT.getVectorElementType() == MVT::i1 && "Expected bool vector");
+ unsigned NumElts = VT.getVectorNumElements();
+ if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8)
+ return Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
+ return VT;
+}
+
+/// Widen a mask vector to a minimum of v8i1/v16i1 to allow use of KSHIFT and
+/// bitcast with integer types.
+static SDValue widenMaskVector(SDValue Vec, bool ZeroNewElements,
+ const X86Subtarget &Subtarget, SelectionDAG &DAG,
+ const SDLoc &dl) {
+ MVT VT = widenMaskVectorType(Vec.getSimpleValueType(), Subtarget);
+ return widenSubVector(VT, Vec, ZeroNewElements, Subtarget, DAG, dl);
+}
+
// Helper function to collect subvector ops that are concatenated together,
// either by ISD::CONCAT_VECTORS or a ISD::INSERT_SUBVECTOR series.
// The subvectors in Ops are guaranteed to be the same type.
@@ -4100,9 +4119,7 @@ static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG,
SDValue ZeroIdx = DAG.getIntPtrConstant(0, dl);
// Extend to natively supported kshift.
- MVT WideOpVT = OpVT;
- if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8)
- WideOpVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
+ MVT WideOpVT = widenMaskVectorType(OpVT, Subtarget);
// Inserting into the lsbs of a zero vector is legal. ISel will insert shifts
// if necessary.
@@ -9008,16 +9025,12 @@ static SDValue LowerCONCAT_VECTORSvXi1(SDValue Op,
// insert_subvector will give us two kshifts.
if (isPowerOf2_64(NonZeros) && Zeros != 0 && NonZeros > Zeros &&
Log2_64(NonZeros) != NumOperands - 1) {
- MVT ShiftVT = ResVT;
- if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8)
- ShiftVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
unsigned Idx = Log2_64(NonZeros);
SDValue SubVec = Op.getOperand(Idx);
unsigned SubVecNumElts = SubVec.getSimpleValueType().getVectorNumElements();
- SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ShiftVT,
- DAG.getUNDEF(ShiftVT), SubVec,
- DAG.getIntPtrConstant(0, dl));
- Op = DAG.getNode(X86ISD::KSHIFTL, dl, ShiftVT, SubVec,
+ MVT ShiftVT = widenMaskVectorType(ResVT, Subtarget);
+ Op = widenSubVector(ShiftVT, SubVec, false, Subtarget, DAG, dl);
+ Op = DAG.getNode(X86ISD::KSHIFTL, dl, ShiftVT, Op,
DAG.getTargetConstant(Idx * SubVecNumElts, dl, MVT::i8));
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResVT, Op,
DAG.getIntPtrConstant(0, dl));
@@ -17004,13 +17017,8 @@ static SDValue lower1BitShuffleAsKSHIFTR(const SDLoc &DL, ArrayRef<int> Mask,
assert(ShiftAmt >= 0 && "All undef?");
// Great we found a shift right.
- MVT WideVT = VT;
- if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8)
- WideVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
- SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT,
- DAG.getUNDEF(WideVT), V1,
- DAG.getIntPtrConstant(0, DL));
- Res = DAG.getNode(X86ISD::KSHIFTR, DL, WideVT, Res,
+ SDValue Res = widenMaskVector(V1, false, Subtarget, DAG, DL);
+ Res = DAG.getNode(X86ISD::KSHIFTR, DL, Res.getValueType(), Res,
DAG.getTargetConstant(ShiftAmt, DL, MVT::i8));
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
DAG.getIntPtrConstant(0, DL));
@@ -17107,12 +17115,8 @@ static SDValue lower1BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
unsigned Opcode;
int ShiftAmt = match1BitShuffleAsKSHIFT(Opcode, Mask, Offset, Zeroable);
if (ShiftAmt >= 0) {
- MVT WideVT = VT;
- if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8)
- WideVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
- SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT,
- DAG.getUNDEF(WideVT), V,
- DAG.getIntPtrConstant(0, DL));
+ SDValue Res = widenMaskVector(V, false, Subtarget, DAG, DL);
+ MVT WideVT = Res.getSimpleValueType();
// Widened right shifts need two shifts to ensure we shift in zeroes.
if (Opcode == X86ISD::KSHIFTR && WideVT != VT) {
int WideElts = WideVT.getVectorNumElements();
@@ -17650,17 +17654,9 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG,
// Extending v8i1/v16i1 to 512-bit get better performance on KNL
// than extending to 128/256bit.
if (NumElts == 1) {
- if (Subtarget.hasDQI()) {
- Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1,
- DAG.getUNDEF(MVT::v8i1), Vec,
- DAG.getIntPtrConstant(0, dl));
- return DAG.getBitcast(MVT::i8, Vec);
- }
- Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1,
- DAG.getUNDEF(MVT::v16i1), Vec,
- DAG.getIntPtrConstant(0, dl));
- return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8,
- DAG.getBitcast(MVT::i16, Vec));
+ Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);
+ MVT IntVT = MVT::getIntegerVT(Vec.getValueType().getVectorNumElements());
+ return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, DAG.getBitcast(IntVT, Vec));
}
MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8;
MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts);
@@ -17674,17 +17670,10 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG,
return Op;
// Extend to natively supported kshift.
- unsigned NumElems = VecVT.getVectorNumElements();
- MVT WideVecVT = VecVT;
- if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) {
- WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
- Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT,
- DAG.getUNDEF(WideVecVT), Vec,
- DAG.getIntPtrConstant(0, dl));
- }
+ Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);
// Use kshiftr instruction to move to the lower element.
- Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec,
+ Vec = DAG.getNode(X86ISD::KSHIFTR, dl, Vec.getSimpleValueType(), Vec,
DAG.getTargetConstant(IdxVal, dl, MVT::i8));
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op.getValueType(), Vec,
@@ -18176,20 +18165,11 @@ static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget,
if (IdxVal == 0) // the operation is legal
return Op;
- MVT VecVT = Vec.getSimpleValueType();
- unsigned NumElems = VecVT.getVectorNumElements();
-
// Extend to natively supported kshift.
- MVT WideVecVT = VecVT;
- if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) {
- WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
- Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT,
- DAG.getUNDEF(WideVecVT), Vec,
- DAG.getIntPtrConstant(0, dl));
- }
+ Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);
// Shift to the LSB.
- Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec,
+ Vec = DAG.getNode(X86ISD::KSHIFTR, dl, Vec.getSimpleValueType(), Vec,
DAG.getTargetConstant(IdxVal, dl, MVT::i8));
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op.getValueType(), Vec,
More information about the llvm-commits
mailing list