[llvm] d8fc9f8 - [X86][SSE] combineMulToPMADDWD - replace sext(v8i16) -> zext(v8i16)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 24 08:42:38 PDT 2021
Author: Simon Pilgrim
Date: 2021-09-24T16:42:01+01:00
New Revision: d8fc9f87270146e271eddd551ea98580bef15e82
URL: https://github.com/llvm/llvm-project/commit/d8fc9f87270146e271eddd551ea98580bef15e82
DIFF: https://github.com/llvm/llvm-project/commit/d8fc9f87270146e271eddd551ea98580bef15e82.diff
LOG: [X86][SSE] combineMulToPMADDWD - replace sext(v8i16) -> zext(v8i16)
As suggested on D108522, if we're sign extending a v4i16 source before multiplying as a v4i32, then we can replace that with a zero extension and rely on the implicit sign-extension of PMADDWD.
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/madd.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c1e8531207c6..1e79620e1904 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44258,10 +44258,29 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
if (DAG.ComputeNumSignBits(N1) < 17 || DAG.ComputeNumSignBits(N0) < 17)
return SDValue();
- // At least one of the elements must be zero in the upper 17 bits.
- APInt Mask17 = APInt::getHighBitsSet(32, 17);
- if (!DAG.MaskedValueIsZero(N1, Mask17) && !DAG.MaskedValueIsZero(N0, Mask17))
+ // At least one of the elements must be zero in the upper 17 bits, or can be
+ // safely made zero without altering the final result.
+ auto GetZeroableOp = [&](SDValue Op) {
+ APInt Mask17 = APInt::getHighBitsSet(32, 17);
+ if (DAG.MaskedValueIsZero(Op, Mask17))
+ return Op;
+ // Convert sext(vXi16) to zext(vXi16).
+ // TODO: Enable pre-SSE41 once we can prefer MULHU/MULHS first.
+ // TODO: Handle sext from smaller types as well?
+ if (Op.getOpcode() == ISD::SIGN_EXTEND && VT.is128BitVector() &&
+ Subtarget.hasSSE41() && N->isOnlyUserOf(Op.getNode())) {
+ SDValue Src = Op.getOperand(0);
+ if (Src.getScalarValueSizeInBits() == 16)
+ return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), VT, Src);
+ }
+ return SDValue();
+ };
+ SDValue ZeroN0 = GetZeroableOp(N0);
+ SDValue ZeroN1 = GetZeroableOp(N1);
+ if (!ZeroN0 && !ZeroN1)
return SDValue();
+ N0 = ZeroN0 ? ZeroN0 : N0;
+ N1 = ZeroN1 ? ZeroN1 : N1;
// Use SplitOpsAndApply to handle AVX splitting.
auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
diff --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll
index 6b83d2fc5e3b..330ca9061867 100644
--- a/llvm/test/CodeGen/X86/madd.ll
+++ b/llvm/test/CodeGen/X86/madd.ll
@@ -40,9 +40,9 @@ define i32 @_Z10test_shortPsS_i_128(i16* nocapture readonly, i16* nocapture read
; AVX-NEXT: .p2align 4, 0x90
; AVX-NEXT: .LBB0_1: # %vector.body
; AVX-NEXT: # =>This Inner Loop Header: Depth=1
-; AVX-NEXT: vpmovsxwd (%rdi,%rcx,2), %xmm1
-; AVX-NEXT: vpmovsxwd (%rsi,%rcx,2), %xmm2
-; AVX-NEXT: vpmulld %xmm1, %xmm2, %xmm1
+; AVX-NEXT: vpmovzxwd {{.*#+}} xmm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero
+; AVX-NEXT: vpmovzxwd {{.*#+}} xmm2 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero
+; AVX-NEXT: vpmaddwd %xmm1, %xmm2, %xmm1
; AVX-NEXT: vpaddd %xmm0, %xmm1, %xmm0
; AVX-NEXT: addq $8, %rcx
; AVX-NEXT: cmpq %rcx, %rax
@@ -2603,16 +2603,13 @@ define <4 x i32> @pmaddwd_bad_indices(<8 x i16>* %Aptr, <8 x i16>* %Bptr) {
; AVX: # %bb.0:
; AVX-NEXT: vmovdqa (%rdi), %xmm0
; AVX-NEXT: vmovdqa (%rsi), %xmm1
-; AVX-NEXT: vpshufb {{.*#+}} xmm2 = xmm0[2,3,4,5,10,11,12,13,u,u,u,u,u,u,u,u]
-; AVX-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,1,6,7,8,9,14,15,u,u,u,u,u,u,u,u]
-; AVX-NEXT: vpshufb {{.*#+}} xmm3 = xmm1[0,1,4,5,8,9,12,13,u,u,u,u,u,u,u,u]
-; AVX-NEXT: vpshufb {{.*#+}} xmm1 = xmm1[2,3,6,7,10,11,14,15,u,u,u,u,u,u,u,u]
-; AVX-NEXT: vpmovsxwd %xmm2, %xmm2
-; AVX-NEXT: vpmovsxwd %xmm3, %xmm3
-; AVX-NEXT: vpmulld %xmm3, %xmm2, %xmm2
-; AVX-NEXT: vpmovsxwd %xmm0, %xmm0
-; AVX-NEXT: vpmovsxwd %xmm1, %xmm1
-; AVX-NEXT: vpmulld %xmm1, %xmm0, %xmm0
+; AVX-NEXT: vpxor %xmm2, %xmm2, %xmm2
+; AVX-NEXT: vpblendw {{.*#+}} xmm2 = xmm1[0],xmm2[1],xmm1[2],xmm2[3],xmm1[4],xmm2[5],xmm1[6],xmm2[7]
+; AVX-NEXT: vpshufb {{.*#+}} xmm3 = xmm0[2,3],zero,zero,xmm0[4,5],zero,zero,xmm0[10,11],zero,zero,xmm0[12,13],zero,zero
+; AVX-NEXT: vpmaddwd %xmm2, %xmm3, %xmm2
+; AVX-NEXT: vpsrld $16, %xmm1, %xmm1
+; AVX-NEXT: vpshufb {{.*#+}} xmm0 = xmm0[0,1],zero,zero,xmm0[6,7],zero,zero,xmm0[8,9],zero,zero,xmm0[14,15],zero,zero
+; AVX-NEXT: vpmaddwd %xmm1, %xmm0, %xmm0
; AVX-NEXT: vpaddd %xmm0, %xmm2, %xmm0
; AVX-NEXT: retq
%A = load <8 x i16>, <8 x i16>* %Aptr
@@ -3105,38 +3102,28 @@ define <4 x i32> @output_size_mismatch_high_subvector(<16 x i16> %x, <16 x i16>
;
; AVX1-LABEL: output_size_mismatch_high_subvector:
; AVX1: # %bb.0:
-; AVX1-NEXT: vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm0
-; AVX1-NEXT: vpshufb %xmm2, %xmm0, %xmm3
-; AVX1-NEXT: vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
-; AVX1-NEXT: vpshufb %xmm4, %xmm0, %xmm0
-; AVX1-NEXT: vpshufb %xmm2, %xmm1, %xmm2
-; AVX1-NEXT: vpshufb %xmm4, %xmm1, %xmm1
-; AVX1-NEXT: vpmovsxwd %xmm3, %xmm3
-; AVX1-NEXT: vpmovsxwd %xmm0, %xmm0
-; AVX1-NEXT: vpmovsxwd %xmm2, %xmm2
-; AVX1-NEXT: vpmulld %xmm2, %xmm3, %xmm2
-; AVX1-NEXT: vpmovsxwd %xmm1, %xmm1
-; AVX1-NEXT: vpmulld %xmm1, %xmm0, %xmm0
+; AVX1-NEXT: vpxor %xmm2, %xmm2, %xmm2
+; AVX1-NEXT: vpblendw {{.*#+}} xmm3 = xmm1[0],xmm2[1],xmm1[2],xmm2[3],xmm1[4],xmm2[5],xmm1[6],xmm2[7]
+; AVX1-NEXT: vpblendw {{.*#+}} xmm2 = xmm0[0],xmm2[1],xmm0[2],xmm2[3],xmm0[4],xmm2[5],xmm0[6],xmm2[7]
+; AVX1-NEXT: vpmaddwd %xmm3, %xmm2, %xmm2
+; AVX1-NEXT: vpsrld $16, %xmm1, %xmm1
+; AVX1-NEXT: vpsrld $16, %xmm0, %xmm0
+; AVX1-NEXT: vpmaddwd %xmm1, %xmm0, %xmm0
; AVX1-NEXT: vpaddd %xmm0, %xmm2, %xmm0
; AVX1-NEXT: vzeroupper
; AVX1-NEXT: retq
;
; AVX256-LABEL: output_size_mismatch_high_subvector:
; AVX256: # %bb.0:
-; AVX256-NEXT: vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
; AVX256-NEXT: vextracti128 $1, %ymm0, %xmm0
-; AVX256-NEXT: vpshufb %xmm2, %xmm0, %xmm3
-; AVX256-NEXT: vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
-; AVX256-NEXT: vpshufb %xmm4, %xmm0, %xmm0
-; AVX256-NEXT: vpshufb %xmm2, %xmm1, %xmm2
-; AVX256-NEXT: vpshufb %xmm4, %xmm1, %xmm1
-; AVX256-NEXT: vpmovsxwd %xmm3, %xmm3
-; AVX256-NEXT: vpmovsxwd %xmm0, %xmm0
-; AVX256-NEXT: vpmovsxwd %xmm2, %xmm2
-; AVX256-NEXT: vpmulld %xmm2, %xmm3, %xmm2
-; AVX256-NEXT: vpmovsxwd %xmm1, %xmm1
-; AVX256-NEXT: vpmulld %xmm1, %xmm0, %xmm0
+; AVX256-NEXT: vpxor %xmm2, %xmm2, %xmm2
+; AVX256-NEXT: vpblendw {{.*#+}} xmm3 = xmm1[0],xmm2[1],xmm1[2],xmm2[3],xmm1[4],xmm2[5],xmm1[6],xmm2[7]
+; AVX256-NEXT: vpblendw {{.*#+}} xmm2 = xmm0[0],xmm2[1],xmm0[2],xmm2[3],xmm0[4],xmm2[5],xmm0[6],xmm2[7]
+; AVX256-NEXT: vpmaddwd %xmm3, %xmm2, %xmm2
+; AVX256-NEXT: vpsrld $16, %xmm1, %xmm1
+; AVX256-NEXT: vpsrld $16, %xmm0, %xmm0
+; AVX256-NEXT: vpmaddwd %xmm1, %xmm0, %xmm0
; AVX256-NEXT: vpaddd %xmm0, %xmm2, %xmm0
; AVX256-NEXT: vzeroupper
; AVX256-NEXT: retq
More information about the llvm-commits
mailing list