[llvm] 17c567b - [X86] combineVPMADD - add constant folding support for PMADDWD/PMADDUBSW instructions

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 28 05:27:38 PDT 2024


Author: Simon Pilgrim
Date: 2024-06-28T13:26:43+01:00
New Revision: 17c567b095ab749b59d311ec9d8cd2bae584ac0b

URL: https://github.com/llvm/llvm-project/commit/17c567b095ab749b59d311ec9d8cd2bae584ac0b
DIFF: https://github.com/llvm/llvm-project/commit/17c567b095ab749b59d311ec9d8cd2bae584ac0b.diff

LOG: [X86] combineVPMADD - add constant folding support for PMADDWD/PMADDUBSW instructions

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/combine-pmadd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 767c58270a4dc..348c2b56e6e3c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56922,9 +56922,13 @@ static SDValue combinePMULDQ(SDNode *N, SelectionDAG &DAG,
 // Simplify VPMADDUBSW/VPMADDWD operations.
 static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
                              TargetLowering::DAGCombinerInfo &DCI) {
-  EVT VT = N->getValueType(0);
+  MVT VT = N->getSimpleValueType(0);
   SDValue LHS = N->getOperand(0);
   SDValue RHS = N->getOperand(1);
+  unsigned Opc = N->getOpcode();
+  bool IsPMADDWD = Opc == X86ISD::VPMADDWD;
+  assert((Opc == X86ISD::VPMADDWD || Opc == X86ISD::VPMADDUBSW) &&
+         "Unexpected PMADD opcode");
 
   // Multiply by zero.
   // Don't return LHS/RHS as it may contain UNDEFs.
@@ -56932,6 +56936,27 @@ static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
       ISD::isBuildVectorAllZeros(RHS.getNode()))
     return DAG.getConstant(0, SDLoc(N), VT);
 
+  // Constant folding.
+  APInt LHSUndefs, RHSUndefs;
+  SmallVector<APInt> LHSBits, RHSBits;
+  unsigned SrcEltBits = LHS.getScalarValueSizeInBits();
+  unsigned DstEltBits = VT.getScalarSizeInBits();
+  if (getTargetConstantBitsFromNode(LHS, SrcEltBits, LHSUndefs, LHSBits) &&
+      getTargetConstantBitsFromNode(RHS, SrcEltBits, RHSUndefs, RHSBits)) {
+    SmallVector<APInt> Result;
+    for (unsigned I = 0, E = LHSBits.size(); I != E; I += 2) {
+      APInt LHSLo = LHSBits[I + 0], LHSHi = LHSBits[I + 1];
+      APInt RHSLo = RHSBits[I + 0], RHSHi = RHSBits[I + 1];
+      LHSLo = IsPMADDWD ? LHSLo.sext(DstEltBits) : LHSLo.zext(DstEltBits);
+      LHSHi = IsPMADDWD ? LHSHi.sext(DstEltBits) : LHSHi.zext(DstEltBits);
+      APInt Lo = LHSLo * RHSLo.sext(DstEltBits);
+      APInt Hi = LHSHi * RHSHi.sext(DstEltBits);
+      APInt Res = IsPMADDWD ? (Lo + Hi) : Lo.sadd_sat(Hi);
+      Result.push_back(Res);
+    }
+    return getConstVector(Result, VT, DAG, SDLoc(N));
+  }
+
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   APInt DemandedElts = APInt::getAllOnes(VT.getVectorNumElements());
   if (TLI.SimplifyDemandedVectorElts(SDValue(N, 0), DemandedElts, DCI))

diff  --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index c4c3e2870145e..c5f9c8506f291 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -88,43 +88,33 @@ define <4 x i32> @combine_pmaddwd_demandedelts(<8 x i16> %a0, <8 x i16> %a1) {
   ret <4 x i32> %4
 }
 
-; TODO: [2] = (-5*13)+(6*-15) = -155 = 4294967141
+; [2]: (-5*13)+(6*-15) = -155 = 4294967141
 define <4 x i32> @combine_pmaddwd_constant() {
 ; SSE-LABEL: combine_pmaddwd_constant:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    pmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
-; SSE-NEXT:    pmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
+; SSE-NEXT:    movaps {{.*#+}} xmm0 = [19,17,4294967141,271]
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_pmaddwd_constant:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
-; AVX-NEXT:    vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
+; AVX-NEXT:    vmovaps {{.*#+}} xmm0 = [19,17,4294967141,271]
 ; AVX-NEXT:    retq
   %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -9, i16 -11, i16 13, i16 -15, i16 17, i16 -19>)
   ret <4 x i32> %1
 }
 
 ; ensure we don't assume pmaddwd performs add nsw
-; TODO: (-32768*-32768)+(-32768*-32768) = 0x80000000 = 2147483648
+; [0]: (-32768*-32768)+(-32768*-32768) = 0x80000000 = 2147483648
 define <4 x i32> @combine_pmaddwd_constant_nsw() {
 ; SSE-LABEL: combine_pmaddwd_constant_nsw:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    movdqa {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
-; SSE-NEXT:    pmaddwd %xmm0, %xmm0
+; SSE-NEXT:    movaps {{.*#+}} xmm0 = [2147483648,2147483648,2147483648,2147483648]
 ; SSE-NEXT:    retq
 ;
-; AVX1-LABEL: combine_pmaddwd_constant_nsw:
-; AVX1:       # %bb.0:
-; AVX1-NEXT:    vbroadcastss {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
-; AVX1-NEXT:    vpmaddwd %xmm0, %xmm0, %xmm0
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: combine_pmaddwd_constant_nsw:
-; AVX2:       # %bb.0:
-; AVX2-NEXT:    vpbroadcastw {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
-; AVX2-NEXT:    vpmaddwd %xmm0, %xmm0, %xmm0
-; AVX2-NEXT:    retq
+; AVX-LABEL: combine_pmaddwd_constant_nsw:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vbroadcastss {{.*#+}} xmm0 = [2147483648,2147483648,2147483648,2147483648]
+; AVX-NEXT:    retq
   %1 = insertelement <8 x i16> undef, i16 32768, i32 0
   %2 = shufflevector <8 x i16> %1, <8 x i16> undef, <8 x i32> zeroinitializer
   %3 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %2, <8 x i16> %2)
@@ -213,51 +203,26 @@ define <8 x i16> @combine_pmaddubsw_demandedelts(<16 x i8> %a0, <16 x i8> %a1) {
   ret <8 x i16> %4
 }
 
-; TODO
+; [3]: ((uint16_t)-6*7)+(7*-8) = (250*7)+(7*-8) = 1694
 define i32 @combine_pmaddubsw_constant() {
-; SSE-LABEL: combine_pmaddubsw_constant:
-; SSE:       # %bb.0:
-; SSE-NEXT:    movdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; SSE-NEXT:    pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; SSE-NEXT:    pextrw $3, %xmm0, %eax
-; SSE-NEXT:    cwtl
-; SSE-NEXT:    retq
-;
-; AVX-LABEL: combine_pmaddubsw_constant:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vmovdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; AVX-NEXT:    vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; AVX-NEXT:    vpextrw $3, %xmm0, %eax
-; AVX-NEXT:    cwtl
-; AVX-NEXT:    retq
+; CHECK-LABEL: combine_pmaddubsw_constant:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movl $1694, %eax # imm = 0x69E
+; CHECK-NEXT:    retq
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
-  %2 = extractelement <8 x i16> %1, i32 3 ; ((uint16_t)-6*7)+(7*-8) = (250*7)+(7*-8) = 1694
+  %2 = extractelement <8 x i16> %1, i32 3
   %3 = sext i16 %2 to i32
   ret i32 %3
 }
 
-; TODO
+; [0]: add_sat_i16(((uint16_t)-1*-128),((uint16_t)-1*-128)_ = add_sat_i16(255*-128),(255*-128)) = sat_i16(-65280) = -32768
 define i32 @combine_pmaddubsw_constant_sat() {
-; SSE-LABEL: combine_pmaddubsw_constant_sat:
-; SSE:       # %bb.0:
-; SSE-NEXT:    movdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; SSE-NEXT:    pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; SSE-NEXT:    movd %xmm0, %eax
-; SSE-NEXT:    cwtl
-; SSE-NEXT:    retq
-;
-; AVX-LABEL: combine_pmaddubsw_constant_sat:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vmovdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; AVX-NEXT:    vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; AVX-NEXT:    vmovd %xmm0, %eax
-; AVX-NEXT:    cwtl
-; AVX-NEXT:    retq
+; CHECK-LABEL: combine_pmaddubsw_constant_sat:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movl $-32768, %eax # imm = 0x8000
+; CHECK-NEXT:    retq
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
-  %2 = extractelement <8 x i16> %1, i32 0 ; add_sat_i16(((uint16_t)-1*-128),((uint16_t)-1*-128)_ = add_sat_i16(255*-128),(255*-128)) = sat_i16(-65280) = -32768
+  %2 = extractelement <8 x i16> %1, i32 0
   %3 = sext i16 %2 to i32
   ret i32 %3
 }
-
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; CHECK: {{.*}}


        


More information about the llvm-commits mailing list