[llvm] [X86][ISel] Improve logic for optimizing `movmsk(bitcast(shuffle(x)))` (PR #68369)
    via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Fri Oct  6 09:47:12 PDT 2023
    
    
  
https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/68369
>From bc637320c1f5d5859a00a1d30a406d560d79b2d4 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Thu, 5 Oct 2023 19:54:57 -0500
Subject: [PATCH 1/2] [X86] Add tests for incorrectly optimizing out shuffle
 used in `movmsk`; PR67287
---
 llvm/test/CodeGen/X86/movmsk-cmp.ll | 137 ++++++++++++++++++++++++++++
 1 file changed, 137 insertions(+)
diff --git a/llvm/test/CodeGen/X86/movmsk-cmp.ll b/llvm/test/CodeGen/X86/movmsk-cmp.ll
index a0901e265f5ae97..a2af976869dc7c6 100644
--- a/llvm/test/CodeGen/X86/movmsk-cmp.ll
+++ b/llvm/test/CodeGen/X86/movmsk-cmp.ll
@@ -4430,3 +4430,140 @@ define i32 @PR39665_c_ray_opt(<2 x double> %x, <2 x double> %y) {
   %r = select i1 %u, i32 42, i32 99
   ret i32 %r
 }
+
+define i32 @pr67287(<2 x i64> %broadcast.splatinsert25) {
+; SSE2-LABEL: pr67287:
+; SSE2:       # %bb.0: # %entry
+; SSE2-NEXT:    movl $3, %eax
+; SSE2-NEXT:    testl %eax, %eax
+; SSE2-NEXT:    jne .LBB97_2
+; SSE2-NEXT:  # %bb.1: # %entry
+; SSE2-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; SSE2-NEXT:    pxor %xmm1, %xmm1
+; SSE2-NEXT:    pcmpeqd %xmm0, %xmm1
+; SSE2-NEXT:    movd %xmm1, %eax
+; SSE2-NEXT:    testb $1, %al
+; SSE2-NEXT:    jne .LBB97_2
+; SSE2-NEXT:  # %bb.3: # %middle.block
+; SSE2-NEXT:    xorl %eax, %eax
+; SSE2-NEXT:    retq
+; SSE2-NEXT:  .LBB97_2:
+; SSE2-NEXT:    movw $0, 0
+; SSE2-NEXT:    xorl %eax, %eax
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: pr67287:
+; SSE41:       # %bb.0: # %entry
+; SSE41-NEXT:    pxor %xmm1, %xmm1
+; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
+; SSE41-NEXT:    pcmpeqq %xmm1, %xmm0
+; SSE41-NEXT:    movmskpd %xmm0, %eax
+; SSE41-NEXT:    testl %eax, %eax
+; SSE41-NEXT:    jne .LBB97_2
+; SSE41-NEXT:  # %bb.1: # %entry
+; SSE41-NEXT:    movd %xmm0, %eax
+; SSE41-NEXT:    testb $1, %al
+; SSE41-NEXT:    jne .LBB97_2
+; SSE41-NEXT:  # %bb.3: # %middle.block
+; SSE41-NEXT:    xorl %eax, %eax
+; SSE41-NEXT:    retq
+; SSE41-NEXT:  .LBB97_2:
+; SSE41-NEXT:    movw $0, 0
+; SSE41-NEXT:    xorl %eax, %eax
+; SSE41-NEXT:    retq
+;
+; AVX1-LABEL: pr67287:
+; AVX1:       # %bb.0: # %entry
+; AVX1-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX1-NEXT:    vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5],xmm1[6,7]
+; AVX1-NEXT:    vpcmpeqq %xmm1, %xmm0, %xmm0
+; AVX1-NEXT:    vtestpd %xmm0, %xmm0
+; AVX1-NEXT:    jne .LBB97_2
+; AVX1-NEXT:  # %bb.1: # %entry
+; AVX1-NEXT:    vmovd %xmm0, %eax
+; AVX1-NEXT:    testb $1, %al
+; AVX1-NEXT:    jne .LBB97_2
+; AVX1-NEXT:  # %bb.3: # %middle.block
+; AVX1-NEXT:    xorl %eax, %eax
+; AVX1-NEXT:    retq
+; AVX1-NEXT:  .LBB97_2:
+; AVX1-NEXT:    movw $0, 0
+; AVX1-NEXT:    xorl %eax, %eax
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: pr67287:
+; AVX2:       # %bb.0: # %entry
+; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX2-NEXT:    vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
+; AVX2-NEXT:    vpcmpeqq %xmm1, %xmm0, %xmm0
+; AVX2-NEXT:    vtestpd %xmm0, %xmm0
+; AVX2-NEXT:    jne .LBB97_2
+; AVX2-NEXT:  # %bb.1: # %entry
+; AVX2-NEXT:    vmovd %xmm0, %eax
+; AVX2-NEXT:    testb $1, %al
+; AVX2-NEXT:    jne .LBB97_2
+; AVX2-NEXT:  # %bb.3: # %middle.block
+; AVX2-NEXT:    xorl %eax, %eax
+; AVX2-NEXT:    retq
+; AVX2-NEXT:  .LBB97_2:
+; AVX2-NEXT:    movw $0, 0
+; AVX2-NEXT:    xorl %eax, %eax
+; AVX2-NEXT:    retq
+;
+; KNL-LABEL: pr67287:
+; KNL:       # %bb.0: # %entry
+; KNL-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; KNL-NEXT:    vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
+; KNL-NEXT:    vptestnmq %zmm0, %zmm0, %k0
+; KNL-NEXT:    kmovw %k0, %eax
+; KNL-NEXT:    testb $3, %al
+; KNL-NEXT:    jne .LBB97_2
+; KNL-NEXT:  # %bb.1: # %entry
+; KNL-NEXT:    kmovw %k0, %eax
+; KNL-NEXT:    testb $1, %al
+; KNL-NEXT:    jne .LBB97_2
+; KNL-NEXT:  # %bb.3: # %middle.block
+; KNL-NEXT:    xorl %eax, %eax
+; KNL-NEXT:    vzeroupper
+; KNL-NEXT:    retq
+; KNL-NEXT:  .LBB97_2:
+; KNL-NEXT:    movw $0, 0
+; KNL-NEXT:    xorl %eax, %eax
+; KNL-NEXT:    vzeroupper
+; KNL-NEXT:    retq
+;
+; SKX-LABEL: pr67287:
+; SKX:       # %bb.0: # %entry
+; SKX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; SKX-NEXT:    vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
+; SKX-NEXT:    vptestnmq %xmm0, %xmm0, %k0
+; SKX-NEXT:    kortestb %k0, %k0
+; SKX-NEXT:    jne .LBB97_2
+; SKX-NEXT:  # %bb.1: # %entry
+; SKX-NEXT:    kmovd %k0, %eax
+; SKX-NEXT:    testb $1, %al
+; SKX-NEXT:    jne .LBB97_2
+; SKX-NEXT:  # %bb.3: # %middle.block
+; SKX-NEXT:    xorl %eax, %eax
+; SKX-NEXT:    retq
+; SKX-NEXT:  .LBB97_2:
+; SKX-NEXT:    movw $0, 0
+; SKX-NEXT:    xorl %eax, %eax
+; SKX-NEXT:    retq
+entry:
+  %0 = and <2 x i64> %broadcast.splatinsert25, <i64 4294967295, i64 4294967295>
+  %1 = icmp eq <2 x i64> %0, zeroinitializer
+  %shift = shufflevector <2 x i1> %1, <2 x i1> zeroinitializer, <2 x i32> <i32 1, i32 poison>
+  %2 = or <2 x i1> %1, %shift
+  %3 = extractelement <2 x i1> %2, i64 0
+  %4 = extractelement <2 x i1> %1, i64 0
+  %5 = or i1 %3, %4
+  br i1 %5, label %6, label %middle.block
+
+6:                                                ; preds = %entry
+  store i16 0, ptr null, align 2
+  br label %middle.block
+
+middle.block:                                     ; preds = %6, %entry
+  ret i32 0
+}
>From 9de12e02bf0c976e53404c3e808d72feb56e1723 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Thu, 5 Oct 2023 19:28:52 -0500
Subject: [PATCH 2/2] [X86] Fix/improve logic for optimizing
 `movmsk(bitcast(shuffle(x)))`; PR67287
Prior logic would remove the shuffle iff all of the elements in `x`
where used. This is both incorrect and inoptimal.
The issue is `movmsk` only cares about the highbits, so if the width
of the elements in `x` is smaller than the width of the elements
for the `movmsk`, then the shuffle, even if it preserves all the elements,
may change which ones are used by the highbits.
For example:
`movmsk64(bitcast(shuffle32(x, (1,0,3,2))))`
Even though the shuffle mask `(1,0,3,2)` preserves all the elements, it
flips which will be relevant to the `movmsk64` (x[1] and x[3]
before and x[0] and x[2] after).
The fix is to instead of checking whether all the elements are
preserved, to check if all of the important "high" elements in the
shuffle are moved to other "high" elements.
This improves the range of the shift as a shuffle mask like
`(1,1,3,3)` is now optimizable (as it should be) and also fixed
the bug with masks like `(1,0,3,2)`.
---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 48 +++++++++++++++++++++----
 llvm/test/CodeGen/X86/movmsk-cmp.ll     |  9 ++---
 2 files changed, 46 insertions(+), 11 deletions(-)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c4cd2a672fe7b26..aba0311b6c83c9f 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -45836,18 +45836,52 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
   // MOVMSK(SHUFFLE(X,u)) -> MOVMSK(X) iff every element is referenced.
   SmallVector<int, 32> ShuffleMask;
   SmallVector<SDValue, 2> ShuffleInputs;
+  SDValue BaseVec = peekThroughBitcasts(Vec);
   if (NumElts <= CmpBits &&
-      getTargetShuffleInputs(peekThroughBitcasts(Vec), ShuffleInputs,
-                             ShuffleMask, DAG) &&
+      getTargetShuffleInputs(BaseVec, ShuffleInputs, ShuffleMask, DAG) &&
       ShuffleInputs.size() == 1 && !isAnyZeroOrUndef(ShuffleMask) &&
       ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits()) {
     unsigned NumShuffleElts = ShuffleMask.size();
-    APInt DemandedElts = APInt::getZero(NumShuffleElts);
-    for (int M : ShuffleMask) {
-      assert(0 <= M && M < (int)NumShuffleElts && "Bad unary shuffle index");
-      DemandedElts.setBit(M);
+
+    APInt Result = APInt::getZero(NumShuffleElts);
+    APInt ImportantLocs;
+    // Since we peek through a bitcast, we need to be careful if the base vector
+    // type has smaller elements than the MOVMSK type.  In that case, even if
+    // all the elements are demanded by the shuffle mask, only the "high"
+    // elements which have highbits that align with highbits in the MOVMSK vec
+    // elements are actually demanded. A simplification of spurious operations
+    // on the "low" elements take place during other simplifications.
+    //
+    // For example:
+    // MOVMSK64(BITCAST(SHUF32 X, (1,0,3,2))) even though all the elements are
+    // demanded, because we are swapping around the result can change.
+    //
+    // To address this, we need to make sure all the "high" elements are moved
+    // to other "high" locations.
+    MVT BaseVT = BaseVec.getSimpleValueType();
+    unsigned BaseNumElts = BaseVT.getVectorNumElements();
+    if (BaseNumElts > NumElts) {
+      ImportantLocs = APInt::getZero(NumShuffleElts);
+      assert((BaseNumElts % NumElts) == 0 &&
+             "Vec with unsupported element size");
+      unsigned Scale = BaseNumElts / NumElts;
+      for (unsigned i = 0; i < BaseNumElts; ++i) {
+        if ((i % Scale) == (Scale - 1))
+          ImportantLocs.setBit(i);
+      }
+    } else {
+      ImportantLocs = APInt::getAllOnes(NumShuffleElts);
+    }
+
+    for (unsigned ShufDst = 0; ShufDst < ShuffleMask.size(); ++ShufDst) {
+      int ShufSrc = ShuffleMask[ShufDst];
+      assert(0 <= ShufSrc && ShufSrc < (int)NumShuffleElts &&
+             "Bad unary shuffle index");
+      if (ImportantLocs[ShufSrc] && ImportantLocs[ShufDst])
+        Result.setBit(ShufSrc);
     }
-    if (DemandedElts.isAllOnes()) {
+
+    if (Result == ImportantLocs) {
       SDLoc DL(EFLAGS);
       SDValue Result = DAG.getBitcast(VecVT, ShuffleInputs[0]);
       Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
diff --git a/llvm/test/CodeGen/X86/movmsk-cmp.ll b/llvm/test/CodeGen/X86/movmsk-cmp.ll
index a2af976869dc7c6..f26bbb7e5c2bdac 100644
--- a/llvm/test/CodeGen/X86/movmsk-cmp.ll
+++ b/llvm/test/CodeGen/X86/movmsk-cmp.ll
@@ -4434,13 +4434,14 @@ define i32 @PR39665_c_ray_opt(<2 x double> %x, <2 x double> %y) {
 define i32 @pr67287(<2 x i64> %broadcast.splatinsert25) {
 ; SSE2-LABEL: pr67287:
 ; SSE2:       # %bb.0: # %entry
-; SSE2-NEXT:    movl $3, %eax
-; SSE2-NEXT:    testl %eax, %eax
-; SSE2-NEXT:    jne .LBB97_2
-; SSE2-NEXT:  # %bb.1: # %entry
 ; SSE2-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
 ; SSE2-NEXT:    pxor %xmm1, %xmm1
 ; SSE2-NEXT:    pcmpeqd %xmm0, %xmm1
+; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[1,0,3,2]
+; SSE2-NEXT:    movmskpd %xmm0, %eax
+; SSE2-NEXT:    testl %eax, %eax
+; SSE2-NEXT:    jne .LBB97_2
+; SSE2-NEXT:  # %bb.1: # %entry
 ; SSE2-NEXT:    movd %xmm1, %eax
 ; SSE2-NEXT:    testb $1, %al
 ; SSE2-NEXT:    jne .LBB97_2
    
    
More information about the llvm-commits
mailing list