[llvm] 711dff4 - [X86] Add matchTruncateWithPACK helper for matching signbits/knownbits for PACKSS/PACKUS

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 7 06:12:01 PDT 2023


Author: Simon Pilgrim
Date: 2023-08-07T14:11:42+01:00
New Revision: 711dff45772e04bbaff284dfb16501cbc96bef1c

URL: https://github.com/llvm/llvm-project/commit/711dff45772e04bbaff284dfb16501cbc96bef1c
DIFF: https://github.com/llvm/llvm-project/commit/711dff45772e04bbaff284dfb16501cbc96bef1c.diff

LOG: [X86] Add matchTruncateWithPACK helper for matching signbits/knownbits for PACKSS/PACKUS

Begin to consolidate the similar matching code we have - all have semi-similar constraints that still need merging together to ensure we get consistent codegen depending on when the truncate is lowered.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/vector-trunc-packus.ll
    llvm/test/CodeGen/X86/vector-trunc-ssat.ll
    llvm/test/CodeGen/X86/vector-trunc-usat.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 9ca383e4742299..38230ae7231bfd 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20105,6 +20105,89 @@ static SDValue truncateVectorWithPACKSS(EVT DstVT, SDValue In, const SDLoc &DL,
   return truncateVectorWithPACK(X86ISD::PACKSS, DstVT, In, DL, DAG, Subtarget);
 }
 
+/// Helper to determine if \p In truncated to \p DstVT has the necessary
+/// signbits / leading zero bits to be truncated with PACKSS / PACKUS,
+/// possibly by converting a SRL node to SRA for sign extension.
+static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
+                                     SDValue In, const SDLoc &DL,
+                                     SelectionDAG &DAG,
+                                     const X86Subtarget &Subtarget) {
+  // Requires SSE2.
+  if (!Subtarget.hasSSE2())
+    return SDValue();
+
+  EVT SrcVT = In.getValueType();
+  EVT DstSVT = DstVT.getVectorElementType();
+  EVT SrcSVT = SrcVT.getVectorElementType();
+
+  // Check we have a truncation suited for PACKSS/PACKUS.
+  if (!((SrcSVT == MVT::i16 || SrcSVT == MVT::i32 || SrcSVT == MVT::i64) &&
+        (DstSVT == MVT::i8 || DstSVT == MVT::i16 || DstSVT == MVT::i32)))
+    return SDValue();
+
+  assert(SrcSVT.getSizeInBits() > DstSVT.getSizeInBits() && "Bad truncation");
+  unsigned NumStages = Log2_32(SrcSVT.getSizeInBits() / DstSVT.getSizeInBits());
+
+  // Truncation to sub-128bit vXi32 can be better handled with shuffles.
+  if (DstSVT == MVT::i32 && SrcVT.getSizeInBits() <= 128)
+    return SDValue();
+
+  // Truncation from v2i64 to v2i8 can be better handled with PSHUFB.
+  if (DstVT == MVT::v2i8 && SrcVT == MVT::v2i64 && Subtarget.hasSSSE3())
+    return SDValue();
+
+  // Don't truncate AVX512 targets as multiple PACK nodes stages.
+  if (Subtarget.hasAVX512() && NumStages > 1)
+    return SDValue();
+
+  unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
+  unsigned NumPackedSignBits = std::min<unsigned>(DstSVT.getSizeInBits(), 16);
+  unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8;
+
+  // Truncate with PACKUS if we are truncating a vector with leading zero
+  // bits that extend all the way to the packed/truncated value.
+  // e.g. Masks, zext_in_reg, etc.
+  // Pre-SSE41 we can only use PACKUSWB.
+  KnownBits Known = DAG.computeKnownBits(In);
+  if ((NumSrcEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros()) {
+    PackOpcode = X86ISD::PACKUS;
+    return In;
+  }
+
+  // Truncate with PACKSS if we are truncating a vector with sign-bits
+  // that extend all the way to the packed/truncated value.
+  // e.g. Comparison result, sext_in_reg, etc.
+  unsigned NumSignBits = DAG.ComputeNumSignBits(In);
+
+  // Don't use PACKSS for vXi64 -> vXi32 truncations unless we're dealing with
+  // a sign splat (or AVX512 VPSRAQ support). ComputeNumSignBits struggles to
+  // see through BITCASTs later on and combines/simplifications can't then use
+  // it.
+  if (DstSVT == MVT::i32 && NumSignBits != SrcSVT.getSizeInBits() &&
+      !Subtarget.hasAVX512())
+    return SDValue();
+
+  unsigned MinSignBits = NumSrcEltBits - NumPackedSignBits;
+  if (MinSignBits < NumSignBits) {
+    PackOpcode = X86ISD::PACKSS;
+    return In;
+  }
+
+  // If we have a srl that only generates signbits that we will discard in
+  // the truncation then we can use PACKSS by converting the srl to a sra.
+  // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
+  if (In.getOpcode() == ISD::SRL && In->hasOneUse())
+    if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(
+            In, APInt::getAllOnes(SrcVT.getVectorNumElements()))) {
+      if (*ShAmt == MinSignBits) {
+        PackOpcode = X86ISD::PACKSS;
+        return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
+      }
+    }
+
+  return SDValue();
+}
+
 /// This function lowers a vector truncation of 'extended sign-bits' or
 /// 'extended zero-bits' values.
 /// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS/PACKUS operations.
@@ -20119,11 +20202,6 @@ static SDValue LowerTruncateVecPackWithSignBits(MVT DstVT, SDValue In,
         (DstSVT == MVT::i8 || DstSVT == MVT::i16 || DstSVT == MVT::i32)))
     return SDValue();
 
-  // Don't lower with PACK nodes on AVX512 targets if we'd need more than one.
-  if (Subtarget.hasAVX512() &&
-      SrcSVT.getSizeInBits() > (DstSVT.getSizeInBits() * 2))
-    return SDValue();
-
   // Prefer to lower v4i64 -> v4i32 as a shuffle unless we can cheaply
   // split this for packing.
   if (SrcVT == MVT::v4i64 && DstVT == MVT::v4i32 &&
@@ -20144,25 +20222,10 @@ static SDValue LowerTruncateVecPackWithSignBits(MVT DstVT, SDValue In,
     }
   }
 
-  unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
-  unsigned NumPackedSignBits = std::min<unsigned>(DstSVT.getSizeInBits(), 16);
-  unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8;
-
-  // Truncate with PACKUS if we are truncating a vector with leading zero
-  // bits that extend all the way to the packed/truncated value. Pre-SSE41
-  // we can only use PACKUSWB.
-  KnownBits Known = DAG.computeKnownBits(In);
-  if ((NumSrcEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros())
-    if (SDValue V = truncateVectorWithPACK(X86ISD::PACKUS, DstVT, In, DL, DAG,
-                                           Subtarget))
-      return V;
-
-  // Truncate with PACKSS if we are truncating a vector with sign-bits
-  // that extend all the way to the packed/truncated value.
-  if ((NumSrcEltBits - NumPackedSignBits) < DAG.ComputeNumSignBits(In))
-    if (SDValue V = truncateVectorWithPACK(X86ISD::PACKSS, DstVT, In, DL, DAG,
-                                           Subtarget))
-      return V;
+  unsigned PackOpcode;
+  if (SDValue Src =
+          matchTruncateWithPACK(PackOpcode, DstVT, In, DL, DAG, Subtarget))
+    return truncateVectorWithPACK(PackOpcode, DstVT, Src, DL, DAG, Subtarget);
 
   return SDValue();
 }
@@ -31988,39 +32051,17 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
     unsigned InBits = InVT.getSizeInBits();
 
     if (128 % InBits == 0) {
-      // See if we there are sufficient leading bits to perform a PACKUS/PACKSS.
-      // Skip for AVX512 unless this will be a single stage truncation.
-      if ((InEltVT == MVT::i16 || InEltVT == MVT::i32) &&
-          (EltVT == MVT::i8 || EltVT == MVT::i16) &&
-          (!Subtarget.hasAVX512() || InBits == (2 * VT.getSizeInBits()))) {
-        unsigned NumPackedSignBits =
-            std::min<unsigned>(EltVT.getSizeInBits(), 16);
-        unsigned NumPackedZeroBits =
-            Subtarget.hasSSE41() ? NumPackedSignBits : 8;
-
-        // Use PACKUS if the input has zero-bits that extend all the way to the
-        // packed/truncated value. e.g. masks, zext_in_reg, etc.
-        KnownBits Known = DAG.computeKnownBits(In);
-        unsigned NumLeadingZeroBits = Known.countMinLeadingZeros();
-        bool UsePACKUS =
-            NumLeadingZeroBits >= (InEltVT.getSizeInBits() - NumPackedZeroBits);
-
-        // Use PACKSS if the input has sign-bits that extend all the way to the
-        // packed/truncated value. e.g. Comparison result, sext_in_reg, etc.
-        unsigned NumSignBits = DAG.ComputeNumSignBits(In);
-        bool UsePACKSS =
-            NumSignBits > (InEltVT.getSizeInBits() - NumPackedSignBits);
-
-        if (UsePACKUS || UsePACKSS) {
-          SDValue WidenIn =
-              widenSubVector(In, false, Subtarget, DAG, dl,
-                             InEltVT.getSizeInBits() * WidenNumElts);
-          if (SDValue Res = truncateVectorWithPACK(
-                  UsePACKUS ? X86ISD::PACKUS : X86ISD::PACKSS, WidenVT, WidenIn,
-                  dl, DAG, Subtarget)) {
-            Results.push_back(Res);
-            return;
-          }
+      // See if there are sufficient leading bits to perform a PACKUS/PACKSS.
+      unsigned PackOpcode;
+      if (SDValue Src =
+              matchTruncateWithPACK(PackOpcode, VT, In, dl, DAG, Subtarget)) {
+        SDValue WidenSrc =
+            widenSubVector(Src, false, Subtarget, DAG, dl,
+                           InEltVT.getSizeInBits() * WidenNumElts);
+        if (SDValue Res = truncateVectorWithPACK(PackOpcode, WidenVT, WidenSrc,
+                                                 dl, DAG, Subtarget)) {
+          Results.push_back(Res);
+          return;
         }
       }
 
@@ -50816,26 +50857,7 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, const SDLoc &DL,
     return SDValue();
 
   MVT VT = N->getValueType(0).getSimpleVT();
-  MVT SVT = VT.getScalarType();
-
   MVT InVT = In.getValueType().getSimpleVT();
-  MVT InSVT = InVT.getScalarType();
-
-  // Check we have a truncation suited for PACKSS/PACKUS.
-  if (!isPowerOf2_32(VT.getVectorNumElements()))
-    return SDValue();
-  if (SVT != MVT::i8 && SVT != MVT::i16 && SVT != MVT::i32)
-    return SDValue();
-  if (InSVT != MVT::i16 && InSVT != MVT::i32 && InSVT != MVT::i64)
-    return SDValue();
-
-  // Truncation to sub-128bit vXi32 can be better handled with shuffles.
-  if (SVT == MVT::i32 && VT.getSizeInBits() < 128)
-    return SDValue();
-
-  // Truncation from sub-128bit to vXi8 can be better handled with PSHUFB.
-  if (SVT == MVT::i8 && InVT.getSizeInBits() <= 128 && Subtarget.hasSSSE3())
-    return SDValue();
 
   // AVX512 has fast truncate, but if the input is already going to be split,
   // there's no harm in trying pack.
@@ -50848,42 +50870,10 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, const SDLoc &DL,
       return SDValue();
   }
 
-  unsigned NumPackedSignBits = std::min<unsigned>(SVT.getSizeInBits(), 16);
-  unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8;
-
-  // Use PACKUS if the input has zero-bits that extend all the way to the
-  // packed/truncated value. e.g. masks, zext_in_reg, etc.
-  KnownBits Known = DAG.computeKnownBits(In);
-  unsigned NumLeadingZeroBits = Known.countMinLeadingZeros();
-  if (NumLeadingZeroBits >= (InSVT.getSizeInBits() - NumPackedZeroBits))
-    return truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget);
-
-  // Use PACKSS if the input has sign-bits that extend all the way to the
-  // packed/truncated value. e.g. Comparison result, sext_in_reg, etc.
-  unsigned NumSignBits = DAG.ComputeNumSignBits(In);
-
-  // Don't use PACKSS for vXi64 -> vXi32 truncations unless we're dealing with
-  // a sign splat. ComputeNumSignBits struggles to see through BITCASTs later
-  // on and combines/simplifications can't then use it.
-  if (SVT == MVT::i32 && NumSignBits != InSVT.getSizeInBits())
-    return SDValue();
-
-  unsigned MinSignBits = InSVT.getSizeInBits() - NumPackedSignBits;
-  if (NumSignBits > MinSignBits)
-    return truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget);
-
-  // If we have a srl that only generates signbits that we will discard in
-  // the truncation then we can use PACKSS by converting the srl to a sra.
-  // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
-  if (In.getOpcode() == ISD::SRL && N->isOnlyUserOf(In.getNode()))
-    if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(
-            In, APInt::getAllOnes(VT.getVectorNumElements()))) {
-      if (*ShAmt == MinSignBits) {
-        SDValue NewIn = DAG.getNode(ISD::SRA, DL, InVT, In->ops());
-        return truncateVectorWithPACK(X86ISD::PACKSS, VT, NewIn, DL, DAG,
-                                      Subtarget);
-      }
-    }
+  unsigned PackOpcode;
+  if (SDValue Src =
+          matchTruncateWithPACK(PackOpcode, VT, In, DL, DAG, Subtarget))
+    return truncateVectorWithPACK(PackOpcode, VT, Src, DL, DAG, Subtarget);
 
   return SDValue();
 }

diff  --git a/llvm/test/CodeGen/X86/vector-trunc-packus.ll b/llvm/test/CodeGen/X86/vector-trunc-packus.ll
index 93fa09e26121cc..52e533b5154806 100644
--- a/llvm/test/CodeGen/X86/vector-trunc-packus.ll
+++ b/llvm/test/CodeGen/X86/vector-trunc-packus.ll
@@ -4066,15 +4066,24 @@ define <4 x i8> @trunc_packus_v4i32_v4i8(<4 x i32> %a0) "min-legal-vector-width"
 ; AVX1-NEXT:    vpackuswb %xmm0, %xmm0, %xmm0
 ; AVX1-NEXT:    retq
 ;
-; AVX2-LABEL: trunc_packus_v4i32_v4i8:
-; AVX2:       # %bb.0:
-; AVX2-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
-; AVX2-NEXT:    vpminsd %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX2-NEXT:    vpmaxsd %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vpackusdw %xmm0, %xmm0, %xmm0
-; AVX2-NEXT:    vpackuswb %xmm0, %xmm0, %xmm0
-; AVX2-NEXT:    retq
+; AVX2-SLOW-LABEL: trunc_packus_v4i32_v4i8:
+; AVX2-SLOW:       # %bb.0:
+; AVX2-SLOW-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
+; AVX2-SLOW-NEXT:    vpminsd %xmm1, %xmm0, %xmm0
+; AVX2-SLOW-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX2-SLOW-NEXT:    vpmaxsd %xmm1, %xmm0, %xmm0
+; AVX2-SLOW-NEXT:    vpackusdw %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT:    vpackuswb %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT:    retq
+;
+; AVX2-FAST-LABEL: trunc_packus_v4i32_v4i8:
+; AVX2-FAST:       # %bb.0:
+; AVX2-FAST-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
+; AVX2-FAST-NEXT:    vpminsd %xmm1, %xmm0, %xmm0
+; AVX2-FAST-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX2-FAST-NEXT:    vpmaxsd %xmm1, %xmm0, %xmm0
+; AVX2-FAST-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[0,4,8,12,u,u,u,u,u,u,u,u,u,u,u,u]
+; AVX2-FAST-NEXT:    retq
 ;
 ; AVX512F-LABEL: trunc_packus_v4i32_v4i8:
 ; AVX512F:       # %bb.0:

diff  --git a/llvm/test/CodeGen/X86/vector-trunc-ssat.ll b/llvm/test/CodeGen/X86/vector-trunc-ssat.ll
index ffcd593a2c67ad..6fa548ae93cff0 100644
--- a/llvm/test/CodeGen/X86/vector-trunc-ssat.ll
+++ b/llvm/test/CodeGen/X86/vector-trunc-ssat.ll
@@ -3789,15 +3789,24 @@ define <4 x i8> @trunc_ssat_v4i32_v4i8(<4 x i32> %a0) {
 ; AVX1-NEXT:    vpacksswb %xmm0, %xmm0, %xmm0
 ; AVX1-NEXT:    retq
 ;
-; AVX2-LABEL: trunc_ssat_v4i32_v4i8:
-; AVX2:       # %bb.0:
-; AVX2-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [127,127,127,127]
-; AVX2-NEXT:    vpminsd %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [4294967168,4294967168,4294967168,4294967168]
-; AVX2-NEXT:    vpmaxsd %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vpackssdw %xmm0, %xmm0, %xmm0
-; AVX2-NEXT:    vpacksswb %xmm0, %xmm0, %xmm0
-; AVX2-NEXT:    retq
+; AVX2-SLOW-LABEL: trunc_ssat_v4i32_v4i8:
+; AVX2-SLOW:       # %bb.0:
+; AVX2-SLOW-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [127,127,127,127]
+; AVX2-SLOW-NEXT:    vpminsd %xmm1, %xmm0, %xmm0
+; AVX2-SLOW-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [4294967168,4294967168,4294967168,4294967168]
+; AVX2-SLOW-NEXT:    vpmaxsd %xmm1, %xmm0, %xmm0
+; AVX2-SLOW-NEXT:    vpackssdw %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT:    vpacksswb %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT:    retq
+;
+; AVX2-FAST-LABEL: trunc_ssat_v4i32_v4i8:
+; AVX2-FAST:       # %bb.0:
+; AVX2-FAST-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [127,127,127,127]
+; AVX2-FAST-NEXT:    vpminsd %xmm1, %xmm0, %xmm0
+; AVX2-FAST-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [4294967168,4294967168,4294967168,4294967168]
+; AVX2-FAST-NEXT:    vpmaxsd %xmm1, %xmm0, %xmm0
+; AVX2-FAST-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[0,4,8,12,u,u,u,u,u,u,u,u,u,u,u,u]
+; AVX2-FAST-NEXT:    retq
 ;
 ; AVX512F-LABEL: trunc_ssat_v4i32_v4i8:
 ; AVX512F:       # %bb.0:

diff  --git a/llvm/test/CodeGen/X86/vector-trunc-usat.ll b/llvm/test/CodeGen/X86/vector-trunc-usat.ll
index 7030372390598e..fded69f955ccd2 100644
--- a/llvm/test/CodeGen/X86/vector-trunc-usat.ll
+++ b/llvm/test/CodeGen/X86/vector-trunc-usat.ll
@@ -2906,13 +2906,20 @@ define <4 x i8> @trunc_usat_v4i32_v4i8(<4 x i32> %a0) {
 ; AVX1-NEXT:    vpackuswb %xmm0, %xmm0, %xmm0
 ; AVX1-NEXT:    retq
 ;
-; AVX2-LABEL: trunc_usat_v4i32_v4i8:
-; AVX2:       # %bb.0:
-; AVX2-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
-; AVX2-NEXT:    vpminud %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vpackusdw %xmm0, %xmm0, %xmm0
-; AVX2-NEXT:    vpackuswb %xmm0, %xmm0, %xmm0
-; AVX2-NEXT:    retq
+; AVX2-SLOW-LABEL: trunc_usat_v4i32_v4i8:
+; AVX2-SLOW:       # %bb.0:
+; AVX2-SLOW-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
+; AVX2-SLOW-NEXT:    vpminud %xmm1, %xmm0, %xmm0
+; AVX2-SLOW-NEXT:    vpackusdw %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT:    vpackuswb %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT:    retq
+;
+; AVX2-FAST-LABEL: trunc_usat_v4i32_v4i8:
+; AVX2-FAST:       # %bb.0:
+; AVX2-FAST-NEXT:    vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
+; AVX2-FAST-NEXT:    vpminud %xmm1, %xmm0, %xmm0
+; AVX2-FAST-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[0,4,8,12,u,u,u,u,u,u,u,u,u,u,u,u]
+; AVX2-FAST-NEXT:    retq
 ;
 ; AVX512F-LABEL: trunc_usat_v4i32_v4i8:
 ; AVX512F:       # %bb.0:


        


More information about the llvm-commits mailing list