[llvm] a916e81 - [X86] Various improvements to our vector splitting helpers for lowering. NFC

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 15 11:00:29 PDT 2020


Author: Craig Topper
Date: 2020-04-15T10:57:53-07:00
New Revision: a916e819275922ab9a350283a12647da6f4ad4b1

URL: https://github.com/llvm/llvm-project/commit/a916e819275922ab9a350283a12647da6f4ad4b1
DIFF: https://github.com/llvm/llvm-project/commit/a916e819275922ab9a350283a12647da6f4ad4b1.diff

LOG: [X86] Various improvements to our vector splitting helpers for lowering. NFC

-Consistently name the functions as split*
-Add a helper for doing the two extractSubvector calls and determining the size of the split
-Use getSplitDestVTs to get the result type for the split node.
-Move the binary and unary helper to one place in the file near the extractSubvector functions. Left the VSETCC one near LowerVSETCC since that's its only caller.
-Remove the 256/512 wrappers that just had asserts. I don't think they provided a lot of value and now with the routines called split* the call sites are more obvious what they do.
-Make the unary routine support different source and dest types to support D76212.
-Add some weaker asserts into the helpers to make up for losing the very specific asserts from the 256/512 wrappers.

Differential Revision: https://reviews.llvm.org/D78176

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 47ef38ab3f4c..81bb829cd3ce 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -5802,6 +5802,71 @@ static bool collectConcatOps(SDNode *N, SmallVectorImpl<SDValue> &Ops) {
   return false;
 }
 
+static std::pair<SDValue, SDValue> splitVector(SDValue Op, SelectionDAG &DAG,
+                                               const SDLoc &dl) {
+  MVT VT = Op.getSimpleValueType();
+  unsigned NumElems = VT.getVectorNumElements();
+  unsigned SizeInBits = VT.getSizeInBits();
+
+  SDValue Lo = extractSubVector(Op, 0, DAG, dl, SizeInBits / 2);
+  SDValue Hi = extractSubVector(Op, NumElems / 2, DAG, dl, SizeInBits / 2);
+
+  return std::make_pair(Lo, Hi);
+}
+
+// Split an unary integer op into 2 half sized ops.
+static SDValue splitVectorIntUnary(SDValue Op, SelectionDAG &DAG) {
+  EVT VT = Op.getValueType();
+
+  // Make sure we only try to split 256/512-bit types to avoid creating
+  // narrow vectors.
+  assert((Op.getOperand(0).getValueType().is256BitVector() ||
+          Op.getOperand(0).getValueType().is512BitVector()) &&
+         (VT.is256BitVector() || VT.is512BitVector()) && "Unsupported VT!");
+  assert(Op.getOperand(0).getValueType().getVectorNumElements() ==
+             VT.getVectorNumElements() &&
+         "Unexpected VTs!");
+
+  SDLoc dl(Op);
+
+  // Extract the Lo/Hi vectors
+  SDValue Lo, Hi;
+  std::tie(Lo, Hi) = splitVector(Op.getOperand(0), DAG, dl);
+
+  EVT LoVT, HiVT;
+  std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
+  return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
+                     DAG.getNode(Op.getOpcode(), dl, LoVT, Lo),
+                     DAG.getNode(Op.getOpcode(), dl, HiVT, Hi));
+}
+
+/// Break a binary integer operation into 2 half sized ops and then
+/// concatenate the result back.
+static SDValue splitVectorIntBinary(SDValue Op, SelectionDAG &DAG) {
+  EVT VT = Op.getValueType();
+
+  // Sanity check that all the types match.
+  assert(Op.getOperand(0).getValueType() == VT &&
+         Op.getOperand(1).getValueType() == VT && "Unexpected VTs!");
+  assert((VT.is256BitVector() || VT.is512BitVector()) && "Unsupported VT!");
+
+  SDLoc dl(Op);
+
+  // Extract the LHS Lo/Hi vectors
+  SDValue LHS1, LHS2;
+  std::tie(LHS1, LHS2) = splitVector(Op.getOperand(0), DAG, dl);
+
+  // Extract the RHS Lo/Hi vectors
+  SDValue RHS1, RHS2;
+  std::tie(RHS1, RHS2) = splitVector(Op.getOperand(1), DAG, dl);
+
+  EVT LoVT, HiVT;
+  std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
+  return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
+                     DAG.getNode(Op.getOpcode(), dl, LoVT, LHS1, RHS1),
+                     DAG.getNode(Op.getOpcode(), dl, HiVT, LHS2, RHS2));
+}
+
 // Helper for splitting operands of an operation to legal target size and
 // apply a function on each part.
 // Useful for operations that are available on SSE2 in 128-bit, on AVX2 in
@@ -21820,32 +21885,30 @@ static unsigned translateX86FSETCC(ISD::CondCode SetCCOpcode, SDValue &Op0,
 
 /// Break a VSETCC 256-bit integer VSETCC into two new 128 ones and then
 /// concatenate the result back.
-static SDValue Lower256IntVSETCC(SDValue Op, SelectionDAG &DAG) {
-  MVT VT = Op.getSimpleValueType();
+static SDValue splitIntVSETCC(SDValue Op, SelectionDAG &DAG) {
+  EVT VT = Op.getValueType();
 
-  assert(VT.is256BitVector() && Op.getOpcode() == ISD::SETCC &&
-         "Unsupported value type for operation");
+  assert(Op.getOpcode() == ISD::SETCC && "Unsupported operation");
+  assert(Op.getOperand(0).getValueType().isInteger() &&
+         VT == Op.getOperand(0).getValueType() && "Unsupported VTs!");
 
-  unsigned NumElems = VT.getVectorNumElements();
   SDLoc dl(Op);
   SDValue CC = Op.getOperand(2);
 
-  // Extract the LHS vectors
-  SDValue LHS = Op.getOperand(0);
-  SDValue LHS1 = extract128BitVector(LHS, 0, DAG, dl);
-  SDValue LHS2 = extract128BitVector(LHS, NumElems / 2, DAG, dl);
+  // Extract the LHS Lo/Hi vectors
+  SDValue LHS1, LHS2;
+  std::tie(LHS1, LHS2) = splitVector(Op.getOperand(0), DAG, dl);
 
-  // Extract the RHS vectors
-  SDValue RHS = Op.getOperand(1);
-  SDValue RHS1 = extract128BitVector(RHS, 0, DAG, dl);
-  SDValue RHS2 = extract128BitVector(RHS, NumElems / 2, DAG, dl);
+  // Extract the RHS Lo/Hi vectors
+  SDValue RHS1, RHS2;
+  std::tie(RHS1, RHS2) = splitVector(Op.getOperand(1), DAG, dl);
 
   // Issue the operation on the smaller types and concatenate the result back
-  MVT EltVT = VT.getVectorElementType();
-  MVT NewVT = MVT::getVectorVT(EltVT, NumElems/2);
+  EVT LoVT, HiVT;
+  std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
   return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
-                     DAG.getNode(Op.getOpcode(), dl, NewVT, LHS1, RHS1, CC),
-                     DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2, CC));
+                     DAG.getNode(ISD::SETCC, dl, LoVT, LHS1, RHS1, CC),
+                     DAG.getNode(ISD::SETCC, dl, HiVT, LHS2, RHS2, CC));
 }
 
 static SDValue LowerIntVSETCC_AVX512(SDValue Op, SelectionDAG &DAG) {
@@ -22187,7 +22250,7 @@ static SDValue LowerVSETCC(SDValue Op, const X86Subtarget &Subtarget,
 
   // Break 256-bit integer vector compare into smaller ones.
   if (VT.is256BitVector() && !Subtarget.hasInt256())
-    return Lower256IntVSETCC(Op, DAG);
+    return splitIntVSETCC(Op, DAG);
 
   // If this is a SETNE against the signed minimum value, change it to SETGT.
   // If this is a SETNE against the signed maximum value, change it to SETLT.
@@ -25922,43 +25985,6 @@ SDValue X86TargetLowering::LowerFLT_ROUNDS_(SDValue Op,
   return DAG.getMergeValues({RetVal, Chain}, DL);
 }
 
-// Split an unary integer op into 2 half sized ops.
-static SDValue LowerVectorIntUnary(SDValue Op, SelectionDAG &DAG) {
-  MVT VT = Op.getSimpleValueType();
-  unsigned NumElems = VT.getVectorNumElements();
-  unsigned SizeInBits = VT.getSizeInBits();
-  MVT EltVT = VT.getVectorElementType();
-  SDValue Src = Op.getOperand(0);
-  assert(EltVT == Src.getSimpleValueType().getVectorElementType() &&
-         "Src and Op should have the same element type!");
-
-  // Extract the Lo/Hi vectors
-  SDLoc dl(Op);
-  SDValue Lo = extractSubVector(Src, 0, DAG, dl, SizeInBits / 2);
-  SDValue Hi = extractSubVector(Src, NumElems / 2, DAG, dl, SizeInBits / 2);
-
-  MVT NewVT = MVT::getVectorVT(EltVT, NumElems / 2);
-  return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
-                     DAG.getNode(Op.getOpcode(), dl, NewVT, Lo),
-                     DAG.getNode(Op.getOpcode(), dl, NewVT, Hi));
-}
-
-// Decompose 256-bit ops into smaller 128-bit ops.
-static SDValue Lower256IntUnary(SDValue Op, SelectionDAG &DAG) {
-  assert(Op.getSimpleValueType().is256BitVector() &&
-         Op.getSimpleValueType().isInteger() &&
-         "Only handle AVX 256-bit vector integer operation");
-  return LowerVectorIntUnary(Op, DAG);
-}
-
-// Decompose 512-bit ops into smaller 256-bit ops.
-static SDValue Lower512IntUnary(SDValue Op, SelectionDAG &DAG) {
-  assert(Op.getSimpleValueType().is512BitVector() &&
-         Op.getSimpleValueType().isInteger() &&
-         "Only handle AVX 512-bit vector integer operation");
-  return LowerVectorIntUnary(Op, DAG);
-}
-
 /// Lower a vector CTLZ using native supported vector CTLZ instruction.
 //
 // i8/i16 vector implemented using dword LZCNT vector instruction
@@ -25979,7 +26005,7 @@ static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG,
   // Split vector, it's Lo and Hi parts will be handled in next iteration.
   if (NumElems > 16 ||
       (NumElems == 16 && !Subtarget.canExtendTo512DQ()))
-    return LowerVectorIntUnary(Op, DAG);
+    return splitVectorIntUnary(Op, DAG);
 
   MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems);
   assert((NewVT.is256BitVector() || NewVT.is512BitVector()) &&
@@ -26089,11 +26115,11 @@ static SDValue LowerVectorCTLZ(SDValue Op, const SDLoc &DL,
 
   // Decompose 256-bit ops into smaller 128-bit ops.
   if (VT.is256BitVector() && !Subtarget.hasInt256())
-    return Lower256IntUnary(Op, DAG);
+    return splitVectorIntUnary(Op, DAG);
 
   // Decompose 512-bit ops into smaller 256-bit ops.
   if (VT.is512BitVector() && !Subtarget.hasBWI())
-    return Lower512IntUnary(Op, DAG);
+    return splitVectorIntUnary(Op, DAG);
 
   assert(Subtarget.hasSSSE3() && "Expected SSSE3 support for PSHUFB");
   return LowerVectorCTLZInRegLUT(Op, DL, Subtarget, DAG);
@@ -26159,48 +26185,6 @@ static SDValue LowerCTTZ(SDValue Op, const X86Subtarget &Subtarget,
   return DAG.getNode(X86ISD::CMOV, dl, VT, Ops);
 }
 
-/// Break a binary integer operation into 2 half sized ops and then
-/// concatenate the result back.
-static SDValue splitVectorIntBinary(SDValue Op, SelectionDAG &DAG) {
-  MVT VT = Op.getSimpleValueType();
-  unsigned NumElems = VT.getVectorNumElements();
-  unsigned SizeInBits = VT.getSizeInBits();
-  SDLoc dl(Op);
-
-  // Extract the LHS Lo/Hi vectors
-  SDValue LHS = Op.getOperand(0);
-  SDValue LHS1 = extractSubVector(LHS, 0, DAG, dl, SizeInBits / 2);
-  SDValue LHS2 = extractSubVector(LHS, NumElems / 2, DAG, dl, SizeInBits / 2);
-
-  // Extract the RHS Lo/Hi vectors
-  SDValue RHS = Op.getOperand(1);
-  SDValue RHS1 = extractSubVector(RHS, 0, DAG, dl, SizeInBits / 2);
-  SDValue RHS2 = extractSubVector(RHS, NumElems / 2, DAG, dl, SizeInBits / 2);
-
-  MVT NewVT = MVT::getVectorVT(VT.getVectorElementType(), NumElems / 2);
-  return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
-                     DAG.getNode(Op.getOpcode(), dl, NewVT, LHS1, RHS1),
-                     DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2));
-}
-
-/// Break a 256-bit integer operation into two new 128-bit ones and then
-/// concatenate the result back.
-static SDValue split256IntArith(SDValue Op, SelectionDAG &DAG) {
-  assert(Op.getSimpleValueType().is256BitVector() &&
-         Op.getSimpleValueType().isInteger() &&
-         "Unsupported value type for operation");
-  return splitVectorIntBinary(Op, DAG);
-}
-
-/// Break a 512-bit integer operation into two new 256-bit ones and then
-/// concatenate the result back.
-static SDValue split512IntArith(SDValue Op, SelectionDAG &DAG) {
-  assert(Op.getSimpleValueType().is512BitVector() &&
-         Op.getSimpleValueType().isInteger() &&
-         "Unsupported value type for operation");
-  return splitVectorIntBinary(Op, DAG);
-}
-
 static SDValue lowerAddSub(SDValue Op, SelectionDAG &DAG,
                            const X86Subtarget &Subtarget) {
   MVT VT = Op.getSimpleValueType();
@@ -26214,7 +26198,7 @@ static SDValue lowerAddSub(SDValue Op, SelectionDAG &DAG,
   assert(Op.getSimpleValueType().is256BitVector() &&
          Op.getSimpleValueType().isInteger() &&
          "Only handle AVX 256-bit vector integer operation");
-  return split256IntArith(Op, DAG);
+  return splitVectorIntBinary(Op, DAG);
 }
 
 static SDValue LowerADDSAT_SUBSAT(SDValue Op, SelectionDAG &DAG,
@@ -26262,7 +26246,7 @@ static SDValue LowerADDSAT_SUBSAT(SDValue Op, SelectionDAG &DAG,
   assert(Op.getSimpleValueType().is256BitVector() &&
          Op.getSimpleValueType().isInteger() &&
          "Only handle AVX 256-bit vector integer operation");
-  return split256IntArith(Op, DAG);
+  return splitVectorIntBinary(Op, DAG);
 }
 
 static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget,
@@ -26292,7 +26276,7 @@ static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget,
   if (VT.is256BitVector() && !Subtarget.hasInt256()) {
     assert(VT.isInteger() &&
            "Only handle AVX 256-bit vector integer operation");
-    return Lower256IntUnary(Op, DAG);
+    return splitVectorIntUnary(Op, DAG);
   }
 
   // Default to expand.
@@ -26304,7 +26288,7 @@ static SDValue LowerMINMAX(SDValue Op, SelectionDAG &DAG) {
 
   // For AVX1 cases, split to use legal ops (everything but v4i64).
   if (VT.getScalarType() != MVT::i64 && VT.is256BitVector())
-    return split256IntArith(Op, DAG);
+    return splitVectorIntBinary(Op, DAG);
 
   SDLoc DL(Op);
   unsigned Opcode = Op.getOpcode();
@@ -26348,7 +26332,7 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
 
   // Decompose 256-bit ops into 128-bit ops.
   if (VT.is256BitVector() && !Subtarget.hasInt256())
-    return split256IntArith(Op, DAG);
+    return splitVectorIntBinary(Op, DAG);
 
   SDValue A = Op.getOperand(0);
   SDValue B = Op.getOperand(1);
@@ -26494,7 +26478,7 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
 
   // Decompose 256-bit ops into 128-bit ops.
   if (VT.is256BitVector() && !Subtarget.hasInt256())
-    return split256IntArith(Op, DAG);
+    return splitVectorIntBinary(Op, DAG);
 
   if (VT == MVT::v4i32 || VT == MVT::v8i32 || VT == MVT::v16i32) {
     assert((VT == MVT::v4i32 && Subtarget.hasSSE2()) ||
@@ -26586,7 +26570,7 @@ static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
   // For signed 512-bit vectors, split into 256-bit vectors to allow the
   // sign-extension to occur.
   if (VT == MVT::v64i8 && IsSigned)
-    return split512IntArith(Op, DAG);
+    return splitVectorIntBinary(Op, DAG);
 
   // Signed AVX2 implementation - extend xmm subvectors to ymm.
   if (VT == MVT::v32i8 && IsSigned) {
@@ -27560,7 +27544,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
 
   // Decompose 256-bit shifts into 128-bit shifts.
   if (VT.is256BitVector())
-    return split256IntArith(Op, DAG);
+    return splitVectorIntBinary(Op, DAG);
 
   return SDValue();
 }
@@ -27606,7 +27590,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
   // XOP implicitly uses modulo rotation amounts.
   if (Subtarget.hasXOP()) {
     if (VT.is256BitVector())
-      return split256IntArith(Op, DAG);
+      return splitVectorIntBinary(Op, DAG);
     assert(VT.is128BitVector() && "Only rotate 128-bit vectors!");
 
     // Attempt to rotate by immediate.
@@ -27622,7 +27606,7 @@ static SDValue LowerRotate(SDValue Op, const X86Subtarget &Subtarget,
 
   // Split 256-bit integers on pre-AVX2 targets.
   if (VT.is256BitVector() && !Subtarget.hasAVX2())
-    return split256IntArith(Op, DAG);
+    return splitVectorIntBinary(Op, DAG);
 
   assert((VT == MVT::v4i32 || VT == MVT::v8i16 || VT == MVT::v16i8 ||
           ((VT == MVT::v8i32 || VT == MVT::v16i16 || VT == MVT::v32i8) &&
@@ -28287,11 +28271,11 @@ static SDValue LowerVectorCTPOP(SDValue Op, const X86Subtarget &Subtarget,
 
   // Decompose 256-bit ops into smaller 128-bit ops.
   if (VT.is256BitVector() && !Subtarget.hasInt256())
-    return Lower256IntUnary(Op, DAG);
+    return splitVectorIntUnary(Op, DAG);
 
   // Decompose 512-bit ops into smaller 256-bit ops.
   if (VT.is512BitVector() && !Subtarget.hasBWI())
-    return Lower512IntUnary(Op, DAG);
+    return splitVectorIntUnary(Op, DAG);
 
   // For element types greater than i8, do vXi8 pop counts and a bytesum.
   if (VT.getScalarType() != MVT::i8) {
@@ -28335,7 +28319,7 @@ static SDValue LowerBITREVERSE_XOP(SDValue Op, SelectionDAG &DAG) {
 
   // Decompose 256-bit ops into smaller 128-bit ops.
   if (VT.is256BitVector())
-    return Lower256IntUnary(Op, DAG);
+    return splitVectorIntUnary(Op, DAG);
 
   assert(VT.is128BitVector() &&
          "Only 128-bit vector bitreverse lowering supported.");
@@ -28376,7 +28360,7 @@ static SDValue LowerBITREVERSE(SDValue Op, const X86Subtarget &Subtarget,
   // lowering.
   if (VT == MVT::v8i64 || VT == MVT::v16i32) {
     assert(!Subtarget.hasBWI() && "BWI should Expand BITREVERSE");
-    return Lower512IntUnary(Op, DAG);
+    return splitVectorIntUnary(Op, DAG);
   }
 
   unsigned NumElts = VT.getVectorNumElements();
@@ -28385,7 +28369,7 @@ static SDValue LowerBITREVERSE(SDValue Op, const X86Subtarget &Subtarget,
 
   // Decompose 256-bit ops into smaller 128-bit ops on pre-AVX2.
   if (VT.is256BitVector() && !Subtarget.hasInt256())
-    return Lower256IntUnary(Op, DAG);
+    return splitVectorIntUnary(Op, DAG);
 
   // Perform BITREVERSE using PSHUFB lookups. Each byte is split into
   // two nibbles and a PSHUFB lookup to find the bitreverse of each
@@ -47137,7 +47121,7 @@ static SDValue combineExtractSubvector(SDNode *N, SelectionDAG &DAG,
     if (isConcatenatedNot(InVecBC.getOperand(0)) ||
         isConcatenatedNot(InVecBC.getOperand(1))) {
       // extract (and v4i64 X, (not (concat Y1, Y2))), n -> andnp v2i64 X(n), Y1
-      SDValue Concat = split256IntArith(InVecBC, DAG);
+      SDValue Concat = splitVectorIntBinary(InVecBC, DAG);
       return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT,
                          DAG.getBitcast(InVecVT, Concat), N->getOperand(1));
     }


        


More information about the llvm-commits mailing list