[llvm] d8fc9f8 - [X86][SSE] combineMulToPMADDWD - replace sext(v8i16) -> zext(v8i16)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 24 08:42:38 PDT 2021


Author: Simon Pilgrim
Date: 2021-09-24T16:42:01+01:00
New Revision: d8fc9f87270146e271eddd551ea98580bef15e82

URL: https://github.com/llvm/llvm-project/commit/d8fc9f87270146e271eddd551ea98580bef15e82
DIFF: https://github.com/llvm/llvm-project/commit/d8fc9f87270146e271eddd551ea98580bef15e82.diff

LOG: [X86][SSE] combineMulToPMADDWD - replace sext(v8i16) -> zext(v8i16)

As suggested on D108522, if we're sign extending a v4i16 source before multiplying as a v4i32, then we can replace that with a zero extension and rely on the implicit sign-extension of PMADDWD.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/madd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c1e8531207c6..1e79620e1904 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44258,10 +44258,29 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
   if (DAG.ComputeNumSignBits(N1) < 17 || DAG.ComputeNumSignBits(N0) < 17)
     return SDValue();
 
-  // At least one of the elements must be zero in the upper 17 bits.
-  APInt Mask17 = APInt::getHighBitsSet(32, 17);
-  if (!DAG.MaskedValueIsZero(N1, Mask17) && !DAG.MaskedValueIsZero(N0, Mask17))
+  // At least one of the elements must be zero in the upper 17 bits, or can be
+  // safely made zero without altering the final result.
+  auto GetZeroableOp = [&](SDValue Op) {
+    APInt Mask17 = APInt::getHighBitsSet(32, 17);
+    if (DAG.MaskedValueIsZero(Op, Mask17))
+      return Op;
+    // Convert sext(vXi16) to zext(vXi16).
+    // TODO: Enable pre-SSE41 once we can prefer MULHU/MULHS first.
+    // TODO: Handle sext from smaller types as well?
+    if (Op.getOpcode() == ISD::SIGN_EXTEND && VT.is128BitVector() &&
+        Subtarget.hasSSE41() && N->isOnlyUserOf(Op.getNode())) {
+      SDValue Src = Op.getOperand(0);
+      if (Src.getScalarValueSizeInBits() == 16)
+        return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, Src);
+    }
+    return SDValue();
+  };
+  SDValue ZeroN0 = GetZeroableOp(N0);
+  SDValue ZeroN1 = GetZeroableOp(N1);
+  if (!ZeroN0 && !ZeroN1)
     return SDValue();
+  N0 = ZeroN0 ? ZeroN0 : N0;
+  N1 = ZeroN1 ? ZeroN1 : N1;
 
   // Use SplitOpsAndApply to handle AVX splitting.
   auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,

diff  --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll
index 6b83d2fc5e3b..330ca9061867 100644
--- a/llvm/test/CodeGen/X86/madd.ll
+++ b/llvm/test/CodeGen/X86/madd.ll
@@ -40,9 +40,9 @@ define i32 @_Z10test_shortPsS_i_128(i16* nocapture readonly, i16* nocapture read
 ; AVX-NEXT:    .p2align 4, 0x90
 ; AVX-NEXT:  .LBB0_1: # %vector.body
 ; AVX-NEXT:    # =>This Inner Loop Header: Depth=1
-; AVX-NEXT:    vpmovsxwd (%rdi,%rcx,2), %xmm1
-; AVX-NEXT:    vpmovsxwd (%rsi,%rcx,2), %xmm2
-; AVX-NEXT:    vpmulld %xmm1, %xmm2, %xmm1
+; AVX-NEXT:    vpmovzxwd {{.*#+}} xmm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero
+; AVX-NEXT:    vpmovzxwd {{.*#+}} xmm2 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero
+; AVX-NEXT:    vpmaddwd %xmm1, %xmm2, %xmm1
 ; AVX-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
 ; AVX-NEXT:    addq $8, %rcx
 ; AVX-NEXT:    cmpq %rcx, %rax
@@ -2603,16 +2603,13 @@ define <4 x i32> @pmaddwd_bad_indices(<8 x i16>* %Aptr, <8 x i16>* %Bptr) {
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vmovdqa (%rdi), %xmm0
 ; AVX-NEXT:    vmovdqa (%rsi), %xmm1
-; AVX-NEXT:    vpshufb {{.*#+}} xmm2 = xmm0[2,3,4,5,10,11,12,13,u,u,u,u,u,u,u,u]
-; AVX-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[0,1,6,7,8,9,14,15,u,u,u,u,u,u,u,u]
-; AVX-NEXT:    vpshufb {{.*#+}} xmm3 = xmm1[0,1,4,5,8,9,12,13,u,u,u,u,u,u,u,u]
-; AVX-NEXT:    vpshufb {{.*#+}} xmm1 = xmm1[2,3,6,7,10,11,14,15,u,u,u,u,u,u,u,u]
-; AVX-NEXT:    vpmovsxwd %xmm2, %xmm2
-; AVX-NEXT:    vpmovsxwd %xmm3, %xmm3
-; AVX-NEXT:    vpmulld %xmm3, %xmm2, %xmm2
-; AVX-NEXT:    vpmovsxwd %xmm0, %xmm0
-; AVX-NEXT:    vpmovsxwd %xmm1, %xmm1
-; AVX-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX-NEXT:    vpblendw {{.*#+}} xmm2 = xmm1[0],xmm2[1],xmm1[2],xmm2[3],xmm1[4],xmm2[5],xmm1[6],xmm2[7]
+; AVX-NEXT:    vpshufb {{.*#+}} xmm3 = xmm0[2,3],zero,zero,xmm0[4,5],zero,zero,xmm0[10,11],zero,zero,xmm0[12,13],zero,zero
+; AVX-NEXT:    vpmaddwd %xmm2, %xmm3, %xmm2
+; AVX-NEXT:    vpsrld $16, %xmm1, %xmm1
+; AVX-NEXT:    vpshufb {{.*#+}} xmm0 = xmm0[0,1],zero,zero,xmm0[6,7],zero,zero,xmm0[8,9],zero,zero,xmm0[14,15],zero,zero
+; AVX-NEXT:    vpmaddwd %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
 ; AVX-NEXT:    retq
   %A = load <8 x i16>, <8 x i16>* %Aptr
@@ -3105,38 +3102,28 @@ define <4 x i32> @output_size_mismatch_high_subvector(<16 x i16> %x, <16 x i16>
 ;
 ; AVX1-LABEL: output_size_mismatch_high_subvector:
 ; AVX1:       # %bb.0:
-; AVX1-NEXT:    vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
 ; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm0
-; AVX1-NEXT:    vpshufb %xmm2, %xmm0, %xmm3
-; AVX1-NEXT:    vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
-; AVX1-NEXT:    vpshufb %xmm4, %xmm0, %xmm0
-; AVX1-NEXT:    vpshufb %xmm2, %xmm1, %xmm2
-; AVX1-NEXT:    vpshufb %xmm4, %xmm1, %xmm1
-; AVX1-NEXT:    vpmovsxwd %xmm3, %xmm3
-; AVX1-NEXT:    vpmovsxwd %xmm0, %xmm0
-; AVX1-NEXT:    vpmovsxwd %xmm2, %xmm2
-; AVX1-NEXT:    vpmulld %xmm2, %xmm3, %xmm2
-; AVX1-NEXT:    vpmovsxwd %xmm1, %xmm1
-; AVX1-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
+; AVX1-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX1-NEXT:    vpblendw {{.*#+}} xmm3 = xmm1[0],xmm2[1],xmm1[2],xmm2[3],xmm1[4],xmm2[5],xmm1[6],xmm2[7]
+; AVX1-NEXT:    vpblendw {{.*#+}} xmm2 = xmm0[0],xmm2[1],xmm0[2],xmm2[3],xmm0[4],xmm2[5],xmm0[6],xmm2[7]
+; AVX1-NEXT:    vpmaddwd %xmm3, %xmm2, %xmm2
+; AVX1-NEXT:    vpsrld $16, %xmm1, %xmm1
+; AVX1-NEXT:    vpsrld $16, %xmm0, %xmm0
+; AVX1-NEXT:    vpmaddwd %xmm1, %xmm0, %xmm0
 ; AVX1-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
 ; AVX1-NEXT:    vzeroupper
 ; AVX1-NEXT:    retq
 ;
 ; AVX256-LABEL: output_size_mismatch_high_subvector:
 ; AVX256:       # %bb.0:
-; AVX256-NEXT:    vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
 ; AVX256-NEXT:    vextracti128 $1, %ymm0, %xmm0
-; AVX256-NEXT:    vpshufb %xmm2, %xmm0, %xmm3
-; AVX256-NEXT:    vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
-; AVX256-NEXT:    vpshufb %xmm4, %xmm0, %xmm0
-; AVX256-NEXT:    vpshufb %xmm2, %xmm1, %xmm2
-; AVX256-NEXT:    vpshufb %xmm4, %xmm1, %xmm1
-; AVX256-NEXT:    vpmovsxwd %xmm3, %xmm3
-; AVX256-NEXT:    vpmovsxwd %xmm0, %xmm0
-; AVX256-NEXT:    vpmovsxwd %xmm2, %xmm2
-; AVX256-NEXT:    vpmulld %xmm2, %xmm3, %xmm2
-; AVX256-NEXT:    vpmovsxwd %xmm1, %xmm1
-; AVX256-NEXT:    vpmulld %xmm1, %xmm0, %xmm0
+; AVX256-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX256-NEXT:    vpblendw {{.*#+}} xmm3 = xmm1[0],xmm2[1],xmm1[2],xmm2[3],xmm1[4],xmm2[5],xmm1[6],xmm2[7]
+; AVX256-NEXT:    vpblendw {{.*#+}} xmm2 = xmm0[0],xmm2[1],xmm0[2],xmm2[3],xmm0[4],xmm2[5],xmm0[6],xmm2[7]
+; AVX256-NEXT:    vpmaddwd %xmm3, %xmm2, %xmm2
+; AVX256-NEXT:    vpsrld $16, %xmm1, %xmm1
+; AVX256-NEXT:    vpsrld $16, %xmm0, %xmm0
+; AVX256-NEXT:    vpmaddwd %xmm1, %xmm0, %xmm0
 ; AVX256-NEXT:    vpaddd %xmm0, %xmm2, %xmm0
 ; AVX256-NEXT:    vzeroupper
 ; AVX256-NEXT:    retq


        


More information about the llvm-commits mailing list