[llvm] 711dff4 - [X86] Add matchTruncateWithPACK helper for matching signbits/knownbits for PACKSS/PACKUS
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 7 06:12:01 PDT 2023
Author: Simon Pilgrim
Date: 2023-08-07T14:11:42+01:00
New Revision: 711dff45772e04bbaff284dfb16501cbc96bef1c
URL: https://github.com/llvm/llvm-project/commit/711dff45772e04bbaff284dfb16501cbc96bef1c
DIFF: https://github.com/llvm/llvm-project/commit/711dff45772e04bbaff284dfb16501cbc96bef1c.diff
LOG: [X86] Add matchTruncateWithPACK helper for matching signbits/knownbits for PACKSS/PACKUS
Begin to consolidate the similar matching code we have - all have semi-similar constraints that still need merging together to ensure we get consistent codegen depending on when the truncate is lowered.
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/vector-trunc-packus.ll
llvm/test/CodeGen/X86/vector-trunc-ssat.ll
llvm/test/CodeGen/X86/vector-trunc-usat.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 9ca383e4742299..38230ae7231bfd 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20105,6 +20105,89 @@ static SDValue truncateVectorWithPACKSS(EVT DstVT, SDValue In, const SDLoc &DL,
return truncateVectorWithPACK(X86ISD::PACKSS, DstVT, In, DL, DAG, Subtarget);
}
+/// Helper to determine if \p In truncated to \p DstVT has the necessary
+/// signbits / leading zero bits to be truncated with PACKSS / PACKUS,
+/// possibly by converting a SRL node to SRA for sign extension.
+static SDValue matchTruncateWithPACK(unsigned &PackOpcode, EVT DstVT,
+ SDValue In, const SDLoc &DL,
+ SelectionDAG &DAG,
+ const X86Subtarget &Subtarget) {
+ // Requires SSE2.
+ if (!Subtarget.hasSSE2())
+ return SDValue();
+
+ EVT SrcVT = In.getValueType();
+ EVT DstSVT = DstVT.getVectorElementType();
+ EVT SrcSVT = SrcVT.getVectorElementType();
+
+ // Check we have a truncation suited for PACKSS/PACKUS.
+ if (!((SrcSVT == MVT::i16 || SrcSVT == MVT::i32 || SrcSVT == MVT::i64) &&
+ (DstSVT == MVT::i8 || DstSVT == MVT::i16 || DstSVT == MVT::i32)))
+ return SDValue();
+
+ assert(SrcSVT.getSizeInBits() > DstSVT.getSizeInBits() && "Bad truncation");
+ unsigned NumStages = Log2_32(SrcSVT.getSizeInBits() / DstSVT.getSizeInBits());
+
+ // Truncation to sub-128bit vXi32 can be better handled with shuffles.
+ if (DstSVT == MVT::i32 && SrcVT.getSizeInBits() <= 128)
+ return SDValue();
+
+ // Truncation from v2i64 to v2i8 can be better handled with PSHUFB.
+ if (DstVT == MVT::v2i8 && SrcVT == MVT::v2i64 && Subtarget.hasSSSE3())
+ return SDValue();
+
+ // Don't truncate AVX512 targets as multiple PACK nodes stages.
+ if (Subtarget.hasAVX512() && NumStages > 1)
+ return SDValue();
+
+ unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
+ unsigned NumPackedSignBits = std::min<unsigned>(DstSVT.getSizeInBits(), 16);
+ unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8;
+
+ // Truncate with PACKUS if we are truncating a vector with leading zero
+ // bits that extend all the way to the packed/truncated value.
+ // e.g. Masks, zext_in_reg, etc.
+ // Pre-SSE41 we can only use PACKUSWB.
+ KnownBits Known = DAG.computeKnownBits(In);
+ if ((NumSrcEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros()) {
+ PackOpcode = X86ISD::PACKUS;
+ return In;
+ }
+
+ // Truncate with PACKSS if we are truncating a vector with sign-bits
+ // that extend all the way to the packed/truncated value.
+ // e.g. Comparison result, sext_in_reg, etc.
+ unsigned NumSignBits = DAG.ComputeNumSignBits(In);
+
+ // Don't use PACKSS for vXi64 -> vXi32 truncations unless we're dealing with
+ // a sign splat (or AVX512 VPSRAQ support). ComputeNumSignBits struggles to
+ // see through BITCASTs later on and combines/simplifications can't then use
+ // it.
+ if (DstSVT == MVT::i32 && NumSignBits != SrcSVT.getSizeInBits() &&
+ !Subtarget.hasAVX512())
+ return SDValue();
+
+ unsigned MinSignBits = NumSrcEltBits - NumPackedSignBits;
+ if (MinSignBits < NumSignBits) {
+ PackOpcode = X86ISD::PACKSS;
+ return In;
+ }
+
+ // If we have a srl that only generates signbits that we will discard in
+ // the truncation then we can use PACKSS by converting the srl to a sra.
+ // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
+ if (In.getOpcode() == ISD::SRL && In->hasOneUse())
+ if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(
+ In, APInt::getAllOnes(SrcVT.getVectorNumElements()))) {
+ if (*ShAmt == MinSignBits) {
+ PackOpcode = X86ISD::PACKSS;
+ return DAG.getNode(ISD::SRA, DL, SrcVT, In->ops());
+ }
+ }
+
+ return SDValue();
+}
+
/// This function lowers a vector truncation of 'extended sign-bits' or
/// 'extended zero-bits' values.
/// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS/PACKUS operations.
@@ -20119,11 +20202,6 @@ static SDValue LowerTruncateVecPackWithSignBits(MVT DstVT, SDValue In,
(DstSVT == MVT::i8 || DstSVT == MVT::i16 || DstSVT == MVT::i32)))
return SDValue();
- // Don't lower with PACK nodes on AVX512 targets if we'd need more than one.
- if (Subtarget.hasAVX512() &&
- SrcSVT.getSizeInBits() > (DstSVT.getSizeInBits() * 2))
- return SDValue();
-
// Prefer to lower v4i64 -> v4i32 as a shuffle unless we can cheaply
// split this for packing.
if (SrcVT == MVT::v4i64 && DstVT == MVT::v4i32 &&
@@ -20144,25 +20222,10 @@ static SDValue LowerTruncateVecPackWithSignBits(MVT DstVT, SDValue In,
}
}
- unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
- unsigned NumPackedSignBits = std::min<unsigned>(DstSVT.getSizeInBits(), 16);
- unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8;
-
- // Truncate with PACKUS if we are truncating a vector with leading zero
- // bits that extend all the way to the packed/truncated value. Pre-SSE41
- // we can only use PACKUSWB.
- KnownBits Known = DAG.computeKnownBits(In);
- if ((NumSrcEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros())
- if (SDValue V = truncateVectorWithPACK(X86ISD::PACKUS, DstVT, In, DL, DAG,
- Subtarget))
- return V;
-
- // Truncate with PACKSS if we are truncating a vector with sign-bits
- // that extend all the way to the packed/truncated value.
- if ((NumSrcEltBits - NumPackedSignBits) < DAG.ComputeNumSignBits(In))
- if (SDValue V = truncateVectorWithPACK(X86ISD::PACKSS, DstVT, In, DL, DAG,
- Subtarget))
- return V;
+ unsigned PackOpcode;
+ if (SDValue Src =
+ matchTruncateWithPACK(PackOpcode, DstVT, In, DL, DAG, Subtarget))
+ return truncateVectorWithPACK(PackOpcode, DstVT, Src, DL, DAG, Subtarget);
return SDValue();
}
@@ -31988,39 +32051,17 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
unsigned InBits = InVT.getSizeInBits();
if (128 % InBits == 0) {
- // See if we there are sufficient leading bits to perform a PACKUS/PACKSS.
- // Skip for AVX512 unless this will be a single stage truncation.
- if ((InEltVT == MVT::i16 || InEltVT == MVT::i32) &&
- (EltVT == MVT::i8 || EltVT == MVT::i16) &&
- (!Subtarget.hasAVX512() || InBits == (2 * VT.getSizeInBits()))) {
- unsigned NumPackedSignBits =
- std::min<unsigned>(EltVT.getSizeInBits(), 16);
- unsigned NumPackedZeroBits =
- Subtarget.hasSSE41() ? NumPackedSignBits : 8;
-
- // Use PACKUS if the input has zero-bits that extend all the way to the
- // packed/truncated value. e.g. masks, zext_in_reg, etc.
- KnownBits Known = DAG.computeKnownBits(In);
- unsigned NumLeadingZeroBits = Known.countMinLeadingZeros();
- bool UsePACKUS =
- NumLeadingZeroBits >= (InEltVT.getSizeInBits() - NumPackedZeroBits);
-
- // Use PACKSS if the input has sign-bits that extend all the way to the
- // packed/truncated value. e.g. Comparison result, sext_in_reg, etc.
- unsigned NumSignBits = DAG.ComputeNumSignBits(In);
- bool UsePACKSS =
- NumSignBits > (InEltVT.getSizeInBits() - NumPackedSignBits);
-
- if (UsePACKUS || UsePACKSS) {
- SDValue WidenIn =
- widenSubVector(In, false, Subtarget, DAG, dl,
- InEltVT.getSizeInBits() * WidenNumElts);
- if (SDValue Res = truncateVectorWithPACK(
- UsePACKUS ? X86ISD::PACKUS : X86ISD::PACKSS, WidenVT, WidenIn,
- dl, DAG, Subtarget)) {
- Results.push_back(Res);
- return;
- }
+ // See if there are sufficient leading bits to perform a PACKUS/PACKSS.
+ unsigned PackOpcode;
+ if (SDValue Src =
+ matchTruncateWithPACK(PackOpcode, VT, In, dl, DAG, Subtarget)) {
+ SDValue WidenSrc =
+ widenSubVector(Src, false, Subtarget, DAG, dl,
+ InEltVT.getSizeInBits() * WidenNumElts);
+ if (SDValue Res = truncateVectorWithPACK(PackOpcode, WidenVT, WidenSrc,
+ dl, DAG, Subtarget)) {
+ Results.push_back(Res);
+ return;
}
}
@@ -50816,26 +50857,7 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, const SDLoc &DL,
return SDValue();
MVT VT = N->getValueType(0).getSimpleVT();
- MVT SVT = VT.getScalarType();
-
MVT InVT = In.getValueType().getSimpleVT();
- MVT InSVT = InVT.getScalarType();
-
- // Check we have a truncation suited for PACKSS/PACKUS.
- if (!isPowerOf2_32(VT.getVectorNumElements()))
- return SDValue();
- if (SVT != MVT::i8 && SVT != MVT::i16 && SVT != MVT::i32)
- return SDValue();
- if (InSVT != MVT::i16 && InSVT != MVT::i32 && InSVT != MVT::i64)
- return SDValue();
-
- // Truncation to sub-128bit vXi32 can be better handled with shuffles.
- if (SVT == MVT::i32 && VT.getSizeInBits() < 128)
- return SDValue();
-
- // Truncation from sub-128bit to vXi8 can be better handled with PSHUFB.
- if (SVT == MVT::i8 && InVT.getSizeInBits() <= 128 && Subtarget.hasSSSE3())
- return SDValue();
// AVX512 has fast truncate, but if the input is already going to be split,
// there's no harm in trying pack.
@@ -50848,42 +50870,10 @@ static SDValue combineVectorSignBitsTruncation(SDNode *N, const SDLoc &DL,
return SDValue();
}
- unsigned NumPackedSignBits = std::min<unsigned>(SVT.getSizeInBits(), 16);
- unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8;
-
- // Use PACKUS if the input has zero-bits that extend all the way to the
- // packed/truncated value. e.g. masks, zext_in_reg, etc.
- KnownBits Known = DAG.computeKnownBits(In);
- unsigned NumLeadingZeroBits = Known.countMinLeadingZeros();
- if (NumLeadingZeroBits >= (InSVT.getSizeInBits() - NumPackedZeroBits))
- return truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG, Subtarget);
-
- // Use PACKSS if the input has sign-bits that extend all the way to the
- // packed/truncated value. e.g. Comparison result, sext_in_reg, etc.
- unsigned NumSignBits = DAG.ComputeNumSignBits(In);
-
- // Don't use PACKSS for vXi64 -> vXi32 truncations unless we're dealing with
- // a sign splat. ComputeNumSignBits struggles to see through BITCASTs later
- // on and combines/simplifications can't then use it.
- if (SVT == MVT::i32 && NumSignBits != InSVT.getSizeInBits())
- return SDValue();
-
- unsigned MinSignBits = InSVT.getSizeInBits() - NumPackedSignBits;
- if (NumSignBits > MinSignBits)
- return truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG, Subtarget);
-
- // If we have a srl that only generates signbits that we will discard in
- // the truncation then we can use PACKSS by converting the srl to a sra.
- // SimplifyDemandedBits often relaxes sra to srl so we need to reverse it.
- if (In.getOpcode() == ISD::SRL && N->isOnlyUserOf(In.getNode()))
- if (const APInt *ShAmt = DAG.getValidShiftAmountConstant(
- In, APInt::getAllOnes(VT.getVectorNumElements()))) {
- if (*ShAmt == MinSignBits) {
- SDValue NewIn = DAG.getNode(ISD::SRA, DL, InVT, In->ops());
- return truncateVectorWithPACK(X86ISD::PACKSS, VT, NewIn, DL, DAG,
- Subtarget);
- }
- }
+ unsigned PackOpcode;
+ if (SDValue Src =
+ matchTruncateWithPACK(PackOpcode, VT, In, DL, DAG, Subtarget))
+ return truncateVectorWithPACK(PackOpcode, VT, Src, DL, DAG, Subtarget);
return SDValue();
}
diff --git a/llvm/test/CodeGen/X86/vector-trunc-packus.ll b/llvm/test/CodeGen/X86/vector-trunc-packus.ll
index 93fa09e26121cc..52e533b5154806 100644
--- a/llvm/test/CodeGen/X86/vector-trunc-packus.ll
+++ b/llvm/test/CodeGen/X86/vector-trunc-packus.ll
@@ -4066,15 +4066,24 @@ define <4 x i8> @trunc_packus_v4i32_v4i8(<4 x i32> %a0) "min-legal-vector-width"
; AVX1-NEXT: vpackuswb %xmm0, %xmm0, %xmm0
; AVX1-NEXT: retq
;
-; AVX2-LABEL: trunc_packus_v4i32_v4i8:
-; AVX2: # %bb.0:
-; AVX2-NEXT: vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
-; AVX2-NEXT: vpminsd %xmm1, %xmm0, %xmm0
-; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
-; AVX2-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
-; AVX2-NEXT: vpackusdw %xmm0, %xmm0, %xmm0
-; AVX2-NEXT: vpackuswb %xmm0, %xmm0, %xmm0
-; AVX2-NEXT: retq
+; AVX2-SLOW-LABEL: trunc_packus_v4i32_v4i8:
+; AVX2-SLOW: # %bb.0:
+; AVX2-SLOW-NEXT: vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
+; AVX2-SLOW-NEXT: vpminsd %xmm1, %xmm0, %xmm0
+; AVX2-SLOW-NEXT: vpxor %xmm1, %xmm1, %xmm1
+; AVX2-SLOW-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
+; AVX2-SLOW-NEXT: vpackusdw %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT: vpackuswb %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT: retq
+;
+; AVX2-FAST-LABEL: trunc_packus_v4i32_v4i8:
+; AVX2-FAST: # %bb.0:
+; AVX2-FAST-NEXT: vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
+; AVX2-FAST-NEXT: vpminsd %xmm1, %xmm0, %xmm0
+; AVX2-FAST-NEXT: vpxor %xmm1, %xmm1, %xmm1
+; AVX2-FAST-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
+; AVX2-FAST-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,4,8,12,u,u,u,u,u,u,u,u,u,u,u,u]
+; AVX2-FAST-NEXT: retq
;
; AVX512F-LABEL: trunc_packus_v4i32_v4i8:
; AVX512F: # %bb.0:
diff --git a/llvm/test/CodeGen/X86/vector-trunc-ssat.ll b/llvm/test/CodeGen/X86/vector-trunc-ssat.ll
index ffcd593a2c67ad..6fa548ae93cff0 100644
--- a/llvm/test/CodeGen/X86/vector-trunc-ssat.ll
+++ b/llvm/test/CodeGen/X86/vector-trunc-ssat.ll
@@ -3789,15 +3789,24 @@ define <4 x i8> @trunc_ssat_v4i32_v4i8(<4 x i32> %a0) {
; AVX1-NEXT: vpacksswb %xmm0, %xmm0, %xmm0
; AVX1-NEXT: retq
;
-; AVX2-LABEL: trunc_ssat_v4i32_v4i8:
-; AVX2: # %bb.0:
-; AVX2-NEXT: vpbroadcastd {{.*#+}} xmm1 = [127,127,127,127]
-; AVX2-NEXT: vpminsd %xmm1, %xmm0, %xmm0
-; AVX2-NEXT: vpbroadcastd {{.*#+}} xmm1 = [4294967168,4294967168,4294967168,4294967168]
-; AVX2-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
-; AVX2-NEXT: vpackssdw %xmm0, %xmm0, %xmm0
-; AVX2-NEXT: vpacksswb %xmm0, %xmm0, %xmm0
-; AVX2-NEXT: retq
+; AVX2-SLOW-LABEL: trunc_ssat_v4i32_v4i8:
+; AVX2-SLOW: # %bb.0:
+; AVX2-SLOW-NEXT: vpbroadcastd {{.*#+}} xmm1 = [127,127,127,127]
+; AVX2-SLOW-NEXT: vpminsd %xmm1, %xmm0, %xmm0
+; AVX2-SLOW-NEXT: vpbroadcastd {{.*#+}} xmm1 = [4294967168,4294967168,4294967168,4294967168]
+; AVX2-SLOW-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
+; AVX2-SLOW-NEXT: vpackssdw %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT: vpacksswb %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT: retq
+;
+; AVX2-FAST-LABEL: trunc_ssat_v4i32_v4i8:
+; AVX2-FAST: # %bb.0:
+; AVX2-FAST-NEXT: vpbroadcastd {{.*#+}} xmm1 = [127,127,127,127]
+; AVX2-FAST-NEXT: vpminsd %xmm1, %xmm0, %xmm0
+; AVX2-FAST-NEXT: vpbroadcastd {{.*#+}} xmm1 = [4294967168,4294967168,4294967168,4294967168]
+; AVX2-FAST-NEXT: vpmaxsd %xmm1, %xmm0, %xmm0
+; AVX2-FAST-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,4,8,12,u,u,u,u,u,u,u,u,u,u,u,u]
+; AVX2-FAST-NEXT: retq
;
; AVX512F-LABEL: trunc_ssat_v4i32_v4i8:
; AVX512F: # %bb.0:
diff --git a/llvm/test/CodeGen/X86/vector-trunc-usat.ll b/llvm/test/CodeGen/X86/vector-trunc-usat.ll
index 7030372390598e..fded69f955ccd2 100644
--- a/llvm/test/CodeGen/X86/vector-trunc-usat.ll
+++ b/llvm/test/CodeGen/X86/vector-trunc-usat.ll
@@ -2906,13 +2906,20 @@ define <4 x i8> @trunc_usat_v4i32_v4i8(<4 x i32> %a0) {
; AVX1-NEXT: vpackuswb %xmm0, %xmm0, %xmm0
; AVX1-NEXT: retq
;
-; AVX2-LABEL: trunc_usat_v4i32_v4i8:
-; AVX2: # %bb.0:
-; AVX2-NEXT: vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
-; AVX2-NEXT: vpminud %xmm1, %xmm0, %xmm0
-; AVX2-NEXT: vpackusdw %xmm0, %xmm0, %xmm0
-; AVX2-NEXT: vpackuswb %xmm0, %xmm0, %xmm0
-; AVX2-NEXT: retq
+; AVX2-SLOW-LABEL: trunc_usat_v4i32_v4i8:
+; AVX2-SLOW: # %bb.0:
+; AVX2-SLOW-NEXT: vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
+; AVX2-SLOW-NEXT: vpminud %xmm1, %xmm0, %xmm0
+; AVX2-SLOW-NEXT: vpackusdw %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT: vpackuswb %xmm0, %xmm0, %xmm0
+; AVX2-SLOW-NEXT: retq
+;
+; AVX2-FAST-LABEL: trunc_usat_v4i32_v4i8:
+; AVX2-FAST: # %bb.0:
+; AVX2-FAST-NEXT: vpbroadcastd {{.*#+}} xmm1 = [255,255,255,255]
+; AVX2-FAST-NEXT: vpminud %xmm1, %xmm0, %xmm0
+; AVX2-FAST-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,4,8,12,u,u,u,u,u,u,u,u,u,u,u,u]
+; AVX2-FAST-NEXT: retq
;
; AVX512F-LABEL: trunc_usat_v4i32_v4i8:
; AVX512F: # %bb.0:
More information about the llvm-commits
mailing list