[llvm] fa0e9ac - [X86] Remove PMADDWD/PMADDUBSW known bits handling due to performance issues
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 27 11:10:11 PDT 2024
Author: Simon Pilgrim
Date: 2024-06-27T19:09:54+01:00
New Revision: fa0e9acea5e4d363eef6acc484afc1b22ab8e698
URL: https://github.com/llvm/llvm-project/commit/fa0e9acea5e4d363eef6acc484afc1b22ab8e698
DIFF: https://github.com/llvm/llvm-project/commit/fa0e9acea5e4d363eef6acc484afc1b22ab8e698.diff
LOG: [X86] Remove PMADDWD/PMADDUBSW known bits handling due to performance issues
This appears to be causing an slow (infinite?) loop when building the highway open source project - most likely due to the high number of computeKnownBits calls (although improving early-out doesn't appear help so far).
I'm reverting support to unstick the highway team and will revisit this shortly.
Reported by @alexfh
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/combine-pmadd.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 3bbf009a1defd..767c58270a4dc 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -37139,52 +37139,6 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
Known = Known.zext(64);
}
-static void computeKnownBitsForPMADDWD(SDValue LHS, SDValue RHS,
- KnownBits &Known,
- const APInt &DemandedElts,
- const SelectionDAG &DAG,
- unsigned Depth) {
- unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
-
- // Multiply signed i16 elements to create i32 values and add Lo/Hi pairs.
- APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
- APInt DemandedLoElts =
- DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
- APInt DemandedHiElts =
- DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
- KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
- KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
- KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
- KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
- KnownBits Lo = KnownBits::mul(LHSLo.sext(32), RHSLo.sext(32));
- KnownBits Hi = KnownBits::mul(LHSHi.sext(32), RHSHi.sext(32));
- Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
- /*NUW=*/false, Lo, Hi);
-}
-
-static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
- KnownBits &Known,
- const APInt &DemandedElts,
- const SelectionDAG &DAG,
- unsigned Depth) {
- unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
-
- // Multiply unsigned/signed i8 elements to create i16 values and add_sat Lo/Hi
- // pairs.
- APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
- APInt DemandedLoElts =
- DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
- APInt DemandedHiElts =
- DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
- KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
- KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
- KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
- KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
- KnownBits Lo = KnownBits::mul(LHSLo.zext(16), RHSLo.sext(16));
- KnownBits Hi = KnownBits::mul(LHSHi.zext(16), RHSHi.sext(16));
- Known = KnownBits::sadd_sat(Lo, Hi);
-}
-
void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
KnownBits &Known,
const APInt &DemandedElts,
@@ -37360,26 +37314,6 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
}
break;
}
- case X86ISD::VPMADDWD: {
- SDValue LHS = Op.getOperand(0);
- SDValue RHS = Op.getOperand(1);
- assert(VT.getVectorElementType() == MVT::i32 &&
- LHS.getValueType() == RHS.getValueType() &&
- LHS.getValueType().getVectorElementType() == MVT::i16 &&
- "Unexpected PMADDWD types");
- computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
- break;
- }
- case X86ISD::VPMADDUBSW: {
- SDValue LHS = Op.getOperand(0);
- SDValue RHS = Op.getOperand(1);
- assert(VT.getVectorElementType() == MVT::i16 &&
- LHS.getValueType() == RHS.getValueType() &&
- LHS.getValueType().getVectorElementType() == MVT::i8 &&
- "Unexpected PMADDUBSW types");
- computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
- break;
- }
case X86ISD::PMULUDQ: {
KnownBits Known2;
Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
@@ -37516,30 +37450,6 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
}
case ISD::INTRINSIC_WO_CHAIN: {
switch (Op->getConstantOperandVal(0)) {
- case Intrinsic::x86_sse2_pmadd_wd:
- case Intrinsic::x86_avx2_pmadd_wd:
- case Intrinsic::x86_avx512_pmaddw_d_512: {
- SDValue LHS = Op.getOperand(1);
- SDValue RHS = Op.getOperand(2);
- assert(VT.getScalarType() == MVT::i32 &&
- LHS.getValueType() == RHS.getValueType() &&
- LHS.getValueType().getScalarType() == MVT::i16 &&
- "Unexpected PMADDWD types");
- computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
- break;
- }
- case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
- case Intrinsic::x86_avx2_pmadd_ub_sw:
- case Intrinsic::x86_avx512_pmaddubs_w_512: {
- SDValue LHS = Op.getOperand(1);
- SDValue RHS = Op.getOperand(2);
- assert(VT.getScalarType() == MVT::i16 &&
- LHS.getValueType() == RHS.getValueType() &&
- LHS.getValueType().getScalarType() == MVT::i8 &&
- "Unexpected PMADDUBSW types");
- computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
- break;
- }
case Intrinsic::x86_sse2_psad_bw:
case Intrinsic::x86_avx2_psad_bw:
case Intrinsic::x86_avx512_psad_bw_512: {
diff --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index 403ee72e9dd98..faa20db5acd4f 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -88,11 +88,21 @@ define <4 x i32> @combine_pmaddwd_demandedelts(<8 x i16> %a0, <8 x i16> %a1) {
ret <4 x i32> %4
}
+; TODO
define i32 @combine_pmaddwd_constant() {
-; CHECK-LABEL: combine_pmaddwd_constant:
-; CHECK: # %bb.0:
-; CHECK-NEXT: movl $-155, %eax
-; CHECK-NEXT: retq
+; SSE-LABEL: combine_pmaddwd_constant:
+; SSE: # %bb.0:
+; SSE-NEXT: pmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
+; SSE-NEXT: pmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
+; SSE-NEXT: pextrd $2, %xmm0, %eax
+; SSE-NEXT: retq
+;
+; AVX-LABEL: combine_pmaddwd_constant:
+; AVX: # %bb.0:
+; AVX-NEXT: vpmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
+; AVX-NEXT: vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
+; AVX-NEXT: vpextrd $2, %xmm0, %eax
+; AVX-NEXT: retq
%1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -9, i16 -11, i16 13, i16 -15, i16 17, i16 -19>)
%2 = extractelement <4 x i32> %1, i32 2 ; (-5*13)+(6*-15) = -155
ret i32 %2
@@ -100,10 +110,26 @@ define i32 @combine_pmaddwd_constant() {
; ensure we don't assume pmaddwd performs add nsw
define i32 @combine_pmaddwd_constant_nsw() {
-; CHECK-LABEL: combine_pmaddwd_constant_nsw:
-; CHECK: # %bb.0:
-; CHECK-NEXT: movl $-2147483648, %eax # imm = 0x80000000
-; CHECK-NEXT: retq
+; SSE-LABEL: combine_pmaddwd_constant_nsw:
+; SSE: # %bb.0:
+; SSE-NEXT: movdqa {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
+; SSE-NEXT: pmaddwd %xmm0, %xmm0
+; SSE-NEXT: movd %xmm0, %eax
+; SSE-NEXT: retq
+;
+; AVX1-LABEL: combine_pmaddwd_constant_nsw:
+; AVX1: # %bb.0:
+; AVX1-NEXT: vbroadcastss {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
+; AVX1-NEXT: vpmaddwd %xmm0, %xmm0, %xmm0
+; AVX1-NEXT: vmovd %xmm0, %eax
+; AVX1-NEXT: retq
+;
+; AVX2-LABEL: combine_pmaddwd_constant_nsw:
+; AVX2: # %bb.0:
+; AVX2-NEXT: vpbroadcastw {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
+; AVX2-NEXT: vpmaddwd %xmm0, %xmm0, %xmm0
+; AVX2-NEXT: vmovd %xmm0, %eax
+; AVX2-NEXT: retq
%1 = insertelement <8 x i16> undef, i16 32768, i32 0
%2 = shufflevector <8 x i16> %1, <8 x i16> undef, <8 x i32> zeroinitializer
%3 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %2, <8 x i16> %2)
@@ -193,25 +219,51 @@ define <8 x i16> @combine_pmaddubsw_demandedelts(<16 x i8> %a0, <16 x i8> %a1) {
ret <8 x i16> %4
}
+; TODO
define i32 @combine_pmaddubsw_constant() {
-; CHECK-LABEL: combine_pmaddubsw_constant:
-; CHECK: # %bb.0:
-; CHECK-NEXT: movl $1694, %eax # imm = 0x69E
-; CHECK-NEXT: retq
+; SSE-LABEL: combine_pmaddubsw_constant:
+; SSE: # %bb.0:
+; SSE-NEXT: movdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
+; SSE-NEXT: pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
+; SSE-NEXT: pextrw $3, %xmm0, %eax
+; SSE-NEXT: cwtl
+; SSE-NEXT: retq
+;
+; AVX-LABEL: combine_pmaddubsw_constant:
+; AVX: # %bb.0:
+; AVX-NEXT: vmovdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
+; AVX-NEXT: vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
+; AVX-NEXT: vpextrw $3, %xmm0, %eax
+; AVX-NEXT: cwtl
+; AVX-NEXT: retq
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
%2 = extractelement <8 x i16> %1, i32 3 ; ((uint16_t)-6*7)+(7*-8) = (250*7)+(7*-8) = 1694
%3 = sext i16 %2 to i32
ret i32 %3
}
+; TODO
define i32 @combine_pmaddubsw_constant_sat() {
-; CHECK-LABEL: combine_pmaddubsw_constant_sat:
-; CHECK: # %bb.0:
-; CHECK-NEXT: movl $-32768, %eax # imm = 0x8000
-; CHECK-NEXT: retq
+; SSE-LABEL: combine_pmaddubsw_constant_sat:
+; SSE: # %bb.0:
+; SSE-NEXT: movdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
+; SSE-NEXT: pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
+; SSE-NEXT: movd %xmm0, %eax
+; SSE-NEXT: cwtl
+; SSE-NEXT: retq
+;
+; AVX-LABEL: combine_pmaddubsw_constant_sat:
+; AVX: # %bb.0:
+; AVX-NEXT: vmovdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
+; AVX-NEXT: vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
+; AVX-NEXT: vmovd %xmm0, %eax
+; AVX-NEXT: cwtl
+; AVX-NEXT: retq
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
%2 = extractelement <8 x i16> %1, i32 0 ; add_sat_i16(((uint16_t)-1*-128),((uint16_t)-1*-128)_ = add_sat_i16(255*-128),(255*-128)) = sat_i16(-65280) = -32768
%3 = sext i16 %2 to i32
ret i32 %3
}
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; CHECK: {{.*}}
More information about the llvm-commits
mailing list