[llvm] 92083e8 - [X86] Allow VPERMV3 -> VPERMV folds to handle extraction from a wider source vector (e.g. v16i32 -> v4i32)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 8 05:13:24 PDT 2024


Author: Simon Pilgrim
Date: 2024-07-08T13:10:45+01:00
New Revision: 92083e855b4d4ce7e3f7633cd35a4fcb90e2c24f

URL: https://github.com/llvm/llvm-project/commit/92083e855b4d4ce7e3f7633cd35a4fcb90e2c24f
DIFF: https://github.com/llvm/llvm-project/commit/92083e855b4d4ce7e3f7633cd35a4fcb90e2c24f.diff

LOG: [X86] Allow VPERMV3 -> VPERMV folds to handle extraction from a wider source vector (e.g. v16i32 -> v4i32)

We don't need to restrict this to double width vectors, as long as we correctly bitcast the types

Improves the fix for #97968

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/pr97968.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 14d287d2d5e90..e116285d043c0 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -41336,29 +41336,27 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
   case X86ISD::VPERMV3: {
     // Combine VPERMV3 to widened VPERMV if the two source operands are split
     // from the same vector.
-    // TODO: Handle extraction from a wider source vector (e.g. v16i32 -> v4i32).
     SDValue V1 = peekThroughBitcasts(N.getOperand(0));
     SDValue V2 = peekThroughBitcasts(N.getOperand(2));
     MVT SVT = V1.getSimpleValueType();
-    MVT NVT = VT.getDoubleNumVectorElementsVT();
-    if ((NVT.is256BitVector() ||
-         (NVT.is512BitVector() && Subtarget.hasEVEX512())) &&
-        V1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+    if (V1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
         V1.getConstantOperandVal(1) == 0 &&
         V2.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
         V2.getConstantOperandVal(1) == SVT.getVectorNumElements() &&
-        V1.getOperand(0) == V2.getOperand(0) &&
-        V1.getOperand(0).getValueSizeInBits() == NVT.getSizeInBits()) {
-      SDValue Mask =
-          DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NVT, DAG.getUNDEF(NVT),
-                      N.getOperand(1), DAG.getIntPtrConstant(0, DL));
-      return DAG.getNode(
-          ISD::EXTRACT_SUBVECTOR, DL, VT,
-          DAG.getNode(X86ISD::VPERMV, DL, NVT, Mask,
-                      DAG.getBitcast(NVT, V1.getOperand(0))),
-          DAG.getIntPtrConstant(0, DL));
+        V1.getOperand(0) == V2.getOperand(0)) {
+      EVT NVT = V1.getOperand(0).getValueType();
+      if (NVT.is256BitVector() ||
+          (NVT.is512BitVector() && Subtarget.hasEVEX512())) {
+        MVT WideVT = MVT::getVectorVT(
+            VT.getScalarType(), NVT.getSizeInBits() / VT.getScalarSizeInBits());
+        SDValue Mask = widenSubVector(N.getOperand(1), false, Subtarget, DAG,
+                                      DL, WideVT.getSizeInBits());
+        SDValue Perm = DAG.getNode(X86ISD::VPERMV, DL, WideVT, Mask,
+                                   DAG.getBitcast(WideVT, V1.getOperand(0)));
+        return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Perm,
+                           DAG.getIntPtrConstant(0, DL));
+      }
     }
-
     return SDValue();
   }
   default:

diff  --git a/llvm/test/CodeGen/X86/pr97968.ll b/llvm/test/CodeGen/X86/pr97968.ll
index 103c60bfdfc7f..c8a0536ac4316 100644
--- a/llvm/test/CodeGen/X86/pr97968.ll
+++ b/llvm/test/CodeGen/X86/pr97968.ll
@@ -4,10 +4,10 @@
 define <2 x i32> @PR97968(<16 x i32> %a0) {
 ; CHECK-LABEL: PR97968:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2,7,2,7]
-; CHECK-NEXT:    vextracti128 $1, %ymm0, %xmm2
-; CHECK-NEXT:    vpermi2d %xmm2, %xmm0, %xmm1
-; CHECK-NEXT:    vmovdqa %xmm1, %xmm0
+; CHECK-NEXT:    vmovddup {{.*#+}} xmm1 = [2,7,2,7]
+; CHECK-NEXT:    # xmm1 = mem[0,0]
+; CHECK-NEXT:    vpermps %zmm0, %zmm1, %zmm0
+; CHECK-NEXT:    # kill: def $xmm0 killed $xmm0 killed $zmm0
 ; CHECK-NEXT:    vzeroupper
 ; CHECK-NEXT:    retq
   %sub0 = shufflevector <16 x i32> %a0, <16 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>


        


More information about the llvm-commits mailing list