[llvm] [X86][AVX512] Better lowering for `_mm512_maskz_shuffle_epi32` (PR #121147)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 26 04:12:39 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-x86
Author: None (abhishek-kaushik22)
<details>
<summary>Changes</summary>
For the function (https://godbolt.org/z/4rTYeMY4b)
```
#include <immintrin.h>
__m512i foo(__m512i a){
__m512i r0 = _mm512_maskz_shuffle_epi32(0xaaaa, a, 0xab);
return r0;
}
```
The assembly generated is unnecessarily long
```
.LCPI0_1:
.byte 0
.byte 18
.byte 2
.byte 18
.byte 4
.byte 22
.byte 6
.byte 22
.byte 8
.byte 26
.byte 10
.byte 26
.byte 12
.byte 30
.byte 14
.byte 30
foo(long long vector[8]):
vpmovsxbd zmm2, xmmword ptr [rip + .LCPI0_1]
vpxor xmm1, xmm1, xmm1
vpermt2d zmm1, zmm2, zmm0
vmovdqa64 zmm0, zmm1
ret
```
Instead we could simply generate a `vpshufd {{.*#+}} zmm0 {%k1} {z}` instruction and pass the mask and the `imm8` value to it.
The selection dag generated from the IR doesn't contain the mask and the `imm8` value directly but there is a pattern that can be matched here.
```
t6: v16i32 = BUILD_VECTOR Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32
t2: v8i64,ch = CopyFromReg t0, Register:v8i64 %0
t3: v16i32 = bitcast t2
t7: v16i32 = vector_shuffle<0,18,2,18,4,22,6,22,8,26,10,26,12,30,14,30> t6, t3
t8: v8i64 = bitcast t7
```
I've tried to match this pattern to get the value of the mask and imm8, and generate a `VSELECT` node. The resulting assembly looks like
```
movw $-21846, %ax # imm = 0xAAAA
kmovw %eax, %k1
vpshufd $136, %zmm0, %zmm0 {%k1} {z} # zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
retq
```
---
Full diff: https://github.com/llvm/llvm-project/pull/121147.diff
1 Files Affected:
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+55)
``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index e7f6032ee7d749..ca07b81f3fb984 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -17172,6 +17172,58 @@ static SDValue lowerV8I64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
return lowerShuffleWithPERMV(DL, MVT::v8i64, Mask, V1, V2, Subtarget, DAG);
}
+static SDValue lowerShuffleAsVSELECT(const SDLoc &DL,
+ ArrayRef<int> RepeatedMask, SDValue V1,
+ SDValue V2, SelectionDAG &DAG) {
+ if (V1.getOpcode() != ISD::BUILD_VECTOR &&
+ V2.getOpcode() != ISD::BUILD_VECTOR)
+ return SDValue();
+ SDValue BuildVector;
+ if (V1.getOpcode() == ISD::BUILD_VECTOR) {
+ BuildVector = V1;
+ if (V2.getOpcode() != ISD::BITCAST)
+ return SDValue();
+ } else {
+ BuildVector = V2;
+ if (V1.getOpcode() != ISD::BITCAST)
+ return SDValue();
+ }
+ if (!ISD::isBuildVectorAllZeros(BuildVector.getNode()))
+ return SDValue();
+ APInt DestMask(16, 0);
+ for (unsigned i = 0; i < 16; ++i) {
+ SDValue Op = BuildVector->getOperand(i);
+ if (Op.isUndef())
+ DestMask.setBit(i);
+ }
+ if (DestMask.isZero())
+ return SDValue();
+
+ unsigned Imm8 = 0;
+ for (unsigned i = 0; i < 4; ++i) {
+ if (V1.getOpcode() != ISD::BUILD_VECTOR) {
+ if (RepeatedMask[i] >= 4) {
+ continue;
+ }
+ } else if (RepeatedMask[i] < 4) {
+ continue;
+ }
+ Imm8 += (RepeatedMask[i] % 4) << (2 * i);
+ }
+
+ SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1,
+ DAG.getConstant(DestMask, DL, MVT::i16));
+
+ std::vector<SDValue> ZeroElements(16, DAG.getConstant(0, DL, MVT::i32));
+ SDValue Zeros = DAG.getBuildVector(MVT::v16i32, DL, ZeroElements);
+
+ return DAG.getNode(ISD::VSELECT, DL, MVT::v16i32, Bitcast,
+ DAG.getNode(X86ISD::PSHUFD, DL, MVT::v16i32,
+ V1.getOpcode() != ISD::BUILD_VECTOR ? V1 : V2,
+ DAG.getTargetConstant(Imm8, DL, MVT::i8)),
+ Zeros);
+}
+
/// Handle lowering of 16-lane 32-bit integer shuffles.
static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
const APInt &Zeroable, SDValue V1, SDValue V2,
@@ -17217,6 +17269,9 @@ static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
// Use dedicated unpack instructions for masks that match their pattern.
if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16i32, V1, V2, Mask, DAG))
return V;
+
+ if (SDValue V = lowerShuffleAsVSELECT(DL, RepeatedMask, V1, V2, DAG))
+ return V;
}
// Try to use shift instructions.
``````````
</details>
https://github.com/llvm/llvm-project/pull/121147
More information about the llvm-commits
mailing list