[llvm] 2b7b8bd - [X86] Accept the canonical form of a sign bit test in MatchVectorAllEqualTest. (#154421)

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 20 09:09:58 PDT 2025


Author: Craig Topper
Date: 2025-08-20T09:09:55-07:00
New Revision: 2b7b8bdc165a1f9fa933fe531d6a5b152d066297

URL: https://github.com/llvm/llvm-project/commit/2b7b8bdc165a1f9fa933fe531d6a5b152d066297
DIFF: https://github.com/llvm/llvm-project/commit/2b7b8bdc165a1f9fa933fe531d6a5b152d066297.diff

LOG: [X86] Accept the canonical form of a sign bit test in MatchVectorAllEqualTest. (#154421)

This function tries to look for (seteq (and (reduce_or), mask), 0). If
the mask is a sign bit, InstCombine will have turned it into (setgt
(reduce_or), -1). We should handle that case too.

I'm looking into adding the same canonicalization to SimplifySetCC and
this change is needed to prevent test regressions.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 2c726a9f7f6c9..19131fbd4102b 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -23185,43 +23185,51 @@ static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
 
 // Check whether an AND/OR'd reduction tree is PTEST-able, or if we can fallback
 // to CMP(MOVMSK(PCMPEQB(X,Y))).
-static SDValue MatchVectorAllEqualTest(SDValue LHS, SDValue RHS,
+static SDValue MatchVectorAllEqualTest(SDValue OrigLHS, SDValue OrigRHS,
                                        ISD::CondCode CC, const SDLoc &DL,
                                        const X86Subtarget &Subtarget,
                                        SelectionDAG &DAG,
                                        X86::CondCode &X86CC) {
-  assert((CC == ISD::SETEQ || CC == ISD::SETNE) && "Unsupported ISD::CondCode");
+  SDValue Op = OrigLHS;
 
-  bool CmpNull = isNullConstant(RHS);
-  bool CmpAllOnes = isAllOnesConstant(RHS);
-  if (!CmpNull && !CmpAllOnes)
-    return SDValue();
+  bool CmpNull;
+  APInt Mask;
+  if (CC == ISD::SETEQ || CC == ISD::SETNE) {
+    CmpNull = isNullConstant(OrigRHS);
+    if (!CmpNull && !isAllOnesConstant(OrigRHS))
+      return SDValue();
 
-  SDValue Op = LHS;
-  if (!Subtarget.hasSSE2() || !Op->hasOneUse())
-    return SDValue();
+    if (!Subtarget.hasSSE2() || !Op->hasOneUse())
+      return SDValue();
 
-  // Check whether we're masking/truncating an OR-reduction result, in which
-  // case track the masked bits.
-  // TODO: Add CmpAllOnes support.
-  APInt Mask = APInt::getAllOnes(Op.getScalarValueSizeInBits());
-  if (CmpNull) {
-    switch (Op.getOpcode()) {
-    case ISD::TRUNCATE: {
-      SDValue Src = Op.getOperand(0);
-      Mask = APInt::getLowBitsSet(Src.getScalarValueSizeInBits(),
-                                  Op.getScalarValueSizeInBits());
-      Op = Src;
-      break;
-    }
-    case ISD::AND: {
-      if (auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
-        Mask = Cst->getAPIntValue();
-        Op = Op.getOperand(0);
+    // Check whether we're masking/truncating an OR-reduction result, in which
+    // case track the masked bits.
+    // TODO: Add CmpAllOnes support.
+    Mask = APInt::getAllOnes(Op.getScalarValueSizeInBits());
+    if (CmpNull) {
+      switch (Op.getOpcode()) {
+      case ISD::TRUNCATE: {
+        SDValue Src = Op.getOperand(0);
+        Mask = APInt::getLowBitsSet(Src.getScalarValueSizeInBits(),
+                                    Op.getScalarValueSizeInBits());
+        Op = Src;
+        break;
+      }
+      case ISD::AND: {
+        if (auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+          Mask = Cst->getAPIntValue();
+          Op = Op.getOperand(0);
+        }
+        break;
+      }
       }
-      break;
-    }
     }
+  } else if (CC == ISD::SETGT && isAllOnesConstant(OrigRHS)) {
+    CC = ISD::SETEQ;
+    CmpNull = true;
+    Mask = APInt::getSignMask(Op.getScalarValueSizeInBits());
+  } else {
+    return SDValue();
   }
 
   ISD::NodeType LogicOp = CmpNull ? ISD::OR : ISD::AND;
@@ -56274,14 +56282,16 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
     if (SDValue V = combineVectorSizedSetCCEquality(VT, LHS, RHS, CC, DL, DAG,
                                                     Subtarget))
       return V;
+  }
 
-    if (VT == MVT::i1) {
-      X86::CondCode X86CC;
-      if (SDValue V =
-              MatchVectorAllEqualTest(LHS, RHS, CC, DL, Subtarget, DAG, X86CC))
-        return DAG.getNode(ISD::TRUNCATE, DL, VT, getSETCC(X86CC, V, DL, DAG));
-    }
+  if (VT == MVT::i1) {
+    X86::CondCode X86CC;
+    if (SDValue V =
+            MatchVectorAllEqualTest(LHS, RHS, CC, DL, Subtarget, DAG, X86CC))
+      return DAG.getNode(ISD::TRUNCATE, DL, VT, getSETCC(X86CC, V, DL, DAG));
+  }
 
+  if (CC == ISD::SETNE || CC == ISD::SETEQ) {
     if (OpVT.isScalarInteger()) {
       // cmpeq(or(X,Y),X) --> cmpeq(and(~X,Y),0)
       // cmpne(or(X,Y),X) --> cmpne(and(~X,Y),0)

diff  --git a/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll b/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll
index 9cd0f4d12e15a..227e000c6be7f 100644
--- a/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll
+++ b/llvm/test/CodeGen/X86/vector-reduce-or-cmp.ll
@@ -903,6 +903,95 @@ define i1 @mask_v8i32(<8 x i32> %a0) {
   ret i1 %3
 }
 
+define i1 @mask_v8i32_2(<8 x i32> %a0) {
+; SSE2-LABEL: mask_v8i32_2:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    por %xmm1, %xmm0
+; SSE2-NEXT:    pslld $1, %xmm0
+; SSE2-NEXT:    movmskps %xmm0, %eax
+; SSE2-NEXT:    testl %eax, %eax
+; SSE2-NEXT:    sete %al
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: mask_v8i32_2:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    por %xmm1, %xmm0
+; SSE41-NEXT:    ptest {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; SSE41-NEXT:    sete %al
+; SSE41-NEXT:    retq
+;
+; AVX1-LABEL: mask_v8i32_2:
+; AVX1:       # %bb.0:
+; AVX1-NEXT:    vptest {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0
+; AVX1-NEXT:    sete %al
+; AVX1-NEXT:    vzeroupper
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: mask_v8i32_2:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpbroadcastq {{.*#+}} ymm1 = [4611686019501129728,4611686019501129728,4611686019501129728,4611686019501129728]
+; AVX2-NEXT:    vptest %ymm1, %ymm0
+; AVX2-NEXT:    sete %al
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: mask_v8i32_2:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpbroadcastq {{.*#+}} ymm1 = [4611686019501129728,4611686019501129728,4611686019501129728,4611686019501129728]
+; AVX512-NEXT:    vptest %ymm1, %ymm0
+; AVX512-NEXT:    sete %al
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+  %1 = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> %a0)
+  %2 = and i32 %1, 1073741824
+  %3 = icmp eq i32 %2, 0
+  ret i1 %3
+}
+
+
+define i1 @signtest_v8i32(<8 x i32> %a0) {
+; SSE2-LABEL: signtest_v8i32:
+; SSE2:       # %bb.0:
+; SSE2-NEXT:    orps %xmm1, %xmm0
+; SSE2-NEXT:    movmskps %xmm0, %eax
+; SSE2-NEXT:    testl %eax, %eax
+; SSE2-NEXT:    sete %al
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: signtest_v8i32:
+; SSE41:       # %bb.0:
+; SSE41-NEXT:    por %xmm1, %xmm0
+; SSE41-NEXT:    ptest {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; SSE41-NEXT:    sete %al
+; SSE41-NEXT:    retq
+;
+; AVX1-LABEL: signtest_v8i32:
+; AVX1:       # %bb.0:
+; AVX1-NEXT:    vptest {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0
+; AVX1-NEXT:    sete %al
+; AVX1-NEXT:    vzeroupper
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: signtest_v8i32:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpbroadcastq {{.*#+}} ymm1 = [9223372039002259456,9223372039002259456,9223372039002259456,9223372039002259456]
+; AVX2-NEXT:    vptest %ymm1, %ymm0
+; AVX2-NEXT:    sete %al
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+;
+; AVX512-LABEL: signtest_v8i32:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpbroadcastq {{.*#+}} ymm1 = [9223372039002259456,9223372039002259456,9223372039002259456,9223372039002259456]
+; AVX512-NEXT:    vptest %ymm1, %ymm0
+; AVX512-NEXT:    sete %al
+; AVX512-NEXT:    vzeroupper
+; AVX512-NEXT:    retq
+  %1 = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> %a0)
+  %2 = icmp sgt i32 %1, -1
+  ret i1 %2
+}
+
 define i1 @trunc_v16i16(<16 x i16> %a0) {
 ; SSE2-LABEL: trunc_v16i16:
 ; SSE2:       # %bb.0:
@@ -1073,11 +1162,11 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
 ; SSE2-NEXT:    movd %xmm0, %eax
 ; SSE2-NEXT:    orl %ecx, %eax
 ; SSE2-NEXT:    testb $1, %al
-; SSE2-NEXT:    je .LBB27_2
+; SSE2-NEXT:    je .LBB29_2
 ; SSE2-NEXT:  # %bb.1:
 ; SSE2-NEXT:    xorl %eax, %eax
 ; SSE2-NEXT:    retq
-; SSE2-NEXT:  .LBB27_2:
+; SSE2-NEXT:  .LBB29_2:
 ; SSE2-NEXT:    movl $1, %eax
 ; SSE2-NEXT:    retq
 ;
@@ -1092,11 +1181,11 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
 ; SSE41-NEXT:    pextrd $2, %xmm1, %eax
 ; SSE41-NEXT:    orl %ecx, %eax
 ; SSE41-NEXT:    testb $1, %al
-; SSE41-NEXT:    je .LBB27_2
+; SSE41-NEXT:    je .LBB29_2
 ; SSE41-NEXT:  # %bb.1:
 ; SSE41-NEXT:    xorl %eax, %eax
 ; SSE41-NEXT:    retq
-; SSE41-NEXT:  .LBB27_2:
+; SSE41-NEXT:  .LBB29_2:
 ; SSE41-NEXT:    movl $1, %eax
 ; SSE41-NEXT:    retq
 ;
@@ -1111,11 +1200,11 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
 ; AVX1OR2-NEXT:    vpextrd $2, %xmm0, %eax
 ; AVX1OR2-NEXT:    orl %ecx, %eax
 ; AVX1OR2-NEXT:    testb $1, %al
-; AVX1OR2-NEXT:    je .LBB27_2
+; AVX1OR2-NEXT:    je .LBB29_2
 ; AVX1OR2-NEXT:  # %bb.1:
 ; AVX1OR2-NEXT:    xorl %eax, %eax
 ; AVX1OR2-NEXT:    retq
-; AVX1OR2-NEXT:  .LBB27_2:
+; AVX1OR2-NEXT:  .LBB29_2:
 ; AVX1OR2-NEXT:    movl $1, %eax
 ; AVX1OR2-NEXT:    retq
 ;
@@ -1130,12 +1219,12 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
 ; AVX512F-NEXT:    korw %k0, %k1, %k0
 ; AVX512F-NEXT:    kmovw %k0, %eax
 ; AVX512F-NEXT:    testb $1, %al
-; AVX512F-NEXT:    je .LBB27_2
+; AVX512F-NEXT:    je .LBB29_2
 ; AVX512F-NEXT:  # %bb.1:
 ; AVX512F-NEXT:    xorl %eax, %eax
 ; AVX512F-NEXT:    vzeroupper
 ; AVX512F-NEXT:    retq
-; AVX512F-NEXT:  .LBB27_2:
+; AVX512F-NEXT:  .LBB29_2:
 ; AVX512F-NEXT:    movl $1, %eax
 ; AVX512F-NEXT:    vzeroupper
 ; AVX512F-NEXT:    retq
@@ -1151,12 +1240,12 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
 ; AVX512BW-NEXT:    korw %k0, %k1, %k0
 ; AVX512BW-NEXT:    kmovd %k0, %eax
 ; AVX512BW-NEXT:    testb $1, %al
-; AVX512BW-NEXT:    je .LBB27_2
+; AVX512BW-NEXT:    je .LBB29_2
 ; AVX512BW-NEXT:  # %bb.1:
 ; AVX512BW-NEXT:    xorl %eax, %eax
 ; AVX512BW-NEXT:    vzeroupper
 ; AVX512BW-NEXT:    retq
-; AVX512BW-NEXT:  .LBB27_2:
+; AVX512BW-NEXT:  .LBB29_2:
 ; AVX512BW-NEXT:    movl $1, %eax
 ; AVX512BW-NEXT:    vzeroupper
 ; AVX512BW-NEXT:    retq
@@ -1170,11 +1259,11 @@ define i32 @mask_v3i1(<3 x i32> %a, <3 x i32> %b) {
 ; AVX512BWVL-NEXT:    korw %k0, %k1, %k0
 ; AVX512BWVL-NEXT:    kmovd %k0, %eax
 ; AVX512BWVL-NEXT:    testb $1, %al
-; AVX512BWVL-NEXT:    je .LBB27_2
+; AVX512BWVL-NEXT:    je .LBB29_2
 ; AVX512BWVL-NEXT:  # %bb.1:
 ; AVX512BWVL-NEXT:    xorl %eax, %eax
 ; AVX512BWVL-NEXT:    retq
-; AVX512BWVL-NEXT:  .LBB27_2:
+; AVX512BWVL-NEXT:  .LBB29_2:
 ; AVX512BWVL-NEXT:    movl $1, %eax
 ; AVX512BWVL-NEXT:    retq
   %1 = icmp ne <3 x i32> %a, %b


        


More information about the llvm-commits mailing list