[llvm] 87d5bb6 - [X86][SSE] Improve PMADDWD SimplifyDemandedVectorElts handling

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 4 06:01:15 PDT 2021


Author: Simon Pilgrim
Date: 2021-11-04T12:56:31Z
New Revision: 87d5bb66eb84752e4c4400ee8b503169ef456d89

URL: https://github.com/llvm/llvm-project/commit/87d5bb66eb84752e4c4400ee8b503169ef456d89
DIFF: https://github.com/llvm/llvm-project/commit/87d5bb66eb84752e4c4400ee8b503169ef456d89.diff

LOG: [X86][SSE] Improve PMADDWD SimplifyDemandedVectorElts handling

Check both operands for zero elements to remove unnecessary demanded elts.

Try to help reduce some minor regressions noticed in D110995

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/madd.ll
    llvm/test/CodeGen/X86/shrink_vmul.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index f54a2ae2b4e3d..59ed43dd680d5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -39742,17 +39742,24 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
     SDValue RHS = Op.getOperand(1);
     APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, 2 * NumElts);
 
-    APInt DemandedRHSElts = DemandedSrcElts;
-    if (SimplifyDemandedVectorElts(RHS, DemandedRHSElts, RHSUndef, RHSZero, TLO,
+    if (SimplifyDemandedVectorElts(LHS, DemandedSrcElts, LHSUndef, LHSZero, TLO,
+                                   Depth + 1))
+      return true;
+    if (SimplifyDemandedVectorElts(RHS, DemandedSrcElts, RHSUndef, RHSZero, TLO,
                                    Depth + 1))
       return true;
 
-    // If RHS elements are known zero then we don't need the LHS equivalent.
+    // TODO: Multiply by zero.
+
+    // If RHS/LHS elements are known zero then we don't need the LHS/RHS equivalent.
     APInt DemandedLHSElts = DemandedSrcElts & ~RHSZero;
     if (SimplifyDemandedVectorElts(LHS, DemandedLHSElts, LHSUndef, LHSZero, TLO,
                                    Depth + 1))
       return true;
-    // TODO: Multiply by zero.
+    APInt DemandedRHSElts = DemandedSrcElts & ~LHSZero;
+    if (SimplifyDemandedVectorElts(RHS, DemandedRHSElts, RHSUndef, RHSZero, TLO,
+                                   Depth + 1))
+      return true;
     break;
   }
   case X86ISD::PSADBW: {
@@ -52448,6 +52455,7 @@ static SDValue combinePMULDQ(SDNode *N, SelectionDAG &DAG,
 // Simplify VPMADDUBSW/VPMADDWD operations.
 static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
                              TargetLowering::DAGCombinerInfo &DCI) {
+  EVT VT = N->getValueType(0);
   SDValue LHS = N->getOperand(0);
   SDValue RHS = N->getOperand(1);
 
@@ -52455,7 +52463,14 @@ static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
   // Don't return LHS/RHS as it may contain UNDEFs.
   if (ISD::isBuildVectorAllZeros(LHS.getNode()) ||
       ISD::isBuildVectorAllZeros(RHS.getNode()))
-    return DAG.getConstant(0, SDLoc(N), N->getValueType(0));
+    return DAG.getConstant(0, SDLoc(N), VT);
+
+  APInt KnownUndef, KnownZero;
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  APInt DemandedElts = APInt::getAllOnes(VT.getVectorNumElements());
+  if (TLI.SimplifyDemandedVectorElts(SDValue(N, 0), DemandedElts, KnownUndef,
+                                     KnownZero, DCI))
+    return SDValue(N, 0);
 
   return SDValue();
 }

diff  --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll
index a0a3346d7e9fc..23dec421a3c14 100644
--- a/llvm/test/CodeGen/X86/madd.ll
+++ b/llvm/test/CodeGen/X86/madd.ll
@@ -17,10 +17,10 @@ define i32 @_Z10test_shortPsS_i_128(i16* nocapture readonly, i16* nocapture read
 ; SSE2-NEXT:    # =>This Inner Loop Header: Depth=1
 ; SSE2-NEXT:    movq {{.*#+}} xmm2 = mem[0],zero
 ; SSE2-NEXT:    movq {{.*#+}} xmm3 = mem[0],zero
+; SSE2-NEXT:    punpcklwd {{.*#+}} xmm3 = xmm3[0,0,1,1,2,2,3,3]
 ; SSE2-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3]
-; SSE2-NEXT:    punpcklwd {{.*#+}} xmm3 = xmm3[0],xmm0[0],xmm3[1],xmm0[1],xmm3[2],xmm0[2],xmm3[3],xmm0[3]
-; SSE2-NEXT:    pmaddwd %xmm2, %xmm3
-; SSE2-NEXT:    paddd %xmm3, %xmm1
+; SSE2-NEXT:    pmaddwd %xmm3, %xmm2
+; SSE2-NEXT:    paddd %xmm2, %xmm1
 ; SSE2-NEXT:    addq $8, %rcx
 ; SSE2-NEXT:    cmpq %rcx, %rax
 ; SSE2-NEXT:    jne .LBB0_1
@@ -1859,6 +1859,7 @@ define <4 x i32> @pmaddwd_8_swapped(<8 x i16> %A, <8 x i16> %B) {
    ret <4 x i32> %ret
 }
 
+; FIXME: SSE fails to match PMADDWD
 define <4 x i32> @larger_mul(<16 x i16> %A, <16 x i16> %B) {
 ; SSE2-LABEL: larger_mul:
 ; SSE2:       # %bb.0:

diff  --git a/llvm/test/CodeGen/X86/shrink_vmul.ll b/llvm/test/CodeGen/X86/shrink_vmul.ll
index d8e7f3358b1fc..7557b3fc28440 100644
--- a/llvm/test/CodeGen/X86/shrink_vmul.ll
+++ b/llvm/test/CodeGen/X86/shrink_vmul.ll
@@ -1079,10 +1079,10 @@ define void @mul_2xi16_sext(i8* nocapture readonly %a, i8* nocapture readonly %b
 ; X86-SSE-NEXT:    movd {{.*#+}} xmm0 = mem[0],zero,zero,zero
 ; X86-SSE-NEXT:    movd {{.*#+}} xmm1 = mem[0],zero,zero,zero
 ; X86-SSE-NEXT:    pxor %xmm2, %xmm2
-; X86-SSE-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
-; X86-SSE-NEXT:    pshuflw {{.*#+}} xmm1 = xmm1[0,1,1,3,4,5,6,7]
-; X86-SSE-NEXT:    pmaddwd %xmm0, %xmm1
-; X86-SSE-NEXT:    movq %xmm1, (%esi,%ecx,4)
+; X86-SSE-NEXT:    punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1],xmm1[2],xmm2[2],xmm1[3],xmm2[3]
+; X86-SSE-NEXT:    pshuflw {{.*#+}} xmm0 = xmm0[0,1,1,3,4,5,6,7]
+; X86-SSE-NEXT:    pmaddwd %xmm1, %xmm0
+; X86-SSE-NEXT:    movq %xmm0, (%esi,%ecx,4)
 ; X86-SSE-NEXT:    popl %esi
 ; X86-SSE-NEXT:    retl
 ;


        


More information about the llvm-commits mailing list