[llvm] 9d2351a - [X86] matchPMADDWD - add matching for (add (X, (pmaddwd Y, Z)) reassociation patterns.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 6 03:34:14 PST 2024


Author: Simon Pilgrim
Date: 2024-12-06T11:34:01Z
New Revision: 9d2351ab9aff3741e3f4e10ab7ebabc77a6079d6

URL: https://github.com/llvm/llvm-project/commit/9d2351ab9aff3741e3f4e10ab7ebabc77a6079d6
DIFF: https://github.com/llvm/llvm-project/commit/9d2351ab9aff3741e3f4e10ab7ebabc77a6079d6.diff

LOG: [X86] matchPMADDWD - add matching for (add (X, (pmaddwd Y, Z)) reassociation patterns.

Allows us to match pmaddwd accumulation patterns, and folding to vpdpwssd instructions on VNNI targets

Fixes #118433

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/vpdpwssd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index f713f2ed209e1c..ff21aa975033cf 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56471,9 +56471,12 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
       !isPowerOf2_32(VT.getVectorNumElements()))
     return SDValue();
 
-  SDValue Op0, Op1;
+  SDValue Op0, Op1, Accum;
   if (!sd_match(N, m_Add(m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op0)),
-                         m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op1)))))
+                         m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op1)))) &&
+      !sd_match(N, m_Add(m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op0)),
+                         m_Add(m_Value(Accum), m_AllOf(m_Opc(ISD::BUILD_VECTOR),
+                                                       m_Value(Op1))))))
     return SDValue();
 
   // Check if one of Op0,Op1 is of the form:
@@ -56549,7 +56552,10 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
                                  InVT.getVectorNumElements() / 2);
     return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]);
   };
-  return SplitOpsAndApply(DAG, Subtarget, DL, VT, { N0, N1 }, PMADDBuilder);
+  SDValue R = SplitOpsAndApply(DAG, Subtarget, DL, VT, {N0, N1}, PMADDBuilder);
+  if (Accum)
+    R = DAG.getNode(ISD::ADD, DL, VT, R, Accum);
+  return R;
 }
 
 // Attempt to turn this pattern into PMADDWD.

diff  --git a/llvm/test/CodeGen/X86/vpdpwssd.ll b/llvm/test/CodeGen/X86/vpdpwssd.ll
index c2c59e6be87977..f7cd6f8f1b8961 100644
--- a/llvm/test/CodeGen/X86/vpdpwssd.ll
+++ b/llvm/test/CodeGen/X86/vpdpwssd.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=znver4 | FileCheck %s --check-prefixes=CHECK,ZNVER
-; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=znver5 | FileCheck %s --check-prefixes=CHECK,ZNVER
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=znver4 | FileCheck %s --check-prefixes=CHECK,ZNVER,AVX512BW-VNNI
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=znver5 | FileCheck %s --check-prefixes=CHECK,ZNVER,AVX-VNNI
 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni,+fast-dpwssd | FileCheck %s --check-prefixes=CHECK,AVX512-VNNI
 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni,+avx512vl,+fast-dpwssd | FileCheck %s --check-prefixes=CHECK,AVX512VL-VNNI
 
@@ -16,56 +16,28 @@ define <16 x i32> @vpdpwssd_test(<16 x i32> %0, <16 x i32> %1, <16 x i32> %2) {
 define <16 x i32> @vpdpwssd_v16i32_accumulate(<32 x i16> %a0, <32 x i16> %a1, <16 x i32> %a2) {
 ; ZNVER-LABEL: vpdpwssd_v16i32_accumulate:
 ; ZNVER:       # %bb.0:
-; ZNVER-NEXT:    vpmovsxwd %ymm0, %zmm3
-; ZNVER-NEXT:    vpmovsxwd %ymm1, %zmm4
-; ZNVER-NEXT:    vextracti64x4 $1, %zmm0, %ymm0
-; ZNVER-NEXT:    vextracti64x4 $1, %zmm1, %ymm1
-; ZNVER-NEXT:    vpmovsxbd {{.*#+}} zmm5 = [0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30]
-; ZNVER-NEXT:    vpmovsxwd %ymm0, %zmm0
-; ZNVER-NEXT:    vpmovsxwd %ymm1, %zmm1
-; ZNVER-NEXT:    vpmulld %zmm4, %zmm3, %zmm3
-; ZNVER-NEXT:    vpmovsxbd {{.*#+}} zmm4 = [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31]
-; ZNVER-NEXT:    vpmulld %zmm1, %zmm0, %zmm0
-; ZNVER-NEXT:    vpermi2d %zmm0, %zmm3, %zmm5
-; ZNVER-NEXT:    vpermi2d %zmm0, %zmm3, %zmm4
-; ZNVER-NEXT:    vpaddd %zmm2, %zmm5, %zmm0
-; ZNVER-NEXT:    vpaddd %zmm4, %zmm0, %zmm0
+; ZNVER-NEXT:    vpdpwssd %zmm1, %zmm0, %zmm2
+; ZNVER-NEXT:    vmovdqa64 %zmm2, %zmm0
 ; ZNVER-NEXT:    retq
 ;
 ; AVX512-VNNI-LABEL: vpdpwssd_v16i32_accumulate:
 ; AVX512-VNNI:       # %bb.0:
-; AVX512-VNNI-NEXT:    vpmovsxwd %ymm0, %zmm3
-; AVX512-VNNI-NEXT:    vextracti64x4 $1, %zmm0, %ymm0
-; AVX512-VNNI-NEXT:    vpmovsxwd %ymm0, %zmm0
-; AVX512-VNNI-NEXT:    vpmovsxwd %ymm1, %zmm4
-; AVX512-VNNI-NEXT:    vpmulld %zmm4, %zmm3, %zmm3
-; AVX512-VNNI-NEXT:    vextracti64x4 $1, %zmm1, %ymm1
-; AVX512-VNNI-NEXT:    vpmovsxwd %ymm1, %zmm1
-; AVX512-VNNI-NEXT:    vpmulld %zmm1, %zmm0, %zmm0
-; AVX512-VNNI-NEXT:    vpmovsxbd {{.*#+}} zmm1 = [0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30]
-; AVX512-VNNI-NEXT:    vpermi2d %zmm0, %zmm3, %zmm1
-; AVX512-VNNI-NEXT:    vpmovsxbd {{.*#+}} zmm4 = [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31]
-; AVX512-VNNI-NEXT:    vpermi2d %zmm0, %zmm3, %zmm4
-; AVX512-VNNI-NEXT:    vpaddd %zmm2, %zmm1, %zmm0
-; AVX512-VNNI-NEXT:    vpaddd %zmm4, %zmm0, %zmm0
+; AVX512-VNNI-NEXT:    vextracti64x4 $1, %zmm1, %ymm3
+; AVX512-VNNI-NEXT:    vextracti64x4 $1, %zmm0, %ymm4
+; AVX512-VNNI-NEXT:    vpmaddwd %ymm3, %ymm4, %ymm3
+; AVX512-VNNI-NEXT:    vpmaddwd %ymm1, %ymm0, %ymm0
+; AVX512-VNNI-NEXT:    vinserti64x4 $1, %ymm3, %zmm0, %zmm0
+; AVX512-VNNI-NEXT:    vpaddd %zmm2, %zmm0, %zmm0
 ; AVX512-VNNI-NEXT:    retq
 ;
 ; AVX512VL-VNNI-LABEL: vpdpwssd_v16i32_accumulate:
 ; AVX512VL-VNNI:       # %bb.0:
-; AVX512VL-VNNI-NEXT:    vpmovsxwd %ymm0, %zmm3
-; AVX512VL-VNNI-NEXT:    vextracti64x4 $1, %zmm0, %ymm0
-; AVX512VL-VNNI-NEXT:    vpmovsxwd %ymm0, %zmm0
-; AVX512VL-VNNI-NEXT:    vpmovsxwd %ymm1, %zmm4
-; AVX512VL-VNNI-NEXT:    vpmulld %zmm4, %zmm3, %zmm3
-; AVX512VL-VNNI-NEXT:    vextracti64x4 $1, %zmm1, %ymm1
-; AVX512VL-VNNI-NEXT:    vpmovsxwd %ymm1, %zmm1
-; AVX512VL-VNNI-NEXT:    vpmulld %zmm1, %zmm0, %zmm0
-; AVX512VL-VNNI-NEXT:    vpmovsxbd {{.*#+}} zmm1 = [0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30]
-; AVX512VL-VNNI-NEXT:    vpermi2d %zmm0, %zmm3, %zmm1
-; AVX512VL-VNNI-NEXT:    vpmovsxbd {{.*#+}} zmm4 = [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31]
-; AVX512VL-VNNI-NEXT:    vpermi2d %zmm0, %zmm3, %zmm4
-; AVX512VL-VNNI-NEXT:    vpaddd %zmm2, %zmm1, %zmm0
-; AVX512VL-VNNI-NEXT:    vpaddd %zmm4, %zmm0, %zmm0
+; AVX512VL-VNNI-NEXT:    vextracti64x4 $1, %zmm1, %ymm3
+; AVX512VL-VNNI-NEXT:    vextracti64x4 $1, %zmm0, %ymm4
+; AVX512VL-VNNI-NEXT:    vpmaddwd %ymm3, %ymm4, %ymm3
+; AVX512VL-VNNI-NEXT:    vpmaddwd %ymm1, %ymm0, %ymm0
+; AVX512VL-VNNI-NEXT:    vinserti64x4 $1, %ymm3, %zmm0, %zmm0
+; AVX512VL-VNNI-NEXT:    vpaddd %zmm2, %zmm0, %zmm0
 ; AVX512VL-VNNI-NEXT:    retq
   %x0 = sext <32 x i16> %a0 to <32 x i32>
   %x1 = sext <32 x i16> %a1 to <32 x i32>
@@ -78,43 +50,28 @@ define <16 x i32> @vpdpwssd_v16i32_accumulate(<32 x i16> %a0, <32 x i16> %a1, <1
 }
 
 define <8 x i32> @vpdpwssd_v8i32_accumulate(<16 x i16> %a0, <16 x i16> %a1, <8 x i32> %a2) {
-; ZNVER-LABEL: vpdpwssd_v8i32_accumulate:
-; ZNVER:       # %bb.0:
-; ZNVER-NEXT:    vpmovsxwd %ymm0, %zmm0
-; ZNVER-NEXT:    vpmovsxwd %ymm1, %zmm1
-; ZNVER-NEXT:    vpmulld %zmm1, %zmm0, %zmm0
-; ZNVER-NEXT:    vpmovqd %zmm0, %ymm1
-; ZNVER-NEXT:    vextracti64x4 $1, %zmm0, %ymm3
-; ZNVER-NEXT:    vshufps {{.*#+}} ymm0 = ymm0[1,3],ymm3[1,3],ymm0[5,7],ymm3[5,7]
-; ZNVER-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[0,2,1,3]
-; ZNVER-NEXT:    vpaddd %ymm2, %ymm0, %ymm0
-; ZNVER-NEXT:    vpaddd %ymm0, %ymm1, %ymm0
-; ZNVER-NEXT:    retq
+; AVX512BW-VNNI-LABEL: vpdpwssd_v8i32_accumulate:
+; AVX512BW-VNNI:       # %bb.0:
+; AVX512BW-VNNI-NEXT:    vpdpwssd %ymm1, %ymm0, %ymm2
+; AVX512BW-VNNI-NEXT:    vmovdqa %ymm2, %ymm0
+; AVX512BW-VNNI-NEXT:    retq
+;
+; AVX-VNNI-LABEL: vpdpwssd_v8i32_accumulate:
+; AVX-VNNI:       # %bb.0:
+; AVX-VNNI-NEXT:    {vex} vpdpwssd %ymm1, %ymm0, %ymm2
+; AVX-VNNI-NEXT:    vmovdqa %ymm2, %ymm0
+; AVX-VNNI-NEXT:    retq
 ;
 ; AVX512-VNNI-LABEL: vpdpwssd_v8i32_accumulate:
 ; AVX512-VNNI:       # %bb.0:
-; AVX512-VNNI-NEXT:    vpmovsxwd %ymm0, %zmm0
-; AVX512-VNNI-NEXT:    vpmovsxwd %ymm1, %zmm1
-; AVX512-VNNI-NEXT:    vpmulld %zmm1, %zmm0, %zmm0
-; AVX512-VNNI-NEXT:    vpmovqd %zmm0, %ymm1
-; AVX512-VNNI-NEXT:    vextracti64x4 $1, %zmm0, %ymm3
-; AVX512-VNNI-NEXT:    vshufps {{.*#+}} ymm0 = ymm0[1,3],ymm3[1,3],ymm0[5,7],ymm3[5,7]
-; AVX512-VNNI-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[0,2,1,3]
-; AVX512-VNNI-NEXT:    vpaddd %ymm2, %ymm1, %ymm1
-; AVX512-VNNI-NEXT:    vpaddd %ymm0, %ymm1, %ymm0
+; AVX512-VNNI-NEXT:    vpmaddwd %ymm1, %ymm0, %ymm0
+; AVX512-VNNI-NEXT:    vpaddd %ymm2, %ymm0, %ymm0
 ; AVX512-VNNI-NEXT:    retq
 ;
 ; AVX512VL-VNNI-LABEL: vpdpwssd_v8i32_accumulate:
 ; AVX512VL-VNNI:       # %bb.0:
-; AVX512VL-VNNI-NEXT:    vpmovsxwd %ymm0, %zmm0
-; AVX512VL-VNNI-NEXT:    vpmovsxwd %ymm1, %zmm1
-; AVX512VL-VNNI-NEXT:    vpmulld %zmm1, %zmm0, %zmm0
-; AVX512VL-VNNI-NEXT:    vpmovqd %zmm0, %ymm1
-; AVX512VL-VNNI-NEXT:    vextracti64x4 $1, %zmm0, %ymm3
-; AVX512VL-VNNI-NEXT:    vshufps {{.*#+}} ymm0 = ymm0[1,3],ymm3[1,3],ymm0[5,7],ymm3[5,7]
-; AVX512VL-VNNI-NEXT:    vpermpd {{.*#+}} ymm0 = ymm0[0,2,1,3]
-; AVX512VL-VNNI-NEXT:    vpaddd %ymm2, %ymm1, %ymm1
-; AVX512VL-VNNI-NEXT:    vpaddd %ymm0, %ymm1, %ymm0
+; AVX512VL-VNNI-NEXT:    vpdpwssd %ymm1, %ymm0, %ymm2
+; AVX512VL-VNNI-NEXT:    vmovdqa %ymm2, %ymm0
 ; AVX512VL-VNNI-NEXT:    retq
   %x0 = sext <16 x i16> %a0 to <16 x i32>
   %x1 = sext <16 x i16> %a1 to <16 x i32>
@@ -127,43 +84,28 @@ define <8 x i32> @vpdpwssd_v8i32_accumulate(<16 x i16> %a0, <16 x i16> %a1, <8 x
 }
 
 define <4 x i32> @vpdpwssd_v4i32_accumulate(<8 x i16> %a0, <8 x i16> %a1, <4 x i32> %a2) {
-; ZNVER-LABEL: vpdpwssd_v4i32_accumulate:
-; ZNVER:       # %bb.0:
-; ZNVER-NEXT:    vpmovsxwd %xmm0, %ymm0
-; ZNVER-NEXT:    vpmovsxwd %xmm1, %ymm1
-; ZNVER-NEXT:    vpmulld %ymm1, %ymm0, %ymm0
-; ZNVER-NEXT:    vpmovqd %ymm0, %xmm1
-; ZNVER-NEXT:    vextracti128 $1, %ymm0, %xmm3
-; ZNVER-NEXT:    vshufps {{.*#+}} xmm0 = xmm0[1,3],xmm3[1,3]
-; ZNVER-NEXT:    vpaddd %xmm2, %xmm1, %xmm1
-; ZNVER-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
-; ZNVER-NEXT:    vzeroupper
-; ZNVER-NEXT:    retq
+; AVX512BW-VNNI-LABEL: vpdpwssd_v4i32_accumulate:
+; AVX512BW-VNNI:       # %bb.0:
+; AVX512BW-VNNI-NEXT:    vpdpwssd %xmm1, %xmm0, %xmm2
+; AVX512BW-VNNI-NEXT:    vmovdqa %xmm2, %xmm0
+; AVX512BW-VNNI-NEXT:    retq
+;
+; AVX-VNNI-LABEL: vpdpwssd_v4i32_accumulate:
+; AVX-VNNI:       # %bb.0:
+; AVX-VNNI-NEXT:    {vex} vpdpwssd %xmm1, %xmm0, %xmm2
+; AVX-VNNI-NEXT:    vmovdqa %xmm2, %xmm0
+; AVX-VNNI-NEXT:    retq
 ;
 ; AVX512-VNNI-LABEL: vpdpwssd_v4i32_accumulate:
 ; AVX512-VNNI:       # %bb.0:
-; AVX512-VNNI-NEXT:    vpmovsxwd %xmm0, %ymm0
-; AVX512-VNNI-NEXT:    vpmovsxwd %xmm1, %ymm1
-; AVX512-VNNI-NEXT:    vpmulld %ymm1, %ymm0, %ymm0
-; AVX512-VNNI-NEXT:    vextracti128 $1, %ymm0, %xmm1
-; AVX512-VNNI-NEXT:    vshufps {{.*#+}} xmm3 = xmm0[0,2],xmm1[0,2]
-; AVX512-VNNI-NEXT:    vshufps {{.*#+}} xmm0 = xmm0[1,3],xmm1[1,3]
-; AVX512-VNNI-NEXT:    vpaddd %xmm2, %xmm3, %xmm1
-; AVX512-VNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
-; AVX512-VNNI-NEXT:    vzeroupper
+; AVX512-VNNI-NEXT:    vpmaddwd %xmm1, %xmm0, %xmm0
+; AVX512-VNNI-NEXT:    vpaddd %xmm2, %xmm0, %xmm0
 ; AVX512-VNNI-NEXT:    retq
 ;
 ; AVX512VL-VNNI-LABEL: vpdpwssd_v4i32_accumulate:
 ; AVX512VL-VNNI:       # %bb.0:
-; AVX512VL-VNNI-NEXT:    vpmovsxwd %xmm0, %ymm0
-; AVX512VL-VNNI-NEXT:    vpmovsxwd %xmm1, %ymm1
-; AVX512VL-VNNI-NEXT:    vpmulld %ymm1, %ymm0, %ymm0
-; AVX512VL-VNNI-NEXT:    vpmovqd %ymm0, %xmm1
-; AVX512VL-VNNI-NEXT:    vextracti128 $1, %ymm0, %xmm3
-; AVX512VL-VNNI-NEXT:    vshufps {{.*#+}} xmm0 = xmm0[1,3],xmm3[1,3]
-; AVX512VL-VNNI-NEXT:    vpaddd %xmm2, %xmm1, %xmm1
-; AVX512VL-VNNI-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
-; AVX512VL-VNNI-NEXT:    vzeroupper
+; AVX512VL-VNNI-NEXT:    vpdpwssd %xmm1, %xmm0, %xmm2
+; AVX512VL-VNNI-NEXT:    vmovdqa %xmm2, %xmm0
 ; AVX512VL-VNNI-NEXT:    retq
   %x0 = sext <8 x i16> %a0 to <8 x i32>
   %x1 = sext <8 x i16> %a1 to <8 x i32>


        


More information about the llvm-commits mailing list