[llvm] 9d2351a - [X86] matchPMADDWD - add matching for (add (X, (pmaddwd Y, Z)) reassociation patterns.
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 6 03:34:14 PST 2024
Author: Simon Pilgrim
Date: 2024-12-06T11:34:01Z
New Revision: 9d2351ab9aff3741e3f4e10ab7ebabc77a6079d6
URL: https://github.com/llvm/llvm-project/commit/9d2351ab9aff3741e3f4e10ab7ebabc77a6079d6
DIFF: https://github.com/llvm/llvm-project/commit/9d2351ab9aff3741e3f4e10ab7ebabc77a6079d6.diff
LOG: [X86] matchPMADDWD - add matching for (add (X, (pmaddwd Y, Z)) reassociation patterns.
Allows us to match pmaddwd accumulation patterns, and folding to vpdpwssd instructions on VNNI targets
Fixes #118433
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/vpdpwssd.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index f713f2ed209e1c..ff21aa975033cf 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56471,9 +56471,12 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
!isPowerOf2_32(VT.getVectorNumElements()))
return SDValue();
- SDValue Op0, Op1;
+ SDValue Op0, Op1, Accum;
if (!sd_match(N, m_Add(m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op0)),
- m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op1)))))
+ m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op1)))) &&
+ !sd_match(N, m_Add(m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op0)),
+ m_Add(m_Value(Accum), m_AllOf(m_Opc(ISD::BUILD_VECTOR),
+ m_Value(Op1))))))
return SDValue();
// Check if one of Op0,Op1 is of the form:
@@ -56549,7 +56552,10 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
InVT.getVectorNumElements() / 2);
return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]);
};
- return SplitOpsAndApply(DAG, Subtarget, DL, VT, { N0, N1 }, PMADDBuilder);
+ SDValue R = SplitOpsAndApply(DAG, Subtarget, DL, VT, {N0, N1}, PMADDBuilder);
+ if (Accum)
+ R = DAG.getNode(ISD::ADD, DL, VT, R, Accum);
+ return R;
}
// Attempt to turn this pattern into PMADDWD.
diff --git a/llvm/test/CodeGen/X86/vpdpwssd.ll b/llvm/test/CodeGen/X86/vpdpwssd.ll
index c2c59e6be87977..f7cd6f8f1b8961 100644
--- a/llvm/test/CodeGen/X86/vpdpwssd.ll
+++ b/llvm/test/CodeGen/X86/vpdpwssd.ll
@@ -1,6 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=znver4 | FileCheck %s --check-prefixes=CHECK,ZNVER
-; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=znver5 | FileCheck %s --check-prefixes=CHECK,ZNVER
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=znver4 | FileCheck %s --check-prefixes=CHECK,ZNVER,AVX512BW-VNNI
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=znver5 | FileCheck %s --check-prefixes=CHECK,ZNVER,AVX-VNNI
; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni,+fast-dpwssd | FileCheck %s --check-prefixes=CHECK,AVX512-VNNI
; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni,+avx512vl,+fast-dpwssd | FileCheck %s --check-prefixes=CHECK,AVX512VL-VNNI
@@ -16,56 +16,28 @@ define <16 x i32> @vpdpwssd_test(<16 x i32> %0, <16 x i32> %1, <16 x i32> %2) {
define <16 x i32> @vpdpwssd_v16i32_accumulate(<32 x i16> %a0, <32 x i16> %a1, <16 x i32> %a2) {
; ZNVER-LABEL: vpdpwssd_v16i32_accumulate:
; ZNVER: # %bb.0:
-; ZNVER-NEXT: vpmovsxwd %ymm0, %zmm3
-; ZNVER-NEXT: vpmovsxwd %ymm1, %zmm4
-; ZNVER-NEXT: vextracti64x4 $1, %zmm0, %ymm0
-; ZNVER-NEXT: vextracti64x4 $1, %zmm1, %ymm1
-; ZNVER-NEXT: vpmovsxbd {{.*#+}} zmm5 = [0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30]
-; ZNVER-NEXT: vpmovsxwd %ymm0, %zmm0
-; ZNVER-NEXT: vpmovsxwd %ymm1, %zmm1
-; ZNVER-NEXT: vpmulld %zmm4, %zmm3, %zmm3
-; ZNVER-NEXT: vpmovsxbd {{.*#+}} zmm4 = [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31]
-; ZNVER-NEXT: vpmulld %zmm1, %zmm0, %zmm0
-; ZNVER-NEXT: vpermi2d %zmm0, %zmm3, %zmm5
-; ZNVER-NEXT: vpermi2d %zmm0, %zmm3, %zmm4
-; ZNVER-NEXT: vpaddd %zmm2, %zmm5, %zmm0
-; ZNVER-NEXT: vpaddd %zmm4, %zmm0, %zmm0
+; ZNVER-NEXT: vpdpwssd %zmm1, %zmm0, %zmm2
+; ZNVER-NEXT: vmovdqa64 %zmm2, %zmm0
; ZNVER-NEXT: retq
;
; AVX512-VNNI-LABEL: vpdpwssd_v16i32_accumulate:
; AVX512-VNNI: # %bb.0:
-; AVX512-VNNI-NEXT: vpmovsxwd %ymm0, %zmm3
-; AVX512-VNNI-NEXT: vextracti64x4 $1, %zmm0, %ymm0
-; AVX512-VNNI-NEXT: vpmovsxwd %ymm0, %zmm0
-; AVX512-VNNI-NEXT: vpmovsxwd %ymm1, %zmm4
-; AVX512-VNNI-NEXT: vpmulld %zmm4, %zmm3, %zmm3
-; AVX512-VNNI-NEXT: vextracti64x4 $1, %zmm1, %ymm1
-; AVX512-VNNI-NEXT: vpmovsxwd %ymm1, %zmm1
-; AVX512-VNNI-NEXT: vpmulld %zmm1, %zmm0, %zmm0
-; AVX512-VNNI-NEXT: vpmovsxbd {{.*#+}} zmm1 = [0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30]
-; AVX512-VNNI-NEXT: vpermi2d %zmm0, %zmm3, %zmm1
-; AVX512-VNNI-NEXT: vpmovsxbd {{.*#+}} zmm4 = [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31]
-; AVX512-VNNI-NEXT: vpermi2d %zmm0, %zmm3, %zmm4
-; AVX512-VNNI-NEXT: vpaddd %zmm2, %zmm1, %zmm0
-; AVX512-VNNI-NEXT: vpaddd %zmm4, %zmm0, %zmm0
+; AVX512-VNNI-NEXT: vextracti64x4 $1, %zmm1, %ymm3
+; AVX512-VNNI-NEXT: vextracti64x4 $1, %zmm0, %ymm4
+; AVX512-VNNI-NEXT: vpmaddwd %ymm3, %ymm4, %ymm3
+; AVX512-VNNI-NEXT: vpmaddwd %ymm1, %ymm0, %ymm0
+; AVX512-VNNI-NEXT: vinserti64x4 $1, %ymm3, %zmm0, %zmm0
+; AVX512-VNNI-NEXT: vpaddd %zmm2, %zmm0, %zmm0
; AVX512-VNNI-NEXT: retq
;
; AVX512VL-VNNI-LABEL: vpdpwssd_v16i32_accumulate:
; AVX512VL-VNNI: # %bb.0:
-; AVX512VL-VNNI-NEXT: vpmovsxwd %ymm0, %zmm3
-; AVX512VL-VNNI-NEXT: vextracti64x4 $1, %zmm0, %ymm0
-; AVX512VL-VNNI-NEXT: vpmovsxwd %ymm0, %zmm0
-; AVX512VL-VNNI-NEXT: vpmovsxwd %ymm1, %zmm4
-; AVX512VL-VNNI-NEXT: vpmulld %zmm4, %zmm3, %zmm3
-; AVX512VL-VNNI-NEXT: vextracti64x4 $1, %zmm1, %ymm1
-; AVX512VL-VNNI-NEXT: vpmovsxwd %ymm1, %zmm1
-; AVX512VL-VNNI-NEXT: vpmulld %zmm1, %zmm0, %zmm0
-; AVX512VL-VNNI-NEXT: vpmovsxbd {{.*#+}} zmm1 = [0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30]
-; AVX512VL-VNNI-NEXT: vpermi2d %zmm0, %zmm3, %zmm1
-; AVX512VL-VNNI-NEXT: vpmovsxbd {{.*#+}} zmm4 = [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31]
-; AVX512VL-VNNI-NEXT: vpermi2d %zmm0, %zmm3, %zmm4
-; AVX512VL-VNNI-NEXT: vpaddd %zmm2, %zmm1, %zmm0
-; AVX512VL-VNNI-NEXT: vpaddd %zmm4, %zmm0, %zmm0
+; AVX512VL-VNNI-NEXT: vextracti64x4 $1, %zmm1, %ymm3
+; AVX512VL-VNNI-NEXT: vextracti64x4 $1, %zmm0, %ymm4
+; AVX512VL-VNNI-NEXT: vpmaddwd %ymm3, %ymm4, %ymm3
+; AVX512VL-VNNI-NEXT: vpmaddwd %ymm1, %ymm0, %ymm0
+; AVX512VL-VNNI-NEXT: vinserti64x4 $1, %ymm3, %zmm0, %zmm0
+; AVX512VL-VNNI-NEXT: vpaddd %zmm2, %zmm0, %zmm0
; AVX512VL-VNNI-NEXT: retq
%x0 = sext <32 x i16> %a0 to <32 x i32>
%x1 = sext <32 x i16> %a1 to <32 x i32>
@@ -78,43 +50,28 @@ define <16 x i32> @vpdpwssd_v16i32_accumulate(<32 x i16> %a0, <32 x i16> %a1, <1
}
define <8 x i32> @vpdpwssd_v8i32_accumulate(<16 x i16> %a0, <16 x i16> %a1, <8 x i32> %a2) {
-; ZNVER-LABEL: vpdpwssd_v8i32_accumulate:
-; ZNVER: # %bb.0:
-; ZNVER-NEXT: vpmovsxwd %ymm0, %zmm0
-; ZNVER-NEXT: vpmovsxwd %ymm1, %zmm1
-; ZNVER-NEXT: vpmulld %zmm1, %zmm0, %zmm0
-; ZNVER-NEXT: vpmovqd %zmm0, %ymm1
-; ZNVER-NEXT: vextracti64x4 $1, %zmm0, %ymm3
-; ZNVER-NEXT: vshufps {{.*#+}} ymm0 = ymm0[1,3],ymm3[1,3],ymm0[5,7],ymm3[5,7]
-; ZNVER-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,2,1,3]
-; ZNVER-NEXT: vpaddd %ymm2, %ymm0, %ymm0
-; ZNVER-NEXT: vpaddd %ymm0, %ymm1, %ymm0
-; ZNVER-NEXT: retq
+; AVX512BW-VNNI-LABEL: vpdpwssd_v8i32_accumulate:
+; AVX512BW-VNNI: # %bb.0:
+; AVX512BW-VNNI-NEXT: vpdpwssd %ymm1, %ymm0, %ymm2
+; AVX512BW-VNNI-NEXT: vmovdqa %ymm2, %ymm0
+; AVX512BW-VNNI-NEXT: retq
+;
+; AVX-VNNI-LABEL: vpdpwssd_v8i32_accumulate:
+; AVX-VNNI: # %bb.0:
+; AVX-VNNI-NEXT: {vex} vpdpwssd %ymm1, %ymm0, %ymm2
+; AVX-VNNI-NEXT: vmovdqa %ymm2, %ymm0
+; AVX-VNNI-NEXT: retq
;
; AVX512-VNNI-LABEL: vpdpwssd_v8i32_accumulate:
; AVX512-VNNI: # %bb.0:
-; AVX512-VNNI-NEXT: vpmovsxwd %ymm0, %zmm0
-; AVX512-VNNI-NEXT: vpmovsxwd %ymm1, %zmm1
-; AVX512-VNNI-NEXT: vpmulld %zmm1, %zmm0, %zmm0
-; AVX512-VNNI-NEXT: vpmovqd %zmm0, %ymm1
-; AVX512-VNNI-NEXT: vextracti64x4 $1, %zmm0, %ymm3
-; AVX512-VNNI-NEXT: vshufps {{.*#+}} ymm0 = ymm0[1,3],ymm3[1,3],ymm0[5,7],ymm3[5,7]
-; AVX512-VNNI-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,2,1,3]
-; AVX512-VNNI-NEXT: vpaddd %ymm2, %ymm1, %ymm1
-; AVX512-VNNI-NEXT: vpaddd %ymm0, %ymm1, %ymm0
+; AVX512-VNNI-NEXT: vpmaddwd %ymm1, %ymm0, %ymm0
+; AVX512-VNNI-NEXT: vpaddd %ymm2, %ymm0, %ymm0
; AVX512-VNNI-NEXT: retq
;
; AVX512VL-VNNI-LABEL: vpdpwssd_v8i32_accumulate:
; AVX512VL-VNNI: # %bb.0:
-; AVX512VL-VNNI-NEXT: vpmovsxwd %ymm0, %zmm0
-; AVX512VL-VNNI-NEXT: vpmovsxwd %ymm1, %zmm1
-; AVX512VL-VNNI-NEXT: vpmulld %zmm1, %zmm0, %zmm0
-; AVX512VL-VNNI-NEXT: vpmovqd %zmm0, %ymm1
-; AVX512VL-VNNI-NEXT: vextracti64x4 $1, %zmm0, %ymm3
-; AVX512VL-VNNI-NEXT: vshufps {{.*#+}} ymm0 = ymm0[1,3],ymm3[1,3],ymm0[5,7],ymm3[5,7]
-; AVX512VL-VNNI-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,2,1,3]
-; AVX512VL-VNNI-NEXT: vpaddd %ymm2, %ymm1, %ymm1
-; AVX512VL-VNNI-NEXT: vpaddd %ymm0, %ymm1, %ymm0
+; AVX512VL-VNNI-NEXT: vpdpwssd %ymm1, %ymm0, %ymm2
+; AVX512VL-VNNI-NEXT: vmovdqa %ymm2, %ymm0
; AVX512VL-VNNI-NEXT: retq
%x0 = sext <16 x i16> %a0 to <16 x i32>
%x1 = sext <16 x i16> %a1 to <16 x i32>
@@ -127,43 +84,28 @@ define <8 x i32> @vpdpwssd_v8i32_accumulate(<16 x i16> %a0, <16 x i16> %a1, <8 x
}
define <4 x i32> @vpdpwssd_v4i32_accumulate(<8 x i16> %a0, <8 x i16> %a1, <4 x i32> %a2) {
-; ZNVER-LABEL: vpdpwssd_v4i32_accumulate:
-; ZNVER: # %bb.0:
-; ZNVER-NEXT: vpmovsxwd %xmm0, %ymm0
-; ZNVER-NEXT: vpmovsxwd %xmm1, %ymm1
-; ZNVER-NEXT: vpmulld %ymm1, %ymm0, %ymm0
-; ZNVER-NEXT: vpmovqd %ymm0, %xmm1
-; ZNVER-NEXT: vextracti128 $1, %ymm0, %xmm3
-; ZNVER-NEXT: vshufps {{.*#+}} xmm0 = xmm0[1,3],xmm3[1,3]
-; ZNVER-NEXT: vpaddd %xmm2, %xmm1, %xmm1
-; ZNVER-NEXT: vpaddd %xmm1, %xmm0, %xmm0
-; ZNVER-NEXT: vzeroupper
-; ZNVER-NEXT: retq
+; AVX512BW-VNNI-LABEL: vpdpwssd_v4i32_accumulate:
+; AVX512BW-VNNI: # %bb.0:
+; AVX512BW-VNNI-NEXT: vpdpwssd %xmm1, %xmm0, %xmm2
+; AVX512BW-VNNI-NEXT: vmovdqa %xmm2, %xmm0
+; AVX512BW-VNNI-NEXT: retq
+;
+; AVX-VNNI-LABEL: vpdpwssd_v4i32_accumulate:
+; AVX-VNNI: # %bb.0:
+; AVX-VNNI-NEXT: {vex} vpdpwssd %xmm1, %xmm0, %xmm2
+; AVX-VNNI-NEXT: vmovdqa %xmm2, %xmm0
+; AVX-VNNI-NEXT: retq
;
; AVX512-VNNI-LABEL: vpdpwssd_v4i32_accumulate:
; AVX512-VNNI: # %bb.0:
-; AVX512-VNNI-NEXT: vpmovsxwd %xmm0, %ymm0
-; AVX512-VNNI-NEXT: vpmovsxwd %xmm1, %ymm1
-; AVX512-VNNI-NEXT: vpmulld %ymm1, %ymm0, %ymm0
-; AVX512-VNNI-NEXT: vextracti128 $1, %ymm0, %xmm1
-; AVX512-VNNI-NEXT: vshufps {{.*#+}} xmm3 = xmm0[0,2],xmm1[0,2]
-; AVX512-VNNI-NEXT: vshufps {{.*#+}} xmm0 = xmm0[1,3],xmm1[1,3]
-; AVX512-VNNI-NEXT: vpaddd %xmm2, %xmm3, %xmm1
-; AVX512-VNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0
-; AVX512-VNNI-NEXT: vzeroupper
+; AVX512-VNNI-NEXT: vpmaddwd %xmm1, %xmm0, %xmm0
+; AVX512-VNNI-NEXT: vpaddd %xmm2, %xmm0, %xmm0
; AVX512-VNNI-NEXT: retq
;
; AVX512VL-VNNI-LABEL: vpdpwssd_v4i32_accumulate:
; AVX512VL-VNNI: # %bb.0:
-; AVX512VL-VNNI-NEXT: vpmovsxwd %xmm0, %ymm0
-; AVX512VL-VNNI-NEXT: vpmovsxwd %xmm1, %ymm1
-; AVX512VL-VNNI-NEXT: vpmulld %ymm1, %ymm0, %ymm0
-; AVX512VL-VNNI-NEXT: vpmovqd %ymm0, %xmm1
-; AVX512VL-VNNI-NEXT: vextracti128 $1, %ymm0, %xmm3
-; AVX512VL-VNNI-NEXT: vshufps {{.*#+}} xmm0 = xmm0[1,3],xmm3[1,3]
-; AVX512VL-VNNI-NEXT: vpaddd %xmm2, %xmm1, %xmm1
-; AVX512VL-VNNI-NEXT: vpaddd %xmm1, %xmm0, %xmm0
-; AVX512VL-VNNI-NEXT: vzeroupper
+; AVX512VL-VNNI-NEXT: vpdpwssd %xmm1, %xmm0, %xmm2
+; AVX512VL-VNNI-NEXT: vmovdqa %xmm2, %xmm0
; AVX512VL-VNNI-NEXT: retq
%x0 = sext <8 x i16> %a0 to <8 x i32>
%x1 = sext <8 x i16> %a1 to <8 x i32>
More information about the llvm-commits
mailing list