[llvm] b0acd6c - [X86] Fold PMADD(x,0) or PMADD(0,x) -> 0

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 2 02:51:12 PDT 2021


Author: Simon Pilgrim
Date: 2021-09-02T10:48:50+01:00
New Revision: b0acd6c36974f83799f2d278862081dfd4a0b2e0

URL: https://github.com/llvm/llvm-project/commit/b0acd6c36974f83799f2d278862081dfd4a0b2e0
DIFF: https://github.com/llvm/llvm-project/commit/b0acd6c36974f83799f2d278862081dfd4a0b2e0.diff

LOG: [X86] Fold PMADD(x,0) or PMADD(0,x) -> 0

Pulled out of D108522 - handle zero-operand cases for PMADDWD/VPMADDUBSW ops

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/combine-pmadd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 17fd32e3898df..f308163168f53 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -51824,6 +51824,21 @@ static SDValue combinePMULDQ(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+// Simplify VPMADDUBSW/VPMADDWD operations.
+static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
+                             TargetLowering::DAGCombinerInfo &DCI) {
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+
+  // Multiply by zero.
+  // Don't return LHS/RHS as it may contain UNDEFs.
+  if (ISD::isBuildVectorAllZeros(LHS.getNode()) ||
+      ISD::isBuildVectorAllZeros(RHS.getNode()))
+    return DAG.getConstant(0, SDLoc(N), N->getValueType(0));
+
+  return SDValue();
+}
+
 static SDValue combineEXTEND_VECTOR_INREG(SDNode *N, SelectionDAG &DAG,
                                           TargetLowering::DAGCombinerInfo &DCI,
                                           const X86Subtarget &Subtarget) {
@@ -52274,6 +52289,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, Subtarget);
   case X86ISD::PMULDQ:
   case X86ISD::PMULUDQ:     return combinePMULDQ(N, DAG, DCI, Subtarget);
+  case X86ISD::VPMADDUBSW:
+  case X86ISD::VPMADDWD:    return combineVPMADD(N, DAG, DCI);
   case X86ISD::KSHIFTL:
   case X86ISD::KSHIFTR:     return combineKSHIFT(N, DAG, DCI);
   case ISD::FP16_TO_FP:     return combineFP16_TO_FP(N, DAG, Subtarget);

diff  --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index f84e2e8c192d0..e8a61eeac6a11 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -9,14 +9,12 @@ declare <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8>, <16 x i8>) nounwind
 define <4 x i32> @combine_pmaddwd_zero(<8 x i16> %a0, <8 x i16> %a1) {
 ; SSE-LABEL: combine_pmaddwd_zero:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    pxor %xmm1, %xmm1
-; SSE-NEXT:    pmaddwd %xmm1, %xmm0
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_pmaddwd_zero:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX-NEXT:    vpmaddwd %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a0, <8 x i16> zeroinitializer)
   ret <4 x i32> %1
@@ -25,14 +23,12 @@ define <4 x i32> @combine_pmaddwd_zero(<8 x i16> %a0, <8 x i16> %a1) {
 define <4 x i32> @combine_pmaddwd_zero_commute(<8 x i16> %a0, <8 x i16> %a1) {
 ; SSE-LABEL: combine_pmaddwd_zero_commute:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    pxor %xmm1, %xmm1
-; SSE-NEXT:    pmaddwd %xmm1, %xmm0
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_pmaddwd_zero_commute:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX-NEXT:    vpmaddwd %xmm0, %xmm1, %xmm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> zeroinitializer, <8 x i16> %a0)
   ret <4 x i32> %1
@@ -41,14 +37,12 @@ define <4 x i32> @combine_pmaddwd_zero_commute(<8 x i16> %a0, <8 x i16> %a1) {
 define <8 x i16> @combine_pmaddubsw_zero(<16 x i8> %a0, <16 x i8> %a1) {
 ; SSE-LABEL: combine_pmaddubsw_zero:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    pxor %xmm1, %xmm1
-; SSE-NEXT:    pmaddubsw %xmm1, %xmm0
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_pmaddubsw_zero:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX-NEXT:    vpmaddubsw %xmm1, %xmm0, %xmm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> %a0, <16 x i8> zeroinitializer)
   ret <8 x i16> %1
@@ -57,15 +51,12 @@ define <8 x i16> @combine_pmaddubsw_zero(<16 x i8> %a0, <16 x i8> %a1) {
 define <8 x i16> @combine_pmaddubsw_zero_commute(<16 x i8> %a0, <16 x i8> %a1) {
 ; SSE-LABEL: combine_pmaddubsw_zero_commute:
 ; SSE:       # %bb.0:
-; SSE-NEXT:    pxor %xmm1, %xmm1
-; SSE-NEXT:    pmaddubsw %xmm0, %xmm1
-; SSE-NEXT:    movdqa %xmm1, %xmm0
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_pmaddubsw_zero_commute:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX-NEXT:    vpmaddubsw %xmm0, %xmm1, %xmm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> zeroinitializer, <16 x i8> %a0)
   ret <8 x i16> %1


        


More information about the llvm-commits mailing list