[llvm] fa0e9ac - [X86] Remove PMADDWD/PMADDUBSW known bits handling due to performance issues

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 27 11:10:11 PDT 2024


Author: Simon Pilgrim
Date: 2024-06-27T19:09:54+01:00
New Revision: fa0e9acea5e4d363eef6acc484afc1b22ab8e698

URL: https://github.com/llvm/llvm-project/commit/fa0e9acea5e4d363eef6acc484afc1b22ab8e698
DIFF: https://github.com/llvm/llvm-project/commit/fa0e9acea5e4d363eef6acc484afc1b22ab8e698.diff

LOG: [X86] Remove PMADDWD/PMADDUBSW known bits handling due to performance issues

This appears to be causing an slow (infinite?) loop when building the highway open source project - most likely due to the high number of computeKnownBits calls (although improving early-out doesn't appear help so far).

I'm reverting support to unstick the highway team and will revisit this shortly.

Reported by @alexfh

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/combine-pmadd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 3bbf009a1defd..767c58270a4dc 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -37139,52 +37139,6 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
   Known = Known.zext(64);
 }
 
-static void computeKnownBitsForPMADDWD(SDValue LHS, SDValue RHS,
-                                       KnownBits &Known,
-                                       const APInt &DemandedElts,
-                                       const SelectionDAG &DAG,
-                                       unsigned Depth) {
-  unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
-
-  // Multiply signed i16 elements to create i32 values and add Lo/Hi pairs.
-  APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
-  APInt DemandedLoElts =
-      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
-  APInt DemandedHiElts =
-      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
-  KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
-  KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
-  KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
-  KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
-  KnownBits Lo = KnownBits::mul(LHSLo.sext(32), RHSLo.sext(32));
-  KnownBits Hi = KnownBits::mul(LHSHi.sext(32), RHSHi.sext(32));
-  Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
-                                      /*NUW=*/false, Lo, Hi);
-}
-
-static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
-                                         KnownBits &Known,
-                                         const APInt &DemandedElts,
-                                         const SelectionDAG &DAG,
-                                         unsigned Depth) {
-  unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
-
-  // Multiply unsigned/signed i8 elements to create i16 values and add_sat Lo/Hi
-  // pairs.
-  APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
-  APInt DemandedLoElts =
-      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
-  APInt DemandedHiElts =
-      DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
-  KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
-  KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
-  KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
-  KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
-  KnownBits Lo = KnownBits::mul(LHSLo.zext(16), RHSLo.sext(16));
-  KnownBits Hi = KnownBits::mul(LHSHi.zext(16), RHSHi.sext(16));
-  Known = KnownBits::sadd_sat(Lo, Hi);
-}
-
 void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
                                                       KnownBits &Known,
                                                       const APInt &DemandedElts,
@@ -37360,26 +37314,6 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     }
     break;
   }
-  case X86ISD::VPMADDWD: {
-    SDValue LHS = Op.getOperand(0);
-    SDValue RHS = Op.getOperand(1);
-    assert(VT.getVectorElementType() == MVT::i32 &&
-           LHS.getValueType() == RHS.getValueType() &&
-           LHS.getValueType().getVectorElementType() == MVT::i16 &&
-           "Unexpected PMADDWD types");
-    computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
-    break;
-  }
-  case X86ISD::VPMADDUBSW: {
-    SDValue LHS = Op.getOperand(0);
-    SDValue RHS = Op.getOperand(1);
-    assert(VT.getVectorElementType() == MVT::i16 &&
-           LHS.getValueType() == RHS.getValueType() &&
-           LHS.getValueType().getVectorElementType() == MVT::i8 &&
-           "Unexpected PMADDUBSW types");
-    computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
-    break;
-  }
   case X86ISD::PMULUDQ: {
     KnownBits Known2;
     Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
@@ -37516,30 +37450,6 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
   }
   case ISD::INTRINSIC_WO_CHAIN: {
     switch (Op->getConstantOperandVal(0)) {
-    case Intrinsic::x86_sse2_pmadd_wd:
-    case Intrinsic::x86_avx2_pmadd_wd:
-    case Intrinsic::x86_avx512_pmaddw_d_512: {
-      SDValue LHS = Op.getOperand(1);
-      SDValue RHS = Op.getOperand(2);
-      assert(VT.getScalarType() == MVT::i32 &&
-             LHS.getValueType() == RHS.getValueType() &&
-             LHS.getValueType().getScalarType() == MVT::i16 &&
-             "Unexpected PMADDWD types");
-      computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
-      break;
-    }
-    case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
-    case Intrinsic::x86_avx2_pmadd_ub_sw:
-    case Intrinsic::x86_avx512_pmaddubs_w_512: {
-      SDValue LHS = Op.getOperand(1);
-      SDValue RHS = Op.getOperand(2);
-      assert(VT.getScalarType() == MVT::i16 &&
-             LHS.getValueType() == RHS.getValueType() &&
-             LHS.getValueType().getScalarType() == MVT::i8 &&
-             "Unexpected PMADDUBSW types");
-      computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
-      break;
-    }
     case Intrinsic::x86_sse2_psad_bw:
     case Intrinsic::x86_avx2_psad_bw:
     case Intrinsic::x86_avx512_psad_bw_512: {

diff  --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index 403ee72e9dd98..faa20db5acd4f 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -88,11 +88,21 @@ define <4 x i32> @combine_pmaddwd_demandedelts(<8 x i16> %a0, <8 x i16> %a1) {
   ret <4 x i32> %4
 }
 
+; TODO
 define i32 @combine_pmaddwd_constant() {
-; CHECK-LABEL: combine_pmaddwd_constant:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    movl $-155, %eax
-; CHECK-NEXT:    retq
+; SSE-LABEL: combine_pmaddwd_constant:
+; SSE:       # %bb.0:
+; SSE-NEXT:    pmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
+; SSE-NEXT:    pmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
+; SSE-NEXT:    pextrd $2, %xmm0, %eax
+; SSE-NEXT:    retq
+;
+; AVX-LABEL: combine_pmaddwd_constant:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vpmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
+; AVX-NEXT:    vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
+; AVX-NEXT:    vpextrd $2, %xmm0, %eax
+; AVX-NEXT:    retq
   %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -9, i16 -11, i16 13, i16 -15, i16 17, i16 -19>)
   %2 = extractelement <4 x i32> %1, i32 2 ; (-5*13)+(6*-15) = -155
   ret i32 %2
@@ -100,10 +110,26 @@ define i32 @combine_pmaddwd_constant() {
 
 ; ensure we don't assume pmaddwd performs add nsw
 define i32 @combine_pmaddwd_constant_nsw() {
-; CHECK-LABEL: combine_pmaddwd_constant_nsw:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    movl $-2147483648, %eax # imm = 0x80000000
-; CHECK-NEXT:    retq
+; SSE-LABEL: combine_pmaddwd_constant_nsw:
+; SSE:       # %bb.0:
+; SSE-NEXT:    movdqa {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
+; SSE-NEXT:    pmaddwd %xmm0, %xmm0
+; SSE-NEXT:    movd %xmm0, %eax
+; SSE-NEXT:    retq
+;
+; AVX1-LABEL: combine_pmaddwd_constant_nsw:
+; AVX1:       # %bb.0:
+; AVX1-NEXT:    vbroadcastss {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
+; AVX1-NEXT:    vpmaddwd %xmm0, %xmm0, %xmm0
+; AVX1-NEXT:    vmovd %xmm0, %eax
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: combine_pmaddwd_constant_nsw:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpbroadcastw {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
+; AVX2-NEXT:    vpmaddwd %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    vmovd %xmm0, %eax
+; AVX2-NEXT:    retq
   %1 = insertelement <8 x i16> undef, i16 32768, i32 0
   %2 = shufflevector <8 x i16> %1, <8 x i16> undef, <8 x i32> zeroinitializer
   %3 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %2, <8 x i16> %2)
@@ -193,25 +219,51 @@ define <8 x i16> @combine_pmaddubsw_demandedelts(<16 x i8> %a0, <16 x i8> %a1) {
   ret <8 x i16> %4
 }
 
+; TODO
 define i32 @combine_pmaddubsw_constant() {
-; CHECK-LABEL: combine_pmaddubsw_constant:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    movl $1694, %eax # imm = 0x69E
-; CHECK-NEXT:    retq
+; SSE-LABEL: combine_pmaddubsw_constant:
+; SSE:       # %bb.0:
+; SSE-NEXT:    movdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
+; SSE-NEXT:    pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
+; SSE-NEXT:    pextrw $3, %xmm0, %eax
+; SSE-NEXT:    cwtl
+; SSE-NEXT:    retq
+;
+; AVX-LABEL: combine_pmaddubsw_constant:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vmovdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
+; AVX-NEXT:    vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
+; AVX-NEXT:    vpextrw $3, %xmm0, %eax
+; AVX-NEXT:    cwtl
+; AVX-NEXT:    retq
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
   %2 = extractelement <8 x i16> %1, i32 3 ; ((uint16_t)-6*7)+(7*-8) = (250*7)+(7*-8) = 1694
   %3 = sext i16 %2 to i32
   ret i32 %3
 }
 
+; TODO
 define i32 @combine_pmaddubsw_constant_sat() {
-; CHECK-LABEL: combine_pmaddubsw_constant_sat:
-; CHECK:       # %bb.0:
-; CHECK-NEXT:    movl $-32768, %eax # imm = 0x8000
-; CHECK-NEXT:    retq
+; SSE-LABEL: combine_pmaddubsw_constant_sat:
+; SSE:       # %bb.0:
+; SSE-NEXT:    movdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
+; SSE-NEXT:    pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
+; SSE-NEXT:    movd %xmm0, %eax
+; SSE-NEXT:    cwtl
+; SSE-NEXT:    retq
+;
+; AVX-LABEL: combine_pmaddubsw_constant_sat:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vmovdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
+; AVX-NEXT:    vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
+; AVX-NEXT:    vmovd %xmm0, %eax
+; AVX-NEXT:    cwtl
+; AVX-NEXT:    retq
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
   %2 = extractelement <8 x i16> %1, i32 0 ; add_sat_i16(((uint16_t)-1*-128),((uint16_t)-1*-128)_ = add_sat_i16(255*-128),(255*-128)) = sat_i16(-65280) = -32768
   %3 = sext i16 %2 to i32
   ret i32 %3
 }
 
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; CHECK: {{.*}}


        


More information about the llvm-commits mailing list