[llvm] b29ec28 - [X86] MatchVectorAllZeroTest - fix bug when splitting vectors of large elements
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 6 02:55:47 PDT 2023
Author: Simon Pilgrim
Date: 2023-04-06T10:55:38+01:00
New Revision: b29ec28fd67fb72b007d9d763d4c054e5fd9cd63
URL: https://github.com/llvm/llvm-project/commit/b29ec28fd67fb72b007d9d763d4c054e5fd9cd63
DIFF: https://github.com/llvm/llvm-project/commit/b29ec28fd67fb72b007d9d763d4c054e5fd9cd63.diff
LOG: [X86] MatchVectorAllZeroTest - fix bug when splitting vectors of large elements
DAG::SplitVector only works with vectors with even numbers of elements, when splitting vectors with large (illegal) element widths, we are likely to split down to <1 x iXXX>.
In such cases, pre-bitcast to a <X x i64> type to ensure splitting will always succeed.
Thanks to @alexfh for identifying this.
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/setcc-wide-types.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 62cbc8bc128f9..eb3afd4cc1518 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -24338,12 +24338,12 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
// Helper function for comparing all bits of two vectors.
static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
- ISD::CondCode CC, const APInt &Mask,
+ ISD::CondCode CC, const APInt &OriginalMask,
const X86Subtarget &Subtarget,
SelectionDAG &DAG, X86::CondCode &X86CC) {
EVT VT = LHS.getValueType();
unsigned ScalarSize = VT.getScalarSizeInBits();
- if (Mask.getBitWidth() != ScalarSize) {
+ if (OriginalMask.getBitWidth() != ScalarSize) {
assert(ScalarSize == 1 && "Element Mask vs Vector bitwidth mismatch");
return SDValue();
}
@@ -24355,6 +24355,8 @@ static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
assert((CC == ISD::SETEQ || CC == ISD::SETNE) && "Unsupported ISD::CondCode");
X86CC = (CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE);
+ APInt Mask = OriginalMask;
+
auto MaskBits = [&](SDValue Src) {
if (Mask.isAllOnes())
return Src;
@@ -24395,6 +24397,18 @@ static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
// Split down to 128/256/512-bit vector.
unsigned TestSize = UseKORTEST ? 512 : (Subtarget.hasAVX() ? 256 : 128);
+
+ // If the input vector has vector elements wider than the target test size,
+ // then cast to <X x i64> so it will safely split.
+ if (ScalarSize > TestSize) {
+ if (!Mask.isAllOnes())
+ return SDValue();
+ VT = EVT::getVectorVT(*DAG.getContext(), MVT::i64, VT.getSizeInBits() / 64);
+ LHS = DAG.getBitcast(VT, LHS);
+ RHS = DAG.getBitcast(VT, RHS);
+ Mask = APInt::getAllOnes(64);
+ }
+
if (VT.getSizeInBits() > TestSize) {
KnownBits KnownRHS = DAG.computeKnownBits(RHS);
if (KnownRHS.isConstant() && KnownRHS.getConstant() == Mask) {
diff --git a/llvm/test/CodeGen/X86/setcc-wide-types.ll b/llvm/test/CodeGen/X86/setcc-wide-types.ll
index 7d65da0319ee7..61254d5e5c2f4 100644
--- a/llvm/test/CodeGen/X86/setcc-wide-types.ll
+++ b/llvm/test/CodeGen/X86/setcc-wide-types.ll
@@ -596,6 +596,178 @@ define i32 @eq_i512(<8 x i64> %x, <8 x i64> %y) {
ret i32 %zext
}
+define i1 @ne_v4i256(<4 x i256> %a0) {
+; SSE2-LABEL: ne_v4i256:
+; SSE2: # %bb.0:
+; SSE2-NEXT: movq {{[0-9]+}}(%rsp), %rax
+; SSE2-NEXT: movq {{[0-9]+}}(%rsp), %r10
+; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %r10
+; SSE2-NEXT: movq %r10, %xmm0
+; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %rax
+; SSE2-NEXT: movq %rax, %xmm1
+; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
+; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %rcx
+; SSE2-NEXT: movq %rcx, %xmm0
+; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %rdx
+; SSE2-NEXT: movq %rdx, %xmm2
+; SSE2-NEXT: punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm0[0]
+; SSE2-NEXT: por %xmm1, %xmm2
+; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %r9
+; SSE2-NEXT: movq %r9, %xmm0
+; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %r8
+; SSE2-NEXT: movq %r8, %xmm1
+; SSE2-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
+; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %rsi
+; SSE2-NEXT: movq %rsi, %xmm0
+; SSE2-NEXT: orq {{[0-9]+}}(%rsp), %rdi
+; SSE2-NEXT: movq %rdi, %xmm3
+; SSE2-NEXT: punpcklqdq {{.*#+}} xmm3 = xmm3[0],xmm0[0]
+; SSE2-NEXT: por %xmm1, %xmm3
+; SSE2-NEXT: por %xmm2, %xmm3
+; SSE2-NEXT: pxor %xmm0, %xmm0
+; SSE2-NEXT: pcmpeqd %xmm3, %xmm0
+; SSE2-NEXT: movmskps %xmm0, %eax
+; SSE2-NEXT: xorl $15, %eax
+; SSE2-NEXT: sete %al
+; SSE2-NEXT: retq
+;
+; SSE41-LABEL: ne_v4i256:
+; SSE41: # %bb.0:
+; SSE41-NEXT: movq {{[0-9]+}}(%rsp), %rax
+; SSE41-NEXT: movq {{[0-9]+}}(%rsp), %r10
+; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %r10
+; SSE41-NEXT: movq %r10, %xmm0
+; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %rax
+; SSE41-NEXT: movq %rax, %xmm1
+; SSE41-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
+; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %rcx
+; SSE41-NEXT: movq %rcx, %xmm0
+; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %rdx
+; SSE41-NEXT: movq %rdx, %xmm2
+; SSE41-NEXT: punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm0[0]
+; SSE41-NEXT: por %xmm1, %xmm2
+; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %r9
+; SSE41-NEXT: movq %r9, %xmm0
+; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %r8
+; SSE41-NEXT: movq %r8, %xmm1
+; SSE41-NEXT: punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
+; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %rsi
+; SSE41-NEXT: movq %rsi, %xmm0
+; SSE41-NEXT: orq {{[0-9]+}}(%rsp), %rdi
+; SSE41-NEXT: movq %rdi, %xmm3
+; SSE41-NEXT: punpcklqdq {{.*#+}} xmm3 = xmm3[0],xmm0[0]
+; SSE41-NEXT: por %xmm1, %xmm3
+; SSE41-NEXT: por %xmm2, %xmm3
+; SSE41-NEXT: ptest %xmm3, %xmm3
+; SSE41-NEXT: sete %al
+; SSE41-NEXT: retq
+;
+; AVX1-LABEL: ne_v4i256:
+; AVX1: # %bb.0:
+; AVX1-NEXT: movq {{[0-9]+}}(%rsp), %rax
+; AVX1-NEXT: movq {{[0-9]+}}(%rsp), %r10
+; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %r10
+; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %rcx
+; AVX1-NEXT: orq %r10, %rcx
+; AVX1-NEXT: vmovq %rcx, %xmm0
+; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %rax
+; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %rdx
+; AVX1-NEXT: orq %rax, %rdx
+; AVX1-NEXT: vmovq %rdx, %xmm1
+; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm1[0],xmm0[0]
+; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %r9
+; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %rsi
+; AVX1-NEXT: orq %r9, %rsi
+; AVX1-NEXT: vmovq %rsi, %xmm1
+; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %r8
+; AVX1-NEXT: orq {{[0-9]+}}(%rsp), %rdi
+; AVX1-NEXT: orq %r8, %rdi
+; AVX1-NEXT: vmovq %rdi, %xmm2
+; AVX1-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0]
+; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm1, %ymm0
+; AVX1-NEXT: vptest %ymm0, %ymm0
+; AVX1-NEXT: sete %al
+; AVX1-NEXT: vzeroupper
+; AVX1-NEXT: retq
+;
+; AVX2-LABEL: ne_v4i256:
+; AVX2: # %bb.0:
+; AVX2-NEXT: movq {{[0-9]+}}(%rsp), %rax
+; AVX2-NEXT: movq {{[0-9]+}}(%rsp), %r10
+; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %r10
+; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %rcx
+; AVX2-NEXT: orq %r10, %rcx
+; AVX2-NEXT: vmovq %rcx, %xmm0
+; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %rax
+; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %rdx
+; AVX2-NEXT: orq %rax, %rdx
+; AVX2-NEXT: vmovq %rdx, %xmm1
+; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm0 = xmm1[0],xmm0[0]
+; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %r9
+; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %rsi
+; AVX2-NEXT: orq %r9, %rsi
+; AVX2-NEXT: vmovq %rsi, %xmm1
+; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %r8
+; AVX2-NEXT: orq {{[0-9]+}}(%rsp), %rdi
+; AVX2-NEXT: orq %r8, %rdi
+; AVX2-NEXT: vmovq %rdi, %xmm2
+; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0]
+; AVX2-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm0
+; AVX2-NEXT: vptest %ymm0, %ymm0
+; AVX2-NEXT: sete %al
+; AVX2-NEXT: vzeroupper
+; AVX2-NEXT: retq
+;
+; AVX512-LABEL: ne_v4i256:
+; AVX512: # %bb.0:
+; AVX512-NEXT: movq {{[0-9]+}}(%rsp), %rax
+; AVX512-NEXT: movq {{[0-9]+}}(%rsp), %r10
+; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %rax
+; AVX512-NEXT: vmovd %eax, %xmm0
+; AVX512-NEXT: shrq $32, %rax
+; AVX512-NEXT: vpinsrd $1, %eax, %xmm0, %xmm0
+; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %r10
+; AVX512-NEXT: vpinsrd $2, %r10d, %xmm0, %xmm0
+; AVX512-NEXT: shrq $32, %r10
+; AVX512-NEXT: vpinsrd $3, %r10d, %xmm0, %xmm0
+; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %r8
+; AVX512-NEXT: vmovd %r8d, %xmm1
+; AVX512-NEXT: shrq $32, %r8
+; AVX512-NEXT: vpinsrd $1, %r8d, %xmm1, %xmm1
+; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %r9
+; AVX512-NEXT: vpinsrd $2, %r9d, %xmm1, %xmm1
+; AVX512-NEXT: shrq $32, %r9
+; AVX512-NEXT: vpinsrd $3, %r9d, %xmm1, %xmm1
+; AVX512-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm0
+; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %rdx
+; AVX512-NEXT: vmovd %edx, %xmm1
+; AVX512-NEXT: shrq $32, %rdx
+; AVX512-NEXT: vpinsrd $1, %edx, %xmm1, %xmm1
+; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %rcx
+; AVX512-NEXT: vpinsrd $2, %ecx, %xmm1, %xmm1
+; AVX512-NEXT: shrq $32, %rcx
+; AVX512-NEXT: vpinsrd $3, %ecx, %xmm1, %xmm1
+; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %rdi
+; AVX512-NEXT: vmovd %edi, %xmm2
+; AVX512-NEXT: shrq $32, %rdi
+; AVX512-NEXT: vpinsrd $1, %edi, %xmm2, %xmm2
+; AVX512-NEXT: orq {{[0-9]+}}(%rsp), %rsi
+; AVX512-NEXT: vpinsrd $2, %esi, %xmm2, %xmm2
+; AVX512-NEXT: shrq $32, %rsi
+; AVX512-NEXT: vpinsrd $3, %esi, %xmm2, %xmm2
+; AVX512-NEXT: vinserti128 $1, %xmm1, %ymm2, %ymm1
+; AVX512-NEXT: vinserti64x4 $1, %ymm0, %zmm1, %zmm0
+; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k0
+; AVX512-NEXT: kortestw %k0, %k0
+; AVX512-NEXT: sete %al
+; AVX512-NEXT: vzeroupper
+; AVX512-NEXT: retq
+ %c = icmp ne <4 x i256> %a0, zeroinitializer
+ %b = bitcast <4 x i1> %c to i4
+ %r = icmp eq i4 %b, 0
+ ret i1 %r
+}
+
; This test models the expansion of 'memcmp(a, b, 32) != 0'
; if we allowed 2 pairs of 16-byte loads per block.
More information about the llvm-commits
mailing list