[llvm] [X86] use lowerShuffleWithPERMV helper to create VPERMV/VPERMV3 nodes (PR #129882)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 5 05:16:29 PST 2025


https://github.com/RKSimon created https://github.com/llvm/llvm-project/pull/129882

This allows us to make use of the extra canonicalization that lowerShuffleWithPERMV performs

>From 5d08aa3580ef578b8a97f4374857be59da040fa5 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Wed, 5 Mar 2025 12:51:15 +0000
Subject: [PATCH] [X86] use lowerShuffleWithPERMV helper to create
 VPERMV/VPERMV3 nodes

This allows us to make use of the extra canonicalization that lowerShuffleWithPERMV performs
---
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 39 +++++++------------
 .../any_extend_vector_inreg_of_broadcast.ll   |  4 +-
 2 files changed, 16 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 882833f15b432..40bac0c20035b 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -42610,7 +42610,6 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
     // Combine VPERMV3 to widened VPERMV if the two source operands can be
     // freely concatenated.
     MVT WideVT = VT.getDoubleNumVectorElementsVT();
-    MVT MaskVT = N.getOperand(1).getSimpleValueType();
     bool CanConcat = VT.is128BitVector() ||
                      (VT.is256BitVector() && Subtarget.useAVX512Regs());
     if (CanConcat) {
@@ -42634,12 +42633,10 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
                 DL, WideVT, {N.getOperand(2), N.getOperand(0)}, DAG, DCI,
                 Subtarget)) {
           ShuffleVectorSDNode::commuteMask(Mask);
-          SDValue NewMask =
-              getConstVector(Mask, MaskVT, DAG, DL, /*IsMask=*/true);
-          NewMask = widenSubVector(NewMask, false, Subtarget, DAG, DL,
-                                   WideVT.getSizeInBits());
+          Mask.append(NumElts, SM_SentinelUndef);
           SDValue Perm =
-              DAG.getNode(X86ISD::VPERMV, DL, WideVT, NewMask, ConcatSrc);
+              lowerShuffleWithPERMV(DL, WideVT, Mask, ConcatSrc,
+                                    DAG.getUNDEF(WideVT), Subtarget, DAG);
           return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Perm,
                              DAG.getVectorIdxConstant(0, DL));
         }
@@ -42649,10 +42646,9 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
       // Canonicalize to VPERMV if both sources are the same.
       if (V1 == V2) {
         for (int &M : Mask)
-          M = (M < 0 ? M : M & (Mask.size() - 1));
-        SDValue NewMask = getConstVector(Mask, MaskVT, DAG, DL,
-                                         /*IsMask=*/true);
-        return DAG.getNode(X86ISD::VPERMV, DL, VT, NewMask, N.getOperand(0));
+          M = (M < 0 ? M : (M & (NumElts - 1)));
+        return lowerShuffleWithPERMV(DL, VT, Mask, N.getOperand(0),
+                                     DAG.getUNDEF(VT), Subtarget, DAG);
       }
       // If sources are half width, then concat and use VPERMV with adjusted
       // mask.
@@ -42667,19 +42663,16 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
                 combineConcatVectorOps(DL, VT, Ops, DAG, DCI, Subtarget)) {
           for (int &M : Mask)
             M = (M < (int)NumElts ? M : (M - (NumElts / 2)));
-          SDValue NewMask = getConstVector(Mask, MaskVT, DAG, DL,
-                                           /*IsMask=*/true);
-          return DAG.getNode(X86ISD::VPERMV, DL, VT, NewMask, ConcatSrc);
+          return lowerShuffleWithPERMV(DL, VT, Mask, ConcatSrc,
+                                       DAG.getUNDEF(VT), Subtarget, DAG);
         }
       }
       // Commute foldable source to the RHS.
       if (isShuffleFoldableLoad(N.getOperand(0)) &&
           !isShuffleFoldableLoad(N.getOperand(2))) {
         ShuffleVectorSDNode::commuteMask(Mask);
-        SDValue NewMask =
-            getConstVector(Mask, MaskVT, DAG, DL, /*IsMask=*/true);
-        return DAG.getNode(X86ISD::VPERMV3, DL, VT, N.getOperand(2), NewMask,
-                           N.getOperand(0));
+        return lowerShuffleWithPERMV(DL, VT, Mask, N.getOperand(2),
+                                     N.getOperand(0), Subtarget, DAG);
       }
     }
     return SDValue();
@@ -58048,10 +58041,8 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
         if (ConcatMask.size() == (NumOps * NumSrcElts)) {
           SDValue Src = concatSubVectors(Ops[0].getOperand(1),
                                          Ops[1].getOperand(1), DAG, DL);
-          MVT IntMaskSVT = MVT::getIntegerVT(EltSizeInBits);
-          MVT IntMaskVT = MVT::getVectorVT(IntMaskSVT, NumOps * NumSrcElts);
-          SDValue Mask = getConstVector(ConcatMask, IntMaskVT, DAG, DL, true);
-          return DAG.getNode(X86ISD::VPERMV, DL, VT, Mask, Src);
+          return lowerShuffleWithPERMV(DL, VT, ConcatMask, Src,
+                                       DAG.getUNDEF(VT), Subtarget, DAG);
         }
       }
       break;
@@ -58080,10 +58071,8 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
         if (ConcatMask.size() == (NumOps * NumSrcElts)) {
           SDValue Src0 = ConcatSubOperand(VT, Ops, 0);
           SDValue Src1 = ConcatSubOperand(VT, Ops, 2);
-          MVT IntMaskSVT = MVT::getIntegerVT(EltSizeInBits);
-          MVT IntMaskVT = MVT::getVectorVT(IntMaskSVT, NumOps * NumSrcElts);
-          SDValue Mask = getConstVector(ConcatMask, IntMaskVT, DAG, DL, true);
-          return DAG.getNode(X86ISD::VPERMV3, DL, VT, Src0, Mask, Src1);
+          return lowerShuffleWithPERMV(DL, VT, ConcatMask, Src0, Src1,
+                                       Subtarget, DAG);
         }
       }
       break;
diff --git a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll
index 951a2b4cafa26..facdbad50164a 100644
--- a/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll
+++ b/llvm/test/CodeGen/X86/any_extend_vector_inreg_of_broadcast.ll
@@ -3908,7 +3908,7 @@ define void @vec384_i16_widen_to_i96_factor6_broadcast_to_v4i96_factor4(ptr %in.
 ; AVX512BW:       # %bb.0:
 ; AVX512BW-NEXT:    vmovdqa64 (%rdi), %zmm0
 ; AVX512BW-NEXT:    vpaddb (%rsi), %zmm0, %zmm0
-; AVX512BW-NEXT:    vpmovsxbw {{.*#+}} ymm1 = [0,25,26,27,28,29,0,31,0,0,0,0,0,0,0,0]
+; AVX512BW-NEXT:    vpmovsxbw {{.*#+}} xmm1 = [0,25,26,27,28,29,0,31]
 ; AVX512BW-NEXT:    vpermw %zmm0, %zmm1, %zmm1
 ; AVX512BW-NEXT:    vpbroadcastw %xmm0, %ymm0
 ; AVX512BW-NEXT:    vinserti64x4 $1, %ymm0, %zmm1, %zmm0
@@ -4146,7 +4146,7 @@ define void @vec384_i16_widen_to_i192_factor12_broadcast_to_v2i192_factor2(ptr %
 ; AVX512BW:       # %bb.0:
 ; AVX512BW-NEXT:    vmovdqa64 (%rdi), %zmm0
 ; AVX512BW-NEXT:    vpaddb (%rsi), %zmm0, %zmm0
-; AVX512BW-NEXT:    vpmovsxbw {{.*#+}} ymm1 = [0,25,26,27,28,29,30,31,0,0,0,0,0,0,0,0]
+; AVX512BW-NEXT:    vpmovsxbw {{.*#+}} xmm1 = [0,25,26,27,28,29,30,31]
 ; AVX512BW-NEXT:    vpermw %zmm0, %zmm1, %zmm0
 ; AVX512BW-NEXT:    vpaddb (%rdx), %zmm0, %zmm0
 ; AVX512BW-NEXT:    vmovdqa64 %zmm0, (%rcx)



More information about the llvm-commits mailing list