[llvm] [X86][AVX512] Better lowering for `_mm512_maskz_shuffle_epi32` (PR #121147)

Abhishek Kaushik via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 10 09:49:29 PDT 2025


https://github.com/abhishek-kaushik22 updated https://github.com/llvm/llvm-project/pull/121147

>From 4f2e5e43d117283a27309bd5661128b28fc466fe Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Thu, 26 Dec 2024 17:40:36 +0530
Subject: [PATCH 1/3] [X86][AVX512] Better lowering for
 `_mm512_maskz_shuffle_epi32`

For the function (https://godbolt.org/z/4rTYeMY4b)
```
#include <immintrin.h>

__m512i foo(__m512i a){
    __m512i r0 = _mm512_maskz_shuffle_epi32(0xaaaa, a, 0xab);
    return r0;
}
```
The assembly generated is unnecessarily long
```
.LCPI0_1:
        .byte   0
        .byte   18
        .byte   2
        .byte   18
        .byte   4
        .byte   22
        .byte   6
        .byte   22
        .byte   8
        .byte   26
        .byte   10
        .byte   26
        .byte   12
        .byte   30
        .byte   14
        .byte   30
foo(long long vector[8]):
        vpmovsxbd       zmm2, xmmword ptr [rip + .LCPI0_1]
        vpxor   xmm1, xmm1, xmm1
        vpermt2d        zmm1, zmm2, zmm0
        vmovdqa64       zmm0, zmm1
        ret
```
Instead we could simply generate a `vpshufd {{.*#+}} zmm0 {%k1} {z}` instruction and pass the mask and the `imm8` value to it.

The selection dag generated from the IR doesn't contain the mask and the `imm8` value directly but there is a pattern that can be matched here.
```
t6: v16i32 = BUILD_VECTOR Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32, Constant:i32<0>, undef:i32
          t2: v8i64,ch = CopyFromReg t0, Register:v8i64 %0
        t3: v16i32 = bitcast t2
      t7: v16i32 = vector_shuffle<0,18,2,18,4,22,6,22,8,26,10,26,12,30,14,30> t6, t3
    t8: v8i64 = bitcast t7
```
I've tried to match this pattern to get the value of the mask and imm8, and generate a `VSELECT` node.
The resulting assembly looks like
```
       movw    $-21846, %ax                    # imm = 0xAAAA
        kmovw   %eax, %k1
        vpshufd $136, %zmm0, %zmm0 {%k1} {z}    # zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
        retq
```
---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 55 +++++++++++++++++++++++++
 1 file changed, 55 insertions(+)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index e7f6032ee7d74..ca07b81f3fb98 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -17172,6 +17172,58 @@ static SDValue lowerV8I64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
   return lowerShuffleWithPERMV(DL, MVT::v8i64, Mask, V1, V2, Subtarget, DAG);
 }
 
+static SDValue lowerShuffleAsVSELECT(const SDLoc &DL,
+                                     ArrayRef<int> RepeatedMask, SDValue V1,
+                                     SDValue V2, SelectionDAG &DAG) {
+  if (V1.getOpcode() != ISD::BUILD_VECTOR &&
+      V2.getOpcode() != ISD::BUILD_VECTOR)
+    return SDValue();
+  SDValue BuildVector;
+  if (V1.getOpcode() == ISD::BUILD_VECTOR) {
+    BuildVector = V1;
+    if (V2.getOpcode() != ISD::BITCAST)
+      return SDValue();
+  } else {
+    BuildVector = V2;
+    if (V1.getOpcode() != ISD::BITCAST)
+      return SDValue();
+  }
+  if (!ISD::isBuildVectorAllZeros(BuildVector.getNode()))
+    return SDValue();
+  APInt DestMask(16, 0);
+  for (unsigned i = 0; i < 16; ++i) {
+    SDValue Op = BuildVector->getOperand(i);
+    if (Op.isUndef())
+      DestMask.setBit(i);
+  }
+  if (DestMask.isZero())
+    return SDValue();
+
+  unsigned Imm8 = 0;
+  for (unsigned i = 0; i < 4; ++i) {
+    if (V1.getOpcode() != ISD::BUILD_VECTOR) {
+      if (RepeatedMask[i] >= 4) {
+        continue;
+      }
+    } else if (RepeatedMask[i] < 4) {
+      continue;
+    }
+    Imm8 += (RepeatedMask[i] % 4) << (2 * i);
+  }
+
+  SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1,
+                                DAG.getConstant(DestMask, DL, MVT::i16));
+
+  std::vector<SDValue> ZeroElements(16, DAG.getConstant(0, DL, MVT::i32));
+  SDValue Zeros = DAG.getBuildVector(MVT::v16i32, DL, ZeroElements);
+
+  return DAG.getNode(ISD::VSELECT, DL, MVT::v16i32, Bitcast,
+                     DAG.getNode(X86ISD::PSHUFD, DL, MVT::v16i32,
+                                 V1.getOpcode() != ISD::BUILD_VECTOR ? V1 : V2,
+                                 DAG.getTargetConstant(Imm8, DL, MVT::i8)),
+                     Zeros);
+}
+
 /// Handle lowering of 16-lane 32-bit integer shuffles.
 static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
                                   const APInt &Zeroable, SDValue V1, SDValue V2,
@@ -17217,6 +17269,9 @@ static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
     // Use dedicated unpack instructions for masks that match their pattern.
     if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16i32, V1, V2, Mask, DAG))
       return V;
+
+    if (SDValue V = lowerShuffleAsVSELECT(DL, RepeatedMask, V1, V2, DAG))
+      return V;
   }
 
   // Try to use shift instructions.

>From 32d0603bb6c8f897ce644bf1f165de82e5de6ace Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Thu, 26 Dec 2024 17:44:32 +0530
Subject: [PATCH 2/3] Add test

---
 .../CodeGen/X86/vector-shuffle-512-v16.ll     | 20 +++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll b/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll
index d3b04878dc06d..3a2a96fcd5de3 100644
--- a/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll
+++ b/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll
@@ -990,3 +990,23 @@ bb:
   ret void
 }
 
+define <8 x i64> @pr121147(<8 x i64> %a) {
+; AVX512F-LABEL: pr121147:
+; AVX512F:       # %bb.0: # %entry
+; AVX512F-NEXT:    movw $-21846, %ax # imm = 0xAAAA
+; AVX512F-NEXT:    kmovw %eax, %k1
+; AVX512F-NEXT:    vpshufd {{.*#+}} zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
+; AVX512F-NEXT:    retq
+;
+; AVX512BW-LABEL: pr121147:
+; AVX512BW:       # %bb.0: # %entry
+; AVX512BW-NEXT:    movw $-21846, %ax # imm = 0xAAAA
+; AVX512BW-NEXT:    kmovd %eax, %k1
+; AVX512BW-NEXT:    vpshufd {{.*#+}} zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
+; AVX512BW-NEXT:    retq
+entry:
+  %0 = bitcast <8 x i64> %a to <16 x i32>
+  %1 = shufflevector <16 x i32> <i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison>, <16 x i32> %0, <16 x i32> <i32 0, i32 18, i32 2, i32 18, i32 4, i32 22, i32 6, i32 22, i32 8, i32 26, i32 10, i32 26, i32 12, i32 30, i32 14, i32 30>
+  %2 = bitcast <16 x i32> %1 to <8 x i64>
+  ret <8 x i64> %2
+}

>From 24989a2f02dce6e83650faf81e8c1516ba98c710 Mon Sep 17 00:00:00 2001
From: abhishek-kaushik22 <abhishek.kaushik at intel.com>
Date: Mon, 10 Mar 2025 22:19:13 +0530
Subject: [PATCH 3/3] Fix reviews

---
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 56 +++++++++----------
 .../CodeGen/X86/vector-shuffle-512-v16.ll     | 39 +++++++++----
 2 files changed, 53 insertions(+), 42 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 288fe35d41d97..48adbbebf1bf7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -17350,24 +17350,23 @@ static SDValue lowerV8I64Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
   return lowerShuffleWithPERMV(DL, MVT::v8i64, Mask, V1, V2, Subtarget, DAG);
 }
 
-static SDValue lowerShuffleAsVSELECT(const SDLoc &DL,
-                                     ArrayRef<int> RepeatedMask, SDValue V1,
-                                     SDValue V2, SelectionDAG &DAG) {
+static SDValue lowerShuffleAsVSELECT(const SDLoc &DL, ArrayRef<int> Mask,
+                                     SDValue V1, SDValue V2,
+                                     SelectionDAG &DAG) {
   if (V1.getOpcode() != ISD::BUILD_VECTOR &&
       V2.getOpcode() != ISD::BUILD_VECTOR)
     return SDValue();
-  SDValue BuildVector;
-  if (V1.getOpcode() == ISD::BUILD_VECTOR) {
-    BuildVector = V1;
-    if (V2.getOpcode() != ISD::BITCAST)
-      return SDValue();
-  } else {
-    BuildVector = V2;
-    if (V1.getOpcode() != ISD::BITCAST)
-      return SDValue();
-  }
+
+  bool IsV1BuildVector = V1.getOpcode() == ISD::BUILD_VECTOR;
+  SDValue BuildVector = IsV1BuildVector ? V1 : V2;
+
   if (!ISD::isBuildVectorAllZeros(BuildVector.getNode()))
     return SDValue();
+  
+  // This relates to the lowering of `_mm512_maskz_shuffle_epi32` intrinsic. 
+  // The `BUILD_VECTOR` contains the zeroing mask. If the corresponding 
+  // element is UNDEF, then the bit in mask is set. If it is zero, the
+  // corresponding bit in mask is zero.
   APInt DestMask(16, 0);
   for (unsigned i = 0; i < 16; ++i) {
     SDValue Op = BuildVector->getOperand(i);
@@ -17377,28 +17376,25 @@ static SDValue lowerShuffleAsVSELECT(const SDLoc &DL,
   if (DestMask.isZero())
     return SDValue();
 
-  unsigned Imm8 = 0;
-  for (unsigned i = 0; i < 4; ++i) {
-    if (V1.getOpcode() != ISD::BUILD_VECTOR) {
-      if (RepeatedMask[i] >= 4) {
-        continue;
-      }
-    } else if (RepeatedMask[i] < 4) {
-      continue;
-    }
-    Imm8 += (RepeatedMask[i] % 4) << (2 * i);
-  }
-
   SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, MVT::v16i1,
                                 DAG.getConstant(DestMask, DL, MVT::i16));
 
-  std::vector<SDValue> ZeroElements(16, DAG.getConstant(0, DL, MVT::i32));
+  SmallVector<SDValue, 16> ZeroElements(16, DAG.getConstant(0, DL, MVT::i32));
   SDValue Zeros = DAG.getBuildVector(MVT::v16i32, DL, ZeroElements);
 
+  SmallVector<int, 16> NewMask(16);
+  for (int I = 0; I < 16; ++I) {
+    if (IsV1BuildVector) {
+      NewMask[I] = Mask[I] >= 16 ? Mask[I] - 16 : Mask[I] + 16;
+    } else {
+      NewMask[I] = Mask[I];
+    }
+  }
+
   return DAG.getNode(ISD::VSELECT, DL, MVT::v16i32, Bitcast,
-                     DAG.getNode(X86ISD::PSHUFD, DL, MVT::v16i32,
-                                 V1.getOpcode() != ISD::BUILD_VECTOR ? V1 : V2,
-                                 DAG.getTargetConstant(Imm8, DL, MVT::i8)),
+                     DAG.getVectorShuffle(MVT::v16i32, DL,
+                                          IsV1BuildVector ? V2 : V1,
+                                          DAG.getUNDEF(MVT::v16i32), NewMask),
                      Zeros);
 }
 
@@ -17448,7 +17444,7 @@ static SDValue lowerV16I32Shuffle(const SDLoc &DL, ArrayRef<int> Mask,
     if (SDValue V = lowerShuffleWithUNPCK(DL, MVT::v16i32, V1, V2, Mask, DAG))
       return V;
 
-    if (SDValue V = lowerShuffleAsVSELECT(DL, RepeatedMask, V1, V2, DAG))
+    if (SDValue V = lowerShuffleAsVSELECT(DL, Mask, V1, V2, DAG))
       return V;
   }
 
diff --git a/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll b/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll
index 889836be3f750..b08bf63b6d3d6 100644
--- a/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll
+++ b/llvm/test/CodeGen/X86/vector-shuffle-512-v16.ll
@@ -990,23 +990,38 @@ bb:
   ret void
 }
 
-define <8 x i64> @pr121147(<8 x i64> %a) {
-; AVX512F-LABEL: pr121147:
-; AVX512F:       # %bb.0: # %entry
+define <16 x i32> @gen_VPSHUFD_AVX512_0(<16 x i32> %a) {
+; AVX512F-LABEL: gen_VPSHUFD_AVX512_0:
+; AVX512F:       # %bb.0:
 ; AVX512F-NEXT:    movw $-21846, %ax # imm = 0xAAAA
 ; AVX512F-NEXT:    kmovw %eax, %k1
-; AVX512F-NEXT:    vpshufd {{.*#+}} zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
+; AVX512F-NEXT:    vpshufd {{.*#+}} zmm0 {%k1} {z} = zmm0[2,2,2,2,6,6,6,6,10,10,10,10,14,14,14,14]
 ; AVX512F-NEXT:    retq
 ;
-; AVX512BW-LABEL: pr121147:
-; AVX512BW:       # %bb.0: # %entry
+; AVX512BW-LABEL: gen_VPSHUFD_AVX512_0:
+; AVX512BW:       # %bb.0:
 ; AVX512BW-NEXT:    movw $-21846, %ax # imm = 0xAAAA
 ; AVX512BW-NEXT:    kmovd %eax, %k1
-; AVX512BW-NEXT:    vpshufd {{.*#+}} zmm0 {%k1} {z} = zmm0[0,2,0,2,4,6,4,6,8,10,8,10,12,14,12,14]
+; AVX512BW-NEXT:    vpshufd {{.*#+}} zmm0 {%k1} {z} = zmm0[2,2,2,2,6,6,6,6,10,10,10,10,14,14,14,14]
 ; AVX512BW-NEXT:    retq
-entry:
-  %0 = bitcast <8 x i64> %a to <16 x i32>
-  %1 = shufflevector <16 x i32> <i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison>, <16 x i32> %0, <16 x i32> <i32 0, i32 18, i32 2, i32 18, i32 4, i32 22, i32 6, i32 22, i32 8, i32 26, i32 10, i32 26, i32 12, i32 30, i32 14, i32 30>
-  %2 = bitcast <16 x i32> %1 to <8 x i64>
-  ret <8 x i64> %2
+  %res = shufflevector <16 x i32> <i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison>, <16 x i32> %a, <16 x i32> <i32 0, i32 18, i32 2, i32 18, i32 4, i32 22, i32 6, i32 22, i32 8, i32 26, i32 10, i32 26, i32 12, i32 30, i32 14, i32 30>
+  ret <16 x i32> %res
+}
+
+define <16 x i32> @gen_VPSHUFD_AVX512_1(<16 x i32> %a) {
+; AVX512F-LABEL: gen_VPSHUFD_AVX512_1:
+; AVX512F:       # %bb.0:
+; AVX512F-NEXT:    movw $-21846, %ax # imm = 0xAAAA
+; AVX512F-NEXT:    kmovw %eax, %k1
+; AVX512F-NEXT:    vpshufd {{.*#+}} zmm0 {%k1} {z} = zmm0[2,2,2,2,6,6,6,6,10,10,10,10,14,14,14,14]
+; AVX512F-NEXT:    retq
+;
+; AVX512BW-LABEL: gen_VPSHUFD_AVX512_1:
+; AVX512BW:       # %bb.0:
+; AVX512BW-NEXT:    movw $-21846, %ax # imm = 0xAAAA
+; AVX512BW-NEXT:    kmovd %eax, %k1
+; AVX512BW-NEXT:    vpshufd {{.*#+}} zmm0 {%k1} {z} = zmm0[2,2,2,2,6,6,6,6,10,10,10,10,14,14,14,14]
+; AVX512BW-NEXT:    retq
+  %res = shufflevector <16 x i32> %a , <16 x i32> <i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison, i32 0, i32 poison>, <16 x i32> <i32 16, i32 2, i32 18, i32 2, i32 20, i32 6, i32 22, i32 6, i32 24, i32 10, i32 26, i32 10, i32 28, i32 14, i32 30, i32 14>
+  ret <16 x i32> %res
 }



More information about the llvm-commits mailing list