[llvm] 81dc54e - [X86] Add widenMaskVector helper function to remove duplicated code for widening mask vectors for KSHIFT etc.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 31 07:44:57 PDT 2023


Author: Simon Pilgrim
Date: 2023-08-31T15:44:22+01:00
New Revision: 81dc54e823a8746cdd35e2e0c07da476cf312dc0

URL: https://github.com/llvm/llvm-project/commit/81dc54e823a8746cdd35e2e0c07da476cf312dc0
DIFF: https://github.com/llvm/llvm-project/commit/81dc54e823a8746cdd35e2e0c07da476cf312dc0.diff

LOG: [X86] Add widenMaskVector helper function to remove duplicated code for widening mask vectors for KSHIFT etc.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 0b7a4c1fe5b0d3..3828dcf5cbc56f 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -3811,6 +3811,25 @@ static SDValue widenSubVector(SDValue Vec, bool ZeroNewElements,
   return widenSubVector(VT, Vec, ZeroNewElements, Subtarget, DAG, dl);
 }
 
+/// Widen a mask vector type to a minimum of v8i1/v16i1 to allow use of KSHIFT
+/// and bitcast with integer types.
+static MVT widenMaskVectorType(MVT VT, const X86Subtarget &Subtarget) {
+  assert(VT.getVectorElementType() == MVT::i1 && "Expected bool vector");
+  unsigned NumElts = VT.getVectorNumElements();
+  if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8)
+    return Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
+  return VT;
+}
+
+/// Widen a mask vector to a minimum of v8i1/v16i1 to allow use of KSHIFT and
+/// bitcast with integer types.
+static SDValue widenMaskVector(SDValue Vec, bool ZeroNewElements,
+                               const X86Subtarget &Subtarget, SelectionDAG &DAG,
+                               const SDLoc &dl) {
+  MVT VT = widenMaskVectorType(Vec.getSimpleValueType(), Subtarget);
+  return widenSubVector(VT, Vec, ZeroNewElements, Subtarget, DAG, dl);
+}
+
 // Helper function to collect subvector ops that are concatenated together,
 // either by ISD::CONCAT_VECTORS or a ISD::INSERT_SUBVECTOR series.
 // The subvectors in Ops are guaranteed to be the same type.
@@ -4100,9 +4119,7 @@ static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG,
   SDValue ZeroIdx = DAG.getIntPtrConstant(0, dl);
 
   // Extend to natively supported kshift.
-  MVT WideOpVT = OpVT;
-  if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8)
-    WideOpVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
+  MVT WideOpVT = widenMaskVectorType(OpVT, Subtarget);
 
   // Inserting into the lsbs of a zero vector is legal. ISel will insert shifts
   // if necessary.
@@ -9008,16 +9025,12 @@ static SDValue LowerCONCAT_VECTORSvXi1(SDValue Op,
   // insert_subvector will give us two kshifts.
   if (isPowerOf2_64(NonZeros) && Zeros != 0 && NonZeros > Zeros &&
       Log2_64(NonZeros) != NumOperands - 1) {
-    MVT ShiftVT = ResVT;
-    if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8)
-      ShiftVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
     unsigned Idx = Log2_64(NonZeros);
     SDValue SubVec = Op.getOperand(Idx);
     unsigned SubVecNumElts = SubVec.getSimpleValueType().getVectorNumElements();
-    SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ShiftVT,
-                         DAG.getUNDEF(ShiftVT), SubVec,
-                         DAG.getIntPtrConstant(0, dl));
-    Op = DAG.getNode(X86ISD::KSHIFTL, dl, ShiftVT, SubVec,
+    MVT ShiftVT = widenMaskVectorType(ResVT, Subtarget);
+    Op = widenSubVector(ShiftVT, SubVec, false, Subtarget, DAG, dl);
+    Op = DAG.getNode(X86ISD::KSHIFTL, dl, ShiftVT, Op,
                      DAG.getTargetConstant(Idx * SubVecNumElts, dl, MVT::i8));
     return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResVT, Op,
                        DAG.getIntPtrConstant(0, dl));
@@ -17004,13 +17017,8 @@ static SDValue lower1BitShuffleAsKSHIFTR(const SDLoc &DL, ArrayRef<int> Mask,
   assert(ShiftAmt >= 0 && "All undef?");
 
   // Great we found a shift right.
-  MVT WideVT = VT;
-  if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8)
-    WideVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
-  SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT,
-                            DAG.getUNDEF(WideVT), V1,
-                            DAG.getIntPtrConstant(0, DL));
-  Res = DAG.getNode(X86ISD::KSHIFTR, DL, WideVT, Res,
+  SDValue Res = widenMaskVector(V1, false, Subtarget, DAG, DL);
+  Res = DAG.getNode(X86ISD::KSHIFTR, DL, Res.getValueType(), Res,
                     DAG.getTargetConstant(ShiftAmt, DL, MVT::i8));
   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
                      DAG.getIntPtrConstant(0, DL));
@@ -17107,12 +17115,8 @@ static SDValue lower1BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
     unsigned Opcode;
     int ShiftAmt = match1BitShuffleAsKSHIFT(Opcode, Mask, Offset, Zeroable);
     if (ShiftAmt >= 0) {
-      MVT WideVT = VT;
-      if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8)
-        WideVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
-      SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT,
-                                DAG.getUNDEF(WideVT), V,
-                                DAG.getIntPtrConstant(0, DL));
+      SDValue Res = widenMaskVector(V, false, Subtarget, DAG, DL);
+      MVT WideVT = Res.getSimpleValueType();
       // Widened right shifts need two shifts to ensure we shift in zeroes.
       if (Opcode == X86ISD::KSHIFTR && WideVT != VT) {
         int WideElts = WideVT.getVectorNumElements();
@@ -17650,17 +17654,9 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG,
     // Extending v8i1/v16i1 to 512-bit get better performance on KNL
     // than extending to 128/256bit.
     if (NumElts == 1) {
-      if (Subtarget.hasDQI()) {
-        Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1,
-                          DAG.getUNDEF(MVT::v8i1), Vec,
-                          DAG.getIntPtrConstant(0, dl));
-        return DAG.getBitcast(MVT::i8, Vec);
-      }
-      Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1,
-                        DAG.getUNDEF(MVT::v16i1), Vec,
-                        DAG.getIntPtrConstant(0, dl));
-      return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8,
-                         DAG.getBitcast(MVT::i16, Vec));
+      Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);
+      MVT IntVT = MVT::getIntegerVT(Vec.getValueType().getVectorNumElements());
+      return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, DAG.getBitcast(IntVT, Vec));
     }
     MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8;
     MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts);
@@ -17674,17 +17670,10 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG,
     return Op;
 
   // Extend to natively supported kshift.
-  unsigned NumElems = VecVT.getVectorNumElements();
-  MVT WideVecVT = VecVT;
-  if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) {
-    WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
-    Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT,
-                      DAG.getUNDEF(WideVecVT), Vec,
-                      DAG.getIntPtrConstant(0, dl));
-  }
+  Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);
 
   // Use kshiftr instruction to move to the lower element.
-  Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec,
+  Vec = DAG.getNode(X86ISD::KSHIFTR, dl, Vec.getSimpleValueType(), Vec,
                     DAG.getTargetConstant(IdxVal, dl, MVT::i8));
 
   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op.getValueType(), Vec,
@@ -18176,20 +18165,11 @@ static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget,
   if (IdxVal == 0) // the operation is legal
     return Op;
 
-  MVT VecVT = Vec.getSimpleValueType();
-  unsigned NumElems = VecVT.getVectorNumElements();
-
   // Extend to natively supported kshift.
-  MVT WideVecVT = VecVT;
-  if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) {
-    WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
-    Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT,
-                      DAG.getUNDEF(WideVecVT), Vec,
-                      DAG.getIntPtrConstant(0, dl));
-  }
+  Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);
 
   // Shift to the LSB.
-  Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec,
+  Vec = DAG.getNode(X86ISD::KSHIFTR, dl, Vec.getSimpleValueType(), Vec,
                     DAG.getTargetConstant(IdxVal, dl, MVT::i8));
 
   return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op.getValueType(), Vec,


        


More information about the llvm-commits mailing list