[llvm] 4e8f847 - [X86][AVX512] Fold extract_element(bitcast(<X x i1>) -> bitcast(extract_subvector())

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 23 06:48:30 PDT 2022


Author: Simon Pilgrim
Date: 2022-10-23T14:47:24+01:00
New Revision: 4e8f847676fd589c456abb266afe968c20fb63e3

URL: https://github.com/llvm/llvm-project/commit/4e8f847676fd589c456abb266afe968c20fb63e3
DIFF: https://github.com/llvm/llvm-project/commit/4e8f847676fd589c456abb266afe968c20fb63e3.diff

LOG: [X86][AVX512] Fold extract_element(bitcast(<X x i1>) -> bitcast(extract_subvector())

On AVX512, extract legal bool vectors as bool subvectors before bitcasting to scalars to avoid spilling to stack.

This helps rust which internally represents bool vectors as bool arrays

It also exposes more missed opportunities to use the KADD instruction to add masks together before moving to gpr

Fixes #58546

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/bitcast-vector-bool.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 71e643e8c937..aafbe7b716c5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44324,6 +44324,8 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
   SDLoc dl(InputVector);
   bool IsPextr = N->getOpcode() != ISD::EXTRACT_VECTOR_ELT;
   unsigned NumSrcElts = SrcVT.getVectorNumElements();
+  unsigned NumEltBits = VT.getScalarSizeInBits();
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
   if (CIdx && CIdx->getAPIntValue().uge(NumSrcElts))
     return IsPextr ? DAG.getConstant(0, dl, VT) : DAG.getUNDEF(VT);
@@ -44338,15 +44340,26 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
       uint64_t Idx = CIdx->getZExtValue();
       if (UndefVecElts[Idx])
         return IsPextr ? DAG.getConstant(0, dl, VT) : DAG.getUNDEF(VT);
-      return DAG.getConstant(EltBits[Idx].zext(VT.getScalarSizeInBits()), dl,
-                             VT);
+      return DAG.getConstant(EltBits[Idx].zext(NumEltBits), dl, VT);
+    }
+
+    // Convert extract_element(bitcast(<X x i1>) -> bitcast(extract_subvector()).
+    // Improves lowering of bool masks on rust which splits them into byte array.
+    if (InputVector.getOpcode() == ISD::BITCAST && (NumEltBits % 8) == 0) {
+      SDValue Src = peekThroughBitcasts(InputVector);
+      if (Src.getValueType().getScalarType() == MVT::i1 &&
+          TLI.isTypeLegal(Src.getValueType())) {
+        MVT SubVT = MVT::getVectorVT(MVT::i1, NumEltBits);
+        SDValue Sub = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, Src,
+            DAG.getIntPtrConstant(CIdx->getZExtValue() * NumEltBits, dl));
+        return DAG.getBitcast(VT, Sub);
+      }
     }
   }
 
   if (IsPextr) {
-    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-    if (TLI.SimplifyDemandedBits(SDValue(N, 0),
-                                 APInt::getAllOnes(VT.getSizeInBits()), DCI))
+    if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumEltBits),
+                                 DCI))
       return SDValue(N, 0);
 
     // PEXTR*(PINSR*(v, s, c), c) -> s (with implicit zext handling).

diff  --git a/llvm/test/CodeGen/X86/bitcast-vector-bool.ll b/llvm/test/CodeGen/X86/bitcast-vector-bool.ll
index 94901e665d56..de132c1c7ef4 100644
--- a/llvm/test/CodeGen/X86/bitcast-vector-bool.ll
+++ b/llvm/test/CodeGen/X86/bitcast-vector-bool.ll
@@ -123,14 +123,24 @@ define i8 @bitcast_v16i8_to_v2i8(<16 x i8> %a0) nounwind {
 ; SSE2-SSSE3-NEXT:    addb -{{[0-9]+}}(%rsp), %al
 ; SSE2-SSSE3-NEXT:    retq
 ;
-; AVX-LABEL: bitcast_v16i8_to_v2i8:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vpmovmskb %xmm0, %ecx
-; AVX-NEXT:    movl %ecx, %eax
-; AVX-NEXT:    shrl $8, %eax
-; AVX-NEXT:    addb %cl, %al
-; AVX-NEXT:    # kill: def $al killed $al killed $eax
-; AVX-NEXT:    retq
+; AVX12-LABEL: bitcast_v16i8_to_v2i8:
+; AVX12:       # %bb.0:
+; AVX12-NEXT:    vpmovmskb %xmm0, %ecx
+; AVX12-NEXT:    movl %ecx, %eax
+; AVX12-NEXT:    shrl $8, %eax
+; AVX12-NEXT:    addb %cl, %al
+; AVX12-NEXT:    # kill: def $al killed $al killed $eax
+; AVX12-NEXT:    retq
+;
+; AVX512-LABEL: bitcast_v16i8_to_v2i8:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpmovb2m %xmm0, %k0
+; AVX512-NEXT:    kshiftrw $8, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
+; AVX512-NEXT:    addb %cl, %al
+; AVX512-NEXT:    # kill: def $al killed $al killed $eax
+; AVX512-NEXT:    retq
   %1 = icmp slt <16 x i8> %a0, zeroinitializer
   %2 = bitcast <16 x i1> %1 to <2 x i8>
   %3 = extractelement <2 x i8> %2, i32 0
@@ -242,10 +252,9 @@ define i8 @bitcast_v16i16_to_v2i8(<16 x i16> %a0) nounwind {
 ; AVX512-LABEL: bitcast_v16i16_to_v2i8:
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vpmovw2m %ymm0, %k0
-; AVX512-NEXT:    kmovw %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT:    vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT:    vmovd %xmm0, %ecx
-; AVX512-NEXT:    vpextrb $1, %xmm0, %eax
+; AVX512-NEXT:    kshiftrw $8, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
 ; AVX512-NEXT:    addb %cl, %al
 ; AVX512-NEXT:    # kill: def $al killed $al killed $eax
 ; AVX512-NEXT:    vzeroupper
@@ -289,9 +298,10 @@ define i16 @bitcast_v32i8_to_v2i16(<32 x i8> %a0) nounwind {
 ;
 ; AVX512-LABEL: bitcast_v32i8_to_v2i16:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpmovmskb %ymm0, %ecx
-; AVX512-NEXT:    movl %ecx, %eax
-; AVX512-NEXT:    shrl $16, %eax
+; AVX512-NEXT:    vpmovb2m %ymm0, %k0
+; AVX512-NEXT:    kshiftrd $16, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
 ; AVX512-NEXT:    addl %ecx, %eax
 ; AVX512-NEXT:    # kill: def $ax killed $ax killed $eax
 ; AVX512-NEXT:    vzeroupper
@@ -424,10 +434,9 @@ define i8 @bitcast_v16i32_to_v2i8(<16 x i32> %a0) nounwind {
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vpxor %xmm1, %xmm1, %xmm1
 ; AVX512-NEXT:    vpcmpgtd %zmm0, %zmm1, %k0
-; AVX512-NEXT:    kmovw %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT:    vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT:    vmovd %xmm0, %ecx
-; AVX512-NEXT:    vpextrb $1, %xmm0, %eax
+; AVX512-NEXT:    kshiftrw $8, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
 ; AVX512-NEXT:    addb %cl, %al
 ; AVX512-NEXT:    # kill: def $al killed $al killed $eax
 ; AVX512-NEXT:    vzeroupper
@@ -479,10 +488,9 @@ define i16 @bitcast_v32i16_to_v2i16(<32 x i16> %a0) nounwind {
 ; AVX512-LABEL: bitcast_v32i16_to_v2i16:
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vpmovw2m %zmm0, %k0
-; AVX512-NEXT:    kmovd %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT:    vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT:    vmovd %xmm0, %ecx
-; AVX512-NEXT:    vpextrw $1, %xmm0, %eax
+; AVX512-NEXT:    kshiftrd $16, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
 ; AVX512-NEXT:    addl %ecx, %eax
 ; AVX512-NEXT:    # kill: def $ax killed $ax killed $eax
 ; AVX512-NEXT:    vzeroupper
@@ -541,9 +549,10 @@ define i32 @bitcast_v64i8_to_v2i32(<64 x i8> %a0) nounwind {
 ; AVX512-LABEL: bitcast_v64i8_to_v2i32:
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vpmovb2m %zmm0, %k0
-; AVX512-NEXT:    kmovq %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT:    movl -{{[0-9]+}}(%rsp), %eax
-; AVX512-NEXT:    addl -{{[0-9]+}}(%rsp), %eax
+; AVX512-NEXT:    kshiftrq $32, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
+; AVX512-NEXT:    addl %ecx, %eax
 ; AVX512-NEXT:    vzeroupper
 ; AVX512-NEXT:    retq
   %1 = icmp slt <64 x i8> %a0, zeroinitializer
@@ -698,10 +707,9 @@ define [2 x i8] @PR58546(<16 x float> %a0) {
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vxorps %xmm1, %xmm1, %xmm1
 ; AVX512-NEXT:    vcmpunordps %zmm1, %zmm0, %k0
-; AVX512-NEXT:    kmovw %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT:    vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT:    vmovd %xmm0, %eax
-; AVX512-NEXT:    vpextrb $1, %xmm0, %edx
+; AVX512-NEXT:    kshiftrw $8, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %eax
+; AVX512-NEXT:    kmovd %k1, %edx
 ; AVX512-NEXT:    # kill: def $al killed $al killed $eax
 ; AVX512-NEXT:    # kill: def $dl killed $dl killed $edx
 ; AVX512-NEXT:    vzeroupper


        


More information about the llvm-commits mailing list