[llvm] b29ec28 - [X86] MatchVectorAllZeroTest - fix bug when splitting vectors of large elements

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 6 02:55:47 PDT 2023


Author: Simon Pilgrim
Date: 2023-04-06T10:55:38+01:00
New Revision: b29ec28fd67fb72b007d9d763d4c054e5fd9cd63

URL: https://github.com/llvm/llvm-project/commit/b29ec28fd67fb72b007d9d763d4c054e5fd9cd63
DIFF: https://github.com/llvm/llvm-project/commit/b29ec28fd67fb72b007d9d763d4c054e5fd9cd63.diff

LOG: [X86] MatchVectorAllZeroTest - fix bug when splitting vectors of large elements

DAG::SplitVector only works with vectors with even numbers of elements, when splitting vectors with large (illegal) element widths, we are likely to split down to <1 x iXXX>.

In such cases, pre-bitcast to a <X x i64> type to ensure splitting will always succeed.

Thanks to @alexfh for identifying this.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/setcc-wide-types.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 62cbc8bc128f9..eb3afd4cc1518 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -24338,12 +24338,12 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
 
 // Helper function for comparing all bits of two vectors.
 static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
-                                   ISD::CondCode CC, const APInt &Mask,
+                                   ISD::CondCode CC, const APInt &OriginalMask,
                                    const X86Subtarget &Subtarget,
                                    SelectionDAG &DAG, X86::CondCode &X86CC) {
   EVT VT = LHS.getValueType();
   unsigned ScalarSize = VT.getScalarSizeInBits();
-  if (Mask.getBitWidth() != ScalarSize) {
+  if (OriginalMask.getBitWidth() != ScalarSize) {
     assert(ScalarSize == 1 && "Element Mask vs Vector bitwidth mismatch");
     return SDValue();
   }
@@ -24355,6 +24355,8 @@ static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
   assert((CC == ISD::SETEQ || CC == ISD::SETNE) && "Unsupported ISD::CondCode");
   X86CC = (CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE);
 
+  APInt Mask = OriginalMask;
+
   auto MaskBits = [&](SDValue Src) {
     if (Mask.isAllOnes())
       return Src;
@@ -24395,6 +24397,18 @@ static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
 
   // Split down to 128/256/512-bit vector.
   unsigned TestSize = UseKORTEST ? 512 : (Subtarget.hasAVX() ? 256 : 128);
+
+  // If the input vector has vector elements wider than the target test size,
+  // then cast to <X x i64> so it will safely split.
+  if (ScalarSize > TestSize) {
+    if (!Mask.isAllOnes())
+      return SDValue();
+    VT = EVT::getVectorVT(*DAG.getContext(), MVT::i64, VT.getSizeInBits() / 64);
+    LHS = DAG.getBitcast(VT, LHS);
+    RHS = DAG.getBitcast(VT, RHS);
+    Mask = APInt::getAllOnes(64);
+  }
+
   if (VT.getSizeInBits() > TestSize) {
     KnownBits KnownRHS = DAG.computeKnownBits(RHS);
     if (KnownRHS.isConstant() && KnownRHS.getConstant() == Mask) {

diff  --git a/llvm/test/CodeGen/X86/setcc-wide-types.ll b/llvm/test/CodeGen/X86/setcc-wide-types.ll
index 7d65da0319ee7..61254d5e5c2f4 100644
--- a/llvm/test/CodeGen/X86/setcc-wide-types.ll
+++ b/llvm/test/CodeGen/X86/setcc-wide-types.ll
@@ -596,6 +596,178 @@ define i32 @eq_i512(<8 x i64> %x, <8 x i64> %y) {
   ret i32 %zext
 }
 
+define i1 @ne_v4i256(<4 x i256> %a0) {
+; SSE2-LABEL: ne_v4i256:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    movq {{[0-9]+}}(%rsp), %rax
+; SSE2-NEXT:    movq {{[0-9]+}}(%rsp), %r10
+; SSE2-NEXT:    orq {{[0-9]+}}(%rsp), %r10
+; SSE2-NEXT:    movq %r10, %xmm0
+; SSE2-NEXT:    orq {{[0-9]+}}(%rsp), %rax
+; SSE2-NEXT:    movq %rax, %xmm1
+; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
+; SSE2-NEXT:    orq {{[0-9]+}}(%rsp), %rcx
+; SSE2-NEXT:    movq %rcx, %xmm0
+; SSE2-NEXT:    orq {{[0-9]+}}(%rsp), %rdx
+; SSE2-NEXT:    movq %rdx, %xmm2
+; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm0[0]
+; SSE2-NEXT:    por %xmm1, %xmm2
+; SSE2-NEXT:    orq {{[0-9]+}}(%rsp), %r9
+; SSE2-NEXT:    movq %r9, %xmm0
+; SSE2-NEXT:    orq {{[0-9]+}}(%rsp), %r8
+; SSE2-NEXT:    movq %r8, %xmm1
+; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
+; SSE2-NEXT:    orq {{[0-9]+}}(%rsp), %rsi
+; SSE2-NEXT:    movq %rsi, %xmm0
+; SSE2-NEXT:    orq {{[0-9]+}}(%rsp), %rdi
+; SSE2-NEXT:    movq %rdi, %xmm3
+; SSE2-NEXT:    punpcklqdq {{.*#+}} xmm3 = xmm3[0],xmm0[0]
+; SSE2-NEXT:    por %xmm1, %xmm3
+; SSE2-NEXT:    por %xmm2, %xmm3
+; SSE2-NEXT:    pxor %xmm0, %xmm0
+; SSE2-NEXT:    pcmpeqd %xmm3, %xmm0
+; SSE2-NEXT:    movmskps %xmm0, %eax
+; SSE2-NEXT:    xorl $15, %eax
+; SSE2-NEXT:    sete %al
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: ne_v4i256:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    movq {{[0-9]+}}(%rsp), %rax
+; SSE41-NEXT:    movq {{[0-9]+}}(%rsp), %r10
+; SSE41-NEXT:    orq {{[0-9]+}}(%rsp), %r10
+; SSE41-NEXT:    movq %r10, %xmm0
+; SSE41-NEXT:    orq {{[0-9]+}}(%rsp), %rax
+; SSE41-NEXT:    movq %rax, %xmm1
+; SSE41-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
+; SSE41-NEXT:    orq {{[0-9]+}}(%rsp), %rcx
+; SSE41-NEXT:    movq %rcx, %xmm0
+; SSE41-NEXT:    orq {{[0-9]+}}(%rsp), %rdx
+; SSE41-NEXT:    movq %rdx, %xmm2
+; SSE41-NEXT:    punpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm0[0]
+; SSE41-NEXT:    por %xmm1, %xmm2
+; SSE41-NEXT:    orq {{[0-9]+}}(%rsp), %r9
+; SSE41-NEXT:    movq %r9, %xmm0
+; SSE41-NEXT:    orq {{[0-9]+}}(%rsp), %r8
+; SSE41-NEXT:    movq %r8, %xmm1
+; SSE41-NEXT:    punpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm0[0]
+; SSE41-NEXT:    orq {{[0-9]+}}(%rsp), %rsi
+; SSE41-NEXT:    movq %rsi, %xmm0
+; SSE41-NEXT:    orq {{[0-9]+}}(%rsp), %rdi
+; SSE41-NEXT:    movq %rdi, %xmm3
+; SSE41-NEXT:    punpcklqdq {{.*#+}} xmm3 = xmm3[0],xmm0[0]
+; SSE41-NEXT:    por %xmm1, %xmm3
+; SSE41-NEXT:    por %xmm2, %xmm3
+; SSE41-NEXT:    ptest %xmm3, %xmm3
+; SSE41-NEXT:    sete %al
+; SSE41-NEXT:    retq
+;
+; AVX1-LABEL: ne_v4i256:
+; AVX1:       # %bb.0:
+; AVX1-NEXT:    movq {{[0-9]+}}(%rsp), %rax
+; AVX1-NEXT:    movq {{[0-9]+}}(%rsp), %r10
+; AVX1-NEXT:    orq {{[0-9]+}}(%rsp), %r10
+; AVX1-NEXT:    orq {{[0-9]+}}(%rsp), %rcx
+; AVX1-NEXT:    orq %r10, %rcx
+; AVX1-NEXT:    vmovq %rcx, %xmm0
+; AVX1-NEXT:    orq {{[0-9]+}}(%rsp), %rax
+; AVX1-NEXT:    orq {{[0-9]+}}(%rsp), %rdx
+; AVX1-NEXT:    orq %rax, %rdx
+; AVX1-NEXT:    vmovq %rdx, %xmm1
+; AVX1-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm1[0],xmm0[0]
+; AVX1-NEXT:    orq {{[0-9]+}}(%rsp), %r9
+; AVX1-NEXT:    orq {{[0-9]+}}(%rsp), %rsi
+; AVX1-NEXT:    orq %r9, %rsi
+; AVX1-NEXT:    vmovq %rsi, %xmm1
+; AVX1-NEXT:    orq {{[0-9]+}}(%rsp), %r8
+; AVX1-NEXT:    orq {{[0-9]+}}(%rsp), %rdi
+; AVX1-NEXT:    orq %r8, %rdi
+; AVX1-NEXT:    vmovq %rdi, %xmm2
+; AVX1-NEXT:    vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0]
+; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm1, %ymm0
+; AVX1-NEXT:    vptest %ymm0, %ymm0
+; AVX1-NEXT:    sete %al
+; AVX1-NEXT:    vzeroupper
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: ne_v4i256:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    movq {{[0-9]+}}(%rsp), %rax
+; AVX2-NEXT:    movq {{[0-9]+}}(%rsp), %r10
+; AVX2-NEXT:    orq {{[0-9]+}}(%rsp), %r10
+; AVX2-NEXT:    orq {{[0-9]+}}(%rsp), %rcx
+; AVX2-NEXT:    orq %r10, %rcx
+; AVX2-NEXT:    vmovq %rcx, %xmm0
+; AVX2-NEXT:    orq {{[0-9]+}}(%rsp), %rax
+; AVX2-NEXT:    orq {{[0-9]+}}(%rsp), %rdx
+; AVX2-NEXT:    orq %rax, %rdx
+; AVX2-NEXT:    vmovq %rdx, %xmm1
+; AVX2-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm1[0],xmm0[0]
+; AVX2-NEXT:    orq {{[0-9]+}}(%rsp), %r9
+; AVX2-NEXT:    orq {{[0-9]+}}(%rsp), %rsi
+; AVX2-NEXT:    orq %r9, %rsi
+; AVX2-NEXT:    vmovq %rsi, %xmm1
+; AVX2-NEXT:    orq {{[0-9]+}}(%rsp), %r8
+; AVX2-NEXT:    orq {{[0-9]+}}(%rsp), %rdi
+; AVX2-NEXT:    orq %r8, %rdi
+; AVX2-NEXT:    vmovq %rdi, %xmm2
+; AVX2-NEXT:    vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0]
+; AVX2-NEXT:    vinserti128 $1, %xmm0, %ymm1, %ymm0
+; AVX2-NEXT:    vptest %ymm0, %ymm0
+; AVX2-NEXT:    sete %al
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: ne_v4i256:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    movq {{[0-9]+}}(%rsp), %rax
+; AVX512-NEXT:    movq {{[0-9]+}}(%rsp), %r10
+; AVX512-NEXT:    orq {{[0-9]+}}(%rsp), %rax
+; AVX512-NEXT:    vmovd %eax, %xmm0
+; AVX512-NEXT:    shrq $32, %rax
+; AVX512-NEXT:    vpinsrd $1, %eax, %xmm0, %xmm0
+; AVX512-NEXT:    orq {{[0-9]+}}(%rsp), %r10
+; AVX512-NEXT:    vpinsrd $2, %r10d, %xmm0, %xmm0
+; AVX512-NEXT:    shrq $32, %r10
+; AVX512-NEXT:    vpinsrd $3, %r10d, %xmm0, %xmm0
+; AVX512-NEXT:    orq {{[0-9]+}}(%rsp), %r8
+; AVX512-NEXT:    vmovd %r8d, %xmm1
+; AVX512-NEXT:    shrq $32, %r8
+; AVX512-NEXT:    vpinsrd $1, %r8d, %xmm1, %xmm1
+; AVX512-NEXT:    orq {{[0-9]+}}(%rsp), %r9
+; AVX512-NEXT:    vpinsrd $2, %r9d, %xmm1, %xmm1
+; AVX512-NEXT:    shrq $32, %r9
+; AVX512-NEXT:    vpinsrd $3, %r9d, %xmm1, %xmm1
+; AVX512-NEXT:    vinserti128 $1, %xmm0, %ymm1, %ymm0
+; AVX512-NEXT:    orq {{[0-9]+}}(%rsp), %rdx
+; AVX512-NEXT:    vmovd %edx, %xmm1
+; AVX512-NEXT:    shrq $32, %rdx
+; AVX512-NEXT:    vpinsrd $1, %edx, %xmm1, %xmm1
+; AVX512-NEXT:    orq {{[0-9]+}}(%rsp), %rcx
+; AVX512-NEXT:    vpinsrd $2, %ecx, %xmm1, %xmm1
+; AVX512-NEXT:    shrq $32, %rcx
+; AVX512-NEXT:    vpinsrd $3, %ecx, %xmm1, %xmm1
+; AVX512-NEXT:    orq {{[0-9]+}}(%rsp), %rdi
+; AVX512-NEXT:    vmovd %edi, %xmm2
+; AVX512-NEXT:    shrq $32, %rdi
+; AVX512-NEXT:    vpinsrd $1, %edi, %xmm2, %xmm2
+; AVX512-NEXT:    orq {{[0-9]+}}(%rsp), %rsi
+; AVX512-NEXT:    vpinsrd $2, %esi, %xmm2, %xmm2
+; AVX512-NEXT:    shrq $32, %rsi
+; AVX512-NEXT:    vpinsrd $3, %esi, %xmm2, %xmm2
+; AVX512-NEXT:    vinserti128 $1, %xmm1, %ymm2, %ymm1
+; AVX512-NEXT:    vinserti64x4 $1, %ymm0, %zmm1, %zmm0
+; AVX512-NEXT:    vptestmd %zmm0, %zmm0, %k0
+; AVX512-NEXT:    kortestw %k0, %k0
+; AVX512-NEXT:    sete %al
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+  %c = icmp ne <4 x i256> %a0, zeroinitializer
+  %b = bitcast <4 x i1> %c to i4
+  %r = icmp eq i4 %b, 0
+  ret i1 %r
+}
+
 ; This test models the expansion of 'memcmp(a, b, 32) != 0'
 ; if we allowed 2 pairs of 16-byte loads per block.
 


        


More information about the llvm-commits mailing list