[llvm] [X86][AVX512] Better lowering for `_mm512_maskz_shuffle_epi32` (PR #121147)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 26 04:12:39 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: None (abhishek-kaushik22)

<details>
<summary>Changes</summary>

For the function (https://godbolt.org/z/4rTYeMY4b)
```
#include <immintrin.h>

__m512i foo(__m512i a){
    __m512i r0 = _mm512_maskz_shuffle_epi32(0xaaaa, a, 0xab);
    return r0;
}
```
The assembly generated is unnecessarily long
```
.LCPI0_1:
        .byte   0
        .byte   18
        .byte   2
        .byte   18
        .byte   4
        .byte   22
        .byte   6
        .byte   22
        .byte   8
        .byte   26
        .byte   10
        .byte   26
        .byte   12
        .byte   30
        .byte   14
        .byte   30
foo(long long vector[8]):
        vpmovsxbd       zmm2, xmmword ptr [rip + .LCPI0_1]
        vpxor   xmm1, xmm1, xmm1
        vpermt2d        zmm1, zmm2, zmm0
        vmovdqa64       zmm0, zmm1
        ret
```
Instead we could simply generate a `vpshufd {{.*#+}} zmm0 {%k1} {z}` instruction and pass the mask and the `imm8` value to it.

The selection dag generated from the IR doesn't contain the mask and the `imm8` value directly but there is a pattern that can be matched here.
```
t6: v16i32 = BUILD_VECTOR Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32
          t2: v8i64,ch = CopyFromReg t0, Register:v8i64 %0
        t3: v16i32 = bitcast t2
      t7: v16i32 = vector_shuffle<0,18,2,18,4,22,6,22,8,26,10,26,12,30,14,30> t6, t3
    t8: v8i64 = bitcast t7
```
I've tried to match this pattern to get the value of the mask and imm8, and generate a `VSELECT` node. The resulting assembly looks like
```
        movw    $-21846, %ax                    # imm = 0xAAAA
        kmovw   %eax, %k1
        vpshufd $136, %zmm0, %zmm0 {%k1} {z}    # zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
        retq
```

---
Full diff: https://github.com/llvm/llvm-project/pull/121147.diff


1 Files Affected:

- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+55) 


``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index e7f6032ee7d749..ca07b81f3fb984 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -17172,6 +17172,58 @@ static SDValue lowerV8I64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
   return lowerShuffleWithPERMV(DL, MVT::v8i64, Mask, V1, V2, Subtarget, DAG);
 }
 
+static SDValue lowerShuffleAsVSELECT(const SDLoc &DL,
+                                     ArrayRef<int> RepeatedMask, SDValue V1,
+                                     SDValue V2, SelectionDAG &DAG) {
+  if (V1.getOpcode() != ISD::BUILD_VECTOR &&
+      V2.getOpcode() != ISD::BUILD_VECTOR)
+    return SDValue();
+  SDValue BuildVector;
+  if (V1.getOpcode() == ISD::BUILD_VECTOR) {
+    BuildVector = V1;
+    if (V2.getOpcode() != ISD::BITCAST)
+      return SDValue();
+  } else {
+    BuildVector = V2;
+    if (V1.getOpcode() != ISD::BITCAST)
+      return SDValue();
+  }
+  if (!ISD::isBuildVectorAllZeros(BuildVector.getNode()))
+    return SDValue();
+  APInt DestMask(16, 0);
+  for (unsigned i = 0; i < 16; ++i) {
+    SDValue Op = BuildVector->getOperand(i);
+    if (Op.isUndef())
+      DestMask.setBit(i);
+  }
+  if (DestMask.isZero())
+    return SDValue();
+
+  unsigned Imm8 = 0;
+  for (unsigned i = 0; i < 4; ++i) {
+    if (V1.getOpcode() != ISD::BUILD_VECTOR) {
+      if (RepeatedMask[i] >= 4) {
+        continue;
+      }
+    } else if (RepeatedMask[i] < 4) {
+      continue;
+    }
+    Imm8 += (RepeatedMask[i] % 4) << (2 * i);
+  }
+
+  SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1,
+                                DAG.getConstant(DestMask, DL, MVT::i16));
+
+  std::vector<SDValue> ZeroElements(16, DAG.getConstant(0, DL, MVT::i32));
+  SDValue Zeros = DAG.getBuildVector(MVT::v16i32, DL, ZeroElements);
+
+  return DAG.getNode(ISD::VSELECT, DL, MVT::v16i32, Bitcast,
+                     DAG.getNode(X86ISD::PSHUFD, DL, MVT::v16i32,
+                                 V1.getOpcode() != ISD::BUILD_VECTOR ? V1 : V2,
+                                 DAG.getTargetConstant(Imm8, DL, MVT::i8)),
+                     Zeros);
+}
+
 /// Handle lowering of 16-lane 32-bit integer shuffles.
 static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
                                   const APInt &Zeroable, SDValue V1, SDValue V2,
@@ -17217,6 +17269,9 @@ static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
     // Use dedicated unpack instructions for masks that match their pattern.
     if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16i32, V1, V2, Mask, DAG))
       return V;
+
+    if (SDValue V = lowerShuffleAsVSELECT(DL, RepeatedMask, V1, V2, DAG))
+      return V;
   }
 
   // Try to use shift instructions.

``````````

</details>


https://github.com/llvm/llvm-project/pull/121147


More information about the llvm-commits mailing list