[llvm] d891968 - [X86] combineMOVMSK - merge movmsk(icmp_eq(and(x,c1),c1)) and movmsk(icmp_eq(and(x,c1),0)) folds

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 3 09:18:11 PDT 2023


Author: Simon Pilgrim
Date: 2023-04-03T17:17:52+01:00
New Revision: d891968f10c42568b2c5d19fa22802ee85d2bd7d

URL: https://github.com/llvm/llvm-project/commit/d891968f10c42568b2c5d19fa22802ee85d2bd7d
DIFF: https://github.com/llvm/llvm-project/commit/d891968f10c42568b2c5d19fa22802ee85d2bd7d.diff

LOG: [X86] combineMOVMSK - merge movmsk(icmp_eq(and(x,c1),c1)) and movmsk(icmp_eq(and(x,c1),0)) folds

Use the same value tracking implementation for both, removing hardcoded PCMPEQ(AND(X,C),C) pattern so to handle bitcasted logic/constants.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/bitcast-vector-bool.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 13a59ab7fb0e..1d8327f35004 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -54538,52 +54538,35 @@ static SDValue combineMOVMSK(SDNode *N, SelectionDAG &DAG,
   }
 
   // Fold movmsk(icmp_eq(and(x,c1),c1)) -> movmsk(shl(x,c2))
-  // iff pow2splat(c1).
-  // Use KnownBits to determine if only a single bit is non-zero
-  // in each element (pow2 or zero), and shift that bit to the msb.
-  // TODO: Merge with the movmsk(icmp_eq(and(x,c1),0)) fold below?
-  if (Src.getOpcode() == X86ISD::PCMPEQ &&
-      Src.getOperand(0).getOpcode() == ISD::AND &&
-      Src.getOperand(1) == Src.getOperand(0).getOperand(1)) {
-    KnownBits KnownSrc = DAG.computeKnownBits(Src.getOperand(1));
-    if (KnownSrc.countMaxPopulation() == 1) {
-      SDLoc DL(N);
-      MVT ShiftVT = SrcVT;
-      SDValue ShiftSrc = Src.getOperand(0);
-      if (ShiftVT.getScalarType() == MVT::i8) {
-        // vXi8 shifts - we only care about the signbit so can use PSLLW.
-        ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
-        ShiftSrc = DAG.getBitcast(ShiftVT, ShiftSrc);
-      }
-      unsigned ShiftAmt = KnownSrc.countMinLeadingZeros();
-      ShiftSrc = getTargetVShiftByConstNode(X86ISD::VSHLI, DL, ShiftVT,
-                                            ShiftSrc, ShiftAmt, DAG);
-      ShiftSrc = DAG.getBitcast(SrcVT, ShiftSrc);
-      return DAG.getNode(X86ISD::MOVMSK, DL, VT, ShiftSrc);
-    }
-  }
-
   // Fold movmsk(icmp_eq(and(x,c1),0)) -> movmsk(not(shl(x,c2)))
   // iff pow2splat(c1).
   // Use KnownBits to determine if only a single bit is non-zero
   // in each element (pow2 or zero), and shift that bit to the msb.
-  if (Src.getOpcode() == X86ISD::PCMPEQ &&
-      ISD::isBuildVectorAllZeros(Src.getOperand(1).getNode())) {
-    KnownBits KnownSrc = DAG.computeKnownBits(Src.getOperand(0));
-    if (KnownSrc.countMaxPopulation() == 1) {
+  if (Src.getOpcode() == X86ISD::PCMPEQ) {
+    KnownBits KnownLHS = DAG.computeKnownBits(Src.getOperand(0));
+    KnownBits KnownRHS = DAG.computeKnownBits(Src.getOperand(1));
+    unsigned ShiftAmt = KnownLHS.countMinLeadingZeros();
+    if (KnownLHS.countMaxPopulation() == 1 &&
+        (KnownRHS.isZero() || (KnownRHS.countMaxPopulation() == 1 &&
+                               ShiftAmt == KnownRHS.countMinLeadingZeros()))) {
       SDLoc DL(N);
       MVT ShiftVT = SrcVT;
-      SDValue ShiftSrc = Src.getOperand(0);
+      SDValue ShiftLHS = Src.getOperand(0);
+      SDValue ShiftRHS = Src.getOperand(1);
       if (ShiftVT.getScalarType() == MVT::i8) {
         // vXi8 shifts - we only care about the signbit so can use PSLLW.
         ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
-        ShiftSrc = DAG.getBitcast(ShiftVT, ShiftSrc);
-      }
-      unsigned ShiftAmt = KnownSrc.countMinLeadingZeros();
-      ShiftSrc = getTargetVShiftByConstNode(X86ISD::VSHLI, DL, ShiftVT,
-                                            ShiftSrc, ShiftAmt, DAG);
-      ShiftSrc = DAG.getNOT(DL, DAG.getBitcast(SrcVT, ShiftSrc), SrcVT);
-      return DAG.getNode(X86ISD::MOVMSK, DL, VT, ShiftSrc);
+        ShiftLHS = DAG.getBitcast(ShiftVT, ShiftLHS);
+        ShiftRHS = DAG.getBitcast(ShiftVT, ShiftRHS);
+      }
+      ShiftLHS = getTargetVShiftByConstNode(X86ISD::VSHLI, DL, ShiftVT,
+                                            ShiftLHS, ShiftAmt, DAG);
+      ShiftRHS = getTargetVShiftByConstNode(X86ISD::VSHLI, DL, ShiftVT,
+                                            ShiftRHS, ShiftAmt, DAG);
+      ShiftLHS = DAG.getBitcast(SrcVT, ShiftLHS);
+      ShiftRHS = DAG.getBitcast(SrcVT, ShiftRHS);
+      SDValue Res = DAG.getNode(ISD::XOR, DL, SrcVT, ShiftLHS, ShiftRHS);
+      return DAG.getNode(X86ISD::MOVMSK, DL, VT, DAG.getNOT(DL, Res, SrcVT));
     }
   }
 

diff  --git a/llvm/test/CodeGen/X86/bitcast-vector-bool.ll b/llvm/test/CodeGen/X86/bitcast-vector-bool.ll
index 7477044c86a7..049754119702 100644
--- a/llvm/test/CodeGen/X86/bitcast-vector-bool.ll
+++ b/llvm/test/CodeGen/X86/bitcast-vector-bool.ll
@@ -927,11 +927,10 @@ define i1 @trunc_v32i16_cmp(<32 x i16> %a0) nounwind {
 ; SSE2-SSSE3-NEXT:    pand %xmm3, %xmm1
 ; SSE2-SSSE3-NEXT:    pand %xmm2, %xmm0
 ; SSE2-SSSE3-NEXT:    pand %xmm1, %xmm0
-; SSE2-SSSE3-NEXT:    movdqa {{.*#+}} xmm1 = [1,1,1,1,1,1,1,1]
-; SSE2-SSSE3-NEXT:    pand %xmm1, %xmm0
-; SSE2-SSSE3-NEXT:    pcmpeqb %xmm1, %xmm0
+; SSE2-SSSE3-NEXT:    psllw $7, %xmm0
 ; SSE2-SSSE3-NEXT:    pmovmskb %xmm0, %eax
-; SSE2-SSSE3-NEXT:    xorl $65535, %eax # imm = 0xFFFF
+; SSE2-SSSE3-NEXT:    notl %eax
+; SSE2-SSSE3-NEXT:    testl $21845, %eax # imm = 0x5555
 ; SSE2-SSSE3-NEXT:    setne %al
 ; SSE2-SSSE3-NEXT:    retq
 ;


        


More information about the llvm-commits mailing list