[llvm] 4e8f847 - [X86][AVX512] Fold extract_element(bitcast(<X x i1>) -> bitcast(extract_subvector())
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Sun Oct 23 06:48:30 PDT 2022
Author: Simon Pilgrim
Date: 2022-10-23T14:47:24+01:00
New Revision: 4e8f847676fd589c456abb266afe968c20fb63e3
URL: https://github.com/llvm/llvm-project/commit/4e8f847676fd589c456abb266afe968c20fb63e3
DIFF: https://github.com/llvm/llvm-project/commit/4e8f847676fd589c456abb266afe968c20fb63e3.diff
LOG: [X86][AVX512] Fold extract_element(bitcast(<X x i1>) -> bitcast(extract_subvector())
On AVX512, extract legal bool vectors as bool subvectors before bitcasting to scalars to avoid spilling to stack.
This helps rust which internally represents bool vectors as bool arrays
It also exposes more missed opportunities to use the KADD instruction to add masks together before moving to gpr
Fixes #58546
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/bitcast-vector-bool.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 71e643e8c937..aafbe7b716c5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44324,6 +44324,8 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
SDLoc dl(InputVector);
bool IsPextr = N->getOpcode() != ISD::EXTRACT_VECTOR_ELT;
unsigned NumSrcElts = SrcVT.getVectorNumElements();
+ unsigned NumEltBits = VT.getScalarSizeInBits();
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (CIdx && CIdx->getAPIntValue().uge(NumSrcElts))
return IsPextr ? DAG.getConstant(0, dl, VT) : DAG.getUNDEF(VT);
@@ -44338,15 +44340,26 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
uint64_t Idx = CIdx->getZExtValue();
if (UndefVecElts[Idx])
return IsPextr ? DAG.getConstant(0, dl, VT) : DAG.getUNDEF(VT);
- return DAG.getConstant(EltBits[Idx].zext(VT.getScalarSizeInBits()), dl,
- VT);
+ return DAG.getConstant(EltBits[Idx].zext(NumEltBits), dl, VT);
+ }
+
+ // Convert extract_element(bitcast(<X x i1>) -> bitcast(extract_subvector()).
+ // Improves lowering of bool masks on rust which splits them into byte array.
+ if (InputVector.getOpcode() == ISD::BITCAST && (NumEltBits % 8) == 0) {
+ SDValue Src = peekThroughBitcasts(InputVector);
+ if (Src.getValueType().getScalarType() == MVT::i1 &&
+ TLI.isTypeLegal(Src.getValueType())) {
+ MVT SubVT = MVT::getVectorVT(MVT::i1, NumEltBits);
+ SDValue Sub = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, Src,
+ DAG.getIntPtrConstant(CIdx->getZExtValue() * NumEltBits, dl));
+ return DAG.getBitcast(VT, Sub);
+ }
}
}
if (IsPextr) {
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- if (TLI.SimplifyDemandedBits(SDValue(N, 0),
- APInt::getAllOnes(VT.getSizeInBits()), DCI))
+ if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumEltBits),
+ DCI))
return SDValue(N, 0);
// PEXTR*(PINSR*(v, s, c), c) -> s (with implicit zext handling).
diff --git a/llvm/test/CodeGen/X86/bitcast-vector-bool.ll b/llvm/test/CodeGen/X86/bitcast-vector-bool.ll
index 94901e665d56..de132c1c7ef4 100644
--- a/llvm/test/CodeGen/X86/bitcast-vector-bool.ll
+++ b/llvm/test/CodeGen/X86/bitcast-vector-bool.ll
@@ -123,14 +123,24 @@ define i8 @bitcast_v16i8_to_v2i8(<16 x i8> %a0) nounwind {
; SSE2-SSSE3-NEXT: addb -{{[0-9]+}}(%rsp), %al
; SSE2-SSSE3-NEXT: retq
;
-; AVX-LABEL: bitcast_v16i8_to_v2i8:
-; AVX: # %bb.0:
-; AVX-NEXT: vpmovmskb %xmm0, %ecx
-; AVX-NEXT: movl %ecx, %eax
-; AVX-NEXT: shrl $8, %eax
-; AVX-NEXT: addb %cl, %al
-; AVX-NEXT: # kill: def $al killed $al killed $eax
-; AVX-NEXT: retq
+; AVX12-LABEL: bitcast_v16i8_to_v2i8:
+; AVX12: # %bb.0:
+; AVX12-NEXT: vpmovmskb %xmm0, %ecx
+; AVX12-NEXT: movl %ecx, %eax
+; AVX12-NEXT: shrl $8, %eax
+; AVX12-NEXT: addb %cl, %al
+; AVX12-NEXT: # kill: def $al killed $al killed $eax
+; AVX12-NEXT: retq
+;
+; AVX512-LABEL: bitcast_v16i8_to_v2i8:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpmovb2m %xmm0, %k0
+; AVX512-NEXT: kshiftrw $8, %k0, %k1
+; AVX512-NEXT: kmovd %k0, %ecx
+; AVX512-NEXT: kmovd %k1, %eax
+; AVX512-NEXT: addb %cl, %al
+; AVX512-NEXT: # kill: def $al killed $al killed $eax
+; AVX512-NEXT: retq
%1 = icmp slt <16 x i8> %a0, zeroinitializer
%2 = bitcast <16 x i1> %1 to <2 x i8>
%3 = extractelement <2 x i8> %2, i32 0
@@ -242,10 +252,9 @@ define i8 @bitcast_v16i16_to_v2i8(<16 x i16> %a0) nounwind {
; AVX512-LABEL: bitcast_v16i16_to_v2i8:
; AVX512: # %bb.0:
; AVX512-NEXT: vpmovw2m %ymm0, %k0
-; AVX512-NEXT: kmovw %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT: vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT: vmovd %xmm0, %ecx
-; AVX512-NEXT: vpextrb $1, %xmm0, %eax
+; AVX512-NEXT: kshiftrw $8, %k0, %k1
+; AVX512-NEXT: kmovd %k0, %ecx
+; AVX512-NEXT: kmovd %k1, %eax
; AVX512-NEXT: addb %cl, %al
; AVX512-NEXT: # kill: def $al killed $al killed $eax
; AVX512-NEXT: vzeroupper
@@ -289,9 +298,10 @@ define i16 @bitcast_v32i8_to_v2i16(<32 x i8> %a0) nounwind {
;
; AVX512-LABEL: bitcast_v32i8_to_v2i16:
; AVX512: # %bb.0:
-; AVX512-NEXT: vpmovmskb %ymm0, %ecx
-; AVX512-NEXT: movl %ecx, %eax
-; AVX512-NEXT: shrl $16, %eax
+; AVX512-NEXT: vpmovb2m %ymm0, %k0
+; AVX512-NEXT: kshiftrd $16, %k0, %k1
+; AVX512-NEXT: kmovd %k0, %ecx
+; AVX512-NEXT: kmovd %k1, %eax
; AVX512-NEXT: addl %ecx, %eax
; AVX512-NEXT: # kill: def $ax killed $ax killed $eax
; AVX512-NEXT: vzeroupper
@@ -424,10 +434,9 @@ define i8 @bitcast_v16i32_to_v2i8(<16 x i32> %a0) nounwind {
; AVX512: # %bb.0:
; AVX512-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX512-NEXT: vpcmpgtd %zmm0, %zmm1, %k0
-; AVX512-NEXT: kmovw %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT: vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT: vmovd %xmm0, %ecx
-; AVX512-NEXT: vpextrb $1, %xmm0, %eax
+; AVX512-NEXT: kshiftrw $8, %k0, %k1
+; AVX512-NEXT: kmovd %k0, %ecx
+; AVX512-NEXT: kmovd %k1, %eax
; AVX512-NEXT: addb %cl, %al
; AVX512-NEXT: # kill: def $al killed $al killed $eax
; AVX512-NEXT: vzeroupper
@@ -479,10 +488,9 @@ define i16 @bitcast_v32i16_to_v2i16(<32 x i16> %a0) nounwind {
; AVX512-LABEL: bitcast_v32i16_to_v2i16:
; AVX512: # %bb.0:
; AVX512-NEXT: vpmovw2m %zmm0, %k0
-; AVX512-NEXT: kmovd %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT: vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT: vmovd %xmm0, %ecx
-; AVX512-NEXT: vpextrw $1, %xmm0, %eax
+; AVX512-NEXT: kshiftrd $16, %k0, %k1
+; AVX512-NEXT: kmovd %k0, %ecx
+; AVX512-NEXT: kmovd %k1, %eax
; AVX512-NEXT: addl %ecx, %eax
; AVX512-NEXT: # kill: def $ax killed $ax killed $eax
; AVX512-NEXT: vzeroupper
@@ -541,9 +549,10 @@ define i32 @bitcast_v64i8_to_v2i32(<64 x i8> %a0) nounwind {
; AVX512-LABEL: bitcast_v64i8_to_v2i32:
; AVX512: # %bb.0:
; AVX512-NEXT: vpmovb2m %zmm0, %k0
-; AVX512-NEXT: kmovq %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT: movl -{{[0-9]+}}(%rsp), %eax
-; AVX512-NEXT: addl -{{[0-9]+}}(%rsp), %eax
+; AVX512-NEXT: kshiftrq $32, %k0, %k1
+; AVX512-NEXT: kmovd %k0, %ecx
+; AVX512-NEXT: kmovd %k1, %eax
+; AVX512-NEXT: addl %ecx, %eax
; AVX512-NEXT: vzeroupper
; AVX512-NEXT: retq
%1 = icmp slt <64 x i8> %a0, zeroinitializer
@@ -698,10 +707,9 @@ define [2 x i8] @PR58546(<16 x float> %a0) {
; AVX512: # %bb.0:
; AVX512-NEXT: vxorps %xmm1, %xmm1, %xmm1
; AVX512-NEXT: vcmpunordps %zmm1, %zmm0, %k0
-; AVX512-NEXT: kmovw %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT: vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT: vmovd %xmm0, %eax
-; AVX512-NEXT: vpextrb $1, %xmm0, %edx
+; AVX512-NEXT: kshiftrw $8, %k0, %k1
+; AVX512-NEXT: kmovd %k0, %eax
+; AVX512-NEXT: kmovd %k1, %edx
; AVX512-NEXT: # kill: def $al killed $al killed $eax
; AVX512-NEXT: # kill: def $dl killed $dl killed $edx
; AVX512-NEXT: vzeroupper
More information about the llvm-commits
mailing list