[llvm] [X86][AVX512] Better lowering for `_mm512_maskz_shuffle_epi32` (PR #121147)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 26 04:14:45 PST 2024
https://github.com/abhishek-kaushik22 updated https://github.com/llvm/llvm-project/pull/121147
>From 4f2e5e43d117283a27309bd5661128b28fc466fe Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Thu, 26 Dec 2024 17:40:36 +0530
Subject: [PATCH 1/2] [X86][AVX512] Better lowering for
`_mm512_maskz_shuffle_epi32`
For the function (https://godbolt.org/z/4rTYeMY4b)
```
#include <immintrin.h>
__m512i foo(__m512i a){
__m512i r0 = _mm512_maskz_shuffle_epi32(0xaaaa, a, 0xab);
return r0;
}
```
The assembly generated is unnecessarily long
```
.LCPI0_1:
.byte 0
.byte 18
.byte 2
.byte 18
.byte 4
.byte 22
.byte 6
.byte 22
.byte 8
.byte 26
.byte 10
.byte 26
.byte 12
.byte 30
.byte 14
.byte 30
foo(long long vector[8]):
vpmovsxbd zmm2, xmmword ptr [rip + .LCPI0_1]
vpxor xmm1, xmm1, xmm1
vpermt2d zmm1, zmm2, zmm0
vmovdqa64 zmm0, zmm1
ret
```
Instead we could simply generate a `vpshufd {{.*#+}} zmm0 {%k1} {z}` instruction and pass the mask and the `imm8` value to it.
The selection dag generated from the IR doesn't contain the mask and the `imm8` value directly but there is a pattern that can be matched here.
```
t6: v16i32 = BUILD_VECTOR Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32
t2: v8i64,ch = CopyFromReg t0, Register:v8i64 %0
t3: v16i32 = bitcast t2
t7: v16i32 = vector_shuffle<0,18,2,18,4,22,6,22,8,26,10,26,12,30,14,30> t6, t3
t8: v8i64 = bitcast t7
```
I've tried to match this pattern to get the value of the mask and imm8, and generate a `VSELECT` node.
The resulting assembly looks like
```
movw $-21846, %ax # imm = 0xAAAA
kmovw %eax, %k1
vpshufd $136, %zmm0, %zmm0 {%k1} {z} # zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
retq
```
---
llvm/lib/Target/X86/X86ISelLowering.cpp | 55 +++++++++++++++++++++++++
1 file changed, 55 insertions(+)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index e7f6032ee7d749..ca07b81f3fb984 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -17172,6 +17172,58 @@ static SDValue lowerV8I64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerShuffleWithPERMV(DL, MVT::v8i64, Mask, V1, V2, Subtarget, DAG);
}
+static SDValue lowerShuffleAsVSELECT(const SDLoc &DL,
+ ArrayRef<int> RepeatedMask, SDValue V1,
+ SDValue V2, SelectionDAG &DAG) {
+ if (V1.getOpcode() != ISD::BUILD_VECTOR &&
+ V2.getOpcode() != ISD::BUILD_VECTOR)
+ return SDValue();
+ SDValue BuildVector;
+ if (V1.getOpcode() == ISD::BUILD_VECTOR) {
+ BuildVector = V1;
+ if (V2.getOpcode() != ISD::BITCAST)
+ return SDValue();
+ } else {
+ BuildVector = V2;
+ if (V1.getOpcode() != ISD::BITCAST)
+ return SDValue();
+ }
+ if (!ISD::isBuildVectorAllZeros(BuildVector.getNode()))
+ return SDValue();
+ APInt DestMask(16, 0);
+ for (unsigned i = 0; i < 16; ++i) {
+ SDValue Op = BuildVector->getOperand(i);
+ if (Op.isUndef())
+ DestMask.setBit(i);
+ }
+ if (DestMask.isZero())
+ return SDValue();
+
+ unsigned Imm8 = 0;
+ for (unsigned i = 0; i < 4; ++i) {
+ if (V1.getOpcode() != ISD::BUILD_VECTOR) {
+ if (RepeatedMask[i] >= 4) {
+ continue;
+ }
+ } else if (RepeatedMask[i] < 4) {
+ continue;
+ }
+ Imm8 += (RepeatedMask[i] % 4) << (2 * i);
+ }
+
+ SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1,
+ DAG.getConstant(DestMask, DL, MVT::i16));
+
+ std::vector<SDValue> ZeroElements(16, DAG.getConstant(0, DL, MVT::i32));
+ SDValue Zeros = DAG.getBuildVector(MVT::v16i32, DL, ZeroElements);
+
+ return DAG.getNode(ISD::VSELECT, DL, MVT::v16i32, Bitcast,
+ DAG.getNode(X86ISD::PSHUFD, DL, MVT::v16i32,
+ V1.getOpcode() != ISD::BUILD_VECTOR ? V1 : V2,
+ DAG.getTargetConstant(Imm8, DL, MVT::i8)),
+ Zeros);
+}
+
/// Handle lowering of 16-lane 32-bit integer shuffles.
static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable, SDValue V1, SDValue V2,
@@ -17217,6 +17269,9 @@ static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
// Use dedicated unpack instructions for masks that match their pattern.
if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16i32, V1, V2, Mask, DAG))
return V;
+
+ if (SDValue V = lowerShuffleAsVSELECT(DL, RepeatedMask, V1, V2, DAG))
+ return V;
}
// Try to use shift instructions.
>From 32d0603bb6c8f897ce644bf1f165de82e5de6ace Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Thu, 26 Dec 2024 17:44:32 +0530
Subject: [PATCH 2/2] Add test
---
.../CodeGen/X86/vector-shuffle-512-v16.ll | 20 +++++++++++++++++++
1 file changed, 20 insertions(+)
diff --git a/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll b/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll
index d3b04878dc06d4..3a2a96fcd5de33 100644
--- a/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll
+++ b/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll
@@ -990,3 +990,23 @@ bb:
ret void
}
+define <8 x i64> @pr121147(<8 x i64> %a) {
+; AVX512F-LABEL: pr121147:
+; AVX512F: # %bb.0: # %entry
+; AVX512F-NEXT: movw $-21846, %ax # imm = 0xAAAA
+; AVX512F-NEXT: kmovw %eax, %k1
+; AVX512F-NEXT: vpshufd {{.*#+}} zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
+; AVX512F-NEXT: retq
+;
+; AVX512BW-LABEL: pr121147:
+; AVX512BW: # %bb.0: # %entry
+; AVX512BW-NEXT: movw $-21846, %ax # imm = 0xAAAA
+; AVX512BW-NEXT: kmovd %eax, %k1
+; AVX512BW-NEXT: vpshufd {{.*#+}} zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
+; AVX512BW-NEXT: retq
+entry:
+ %0 = bitcast <8 x i64> %a to <16 x i32>
+ %1 = shufflevector <16 x i32> <i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison>, <16 x i32> %0, <16 x i32> <i32 0, i32 18, i32 2, i32 18, i32 4, i32 22, i32 6, i32 22, i32 8, i32 26, i32 10, i32 26, i32 12, i32 30, i32 14, i32 30>
+ %2 = bitcast <16 x i32> %1 to <8 x i64>
+ ret <8 x i64> %2
+}
More information about the llvm-commits
mailing list