[llvm] [X86] Handle shifts + and in `LowerSELECTWithCmpZero` (PR #107910)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 9 13:06:32 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: None (goldsteinn)

<details>
<summary>Changes</summary>

- **[X86] Add tests support shifts + and in `LowerSELECTWithCmpZero`; NFC**
- **[X86] Handle shifts + and in `LowerSELECTWithCmpZero`**

shifts are the same as sub where rhs == 0 is identity.
and is the inverted case where:
    `SELECT (AND(X,1) == 0), (AND Y, Z), Y`
        -> `(AND Y, (OR NEG(AND(X, 1)), Z))`
With -1 as the identity.

---
Full diff: https://github.com/llvm/llvm-project/pull/107910.diff


2 Files Affected:

- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+45-15) 
- (modified) llvm/test/CodeGen/X86/pull-conditional-binop-through-shift.ll (+182) 


``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 99477d76f50bce..47d77acecb387c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -24086,36 +24086,38 @@ static SDValue LowerSELECTWithCmpZero(SDValue CmpVal, SDValue LHS, SDValue RHS,
 
   if (X86CC == X86::COND_E && CmpVal.getOpcode() == ISD::AND &&
       isOneConstant(CmpVal.getOperand(1))) {
-    auto SplatLSB = [&]() {
+    auto SplatLSB = [&](EVT SplatVT) {
       // we need mask of all zeros or ones with same size of the other
       // operands.
       SDValue Neg = CmpVal;
-      if (CmpVT.bitsGT(VT))
-        Neg = DAG.getNode(ISD::TRUNCATE, DL, VT, CmpVal);
-      else if (CmpVT.bitsLT(VT))
+      if (CmpVT.bitsGT(SplatVT))
+        Neg = DAG.getNode(ISD::TRUNCATE, DL, SplatVT, CmpVal);
+      else if (CmpVT.bitsLT(SplatVT))
         Neg = DAG.getNode(
-            ISD::AND, DL, VT,
-            DAG.getNode(ISD::ANY_EXTEND, DL, VT, CmpVal.getOperand(0)),
-            DAG.getConstant(1, DL, VT));
-      return DAG.getNegative(Neg, DL, VT); // -(and (x, 0x1))
+            ISD::AND, DL, SplatVT,
+            DAG.getNode(ISD::ANY_EXTEND, DL, SplatVT, CmpVal.getOperand(0)),
+            DAG.getConstant(1, DL, SplatVT));
+      return DAG.getNegative(Neg, DL, SplatVT); // -(and (x, 0x1))
     };
 
     // SELECT (AND(X,1) == 0), 0, -1 -> NEG(AND(X,1))
     if (isNullConstant(LHS) && isAllOnesConstant(RHS))
-      return SplatLSB();
+      return SplatLSB(VT);
 
     // SELECT (AND(X,1) == 0), C1, C2 -> XOR(C1,AND(NEG(AND(X,1)),XOR(C1,C2))
     if (!Subtarget.canUseCMOV() && isa<ConstantSDNode>(LHS) &&
         isa<ConstantSDNode>(RHS)) {
-      SDValue Mask = SplatLSB();
+      SDValue Mask = SplatLSB(VT);
       SDValue Diff = DAG.getNode(ISD::XOR, DL, VT, LHS, RHS);
       SDValue Flip = DAG.getNode(ISD::AND, DL, VT, Mask, Diff);
       return DAG.getNode(ISD::XOR, DL, VT, LHS, Flip);
     }
 
     SDValue Src1, Src2;
-    auto isIdentityPattern = [&]() {
+    auto isIdentityPatternZero = [&]() {
       switch (RHS.getOpcode()) {
+      default:
+        break;
       case ISD::OR:
       case ISD::XOR:
       case ISD::ADD:
@@ -24125,6 +24127,9 @@ static SDValue LowerSELECTWithCmpZero(SDValue CmpVal, SDValue LHS, SDValue RHS,
           return true;
         }
         break;
+      case ISD::SHL:
+      case ISD::SRA:
+      case ISD::SRL:
       case ISD::SUB:
         if (RHS.getOperand(0) == LHS) {
           Src1 = RHS.getOperand(1);
@@ -24136,15 +24141,40 @@ static SDValue LowerSELECTWithCmpZero(SDValue CmpVal, SDValue LHS, SDValue RHS,
       return false;
     };
 
+    auto isIdentityPatternOnes = [&]() {
+      switch (LHS.getOpcode()) {
+      default:
+        break;
+      case ISD::AND:
+        if (LHS.getOperand(0) == RHS || LHS.getOperand(1) == RHS) {
+          Src1 = LHS.getOperand(LHS.getOperand(0) == RHS ? 1 : 0);
+          Src2 = RHS;
+          return true;
+        }
+        break;
+      }
+      return false;
+    };
+
     // Convert 'identity' patterns (iff X is 0 or 1):
     // SELECT (AND(X,1) == 0), Y, (OR Y, Z) -> (OR Y, (AND NEG(AND(X,1)), Z))
     // SELECT (AND(X,1) == 0), Y, (XOR Y, Z) -> (XOR Y, (AND NEG(AND(X,1)), Z))
     // SELECT (AND(X,1) == 0), Y, (ADD Y, Z) -> (ADD Y, (AND NEG(AND(X,1)), Z))
     // SELECT (AND(X,1) == 0), Y, (SUB Y, Z) -> (SUB Y, (AND NEG(AND(X,1)), Z))
-    if (!Subtarget.canUseCMOV() && isIdentityPattern()) {
-      SDValue Mask = SplatLSB();
-      SDValue And = DAG.getNode(ISD::AND, DL, VT, Mask, Src1); // Mask & z
-      return DAG.getNode(RHS.getOpcode(), DL, VT, Src2, And);  // y Op And
+    // SELECT (AND(X,1) == 0), Y, (SHL Y, Z) -> (SHL Y, (AND NEG(AND(X,1)), Z))
+    // SELECT (AND(X,1) == 0), Y, (SRA Y, Z) -> (SRA Y, (AND NEG(AND(X,1)), Z))
+    // SELECT (AND(X,1) == 0), Y, (SRL Y, Z) -> (SRL Y, (AND NEG(AND(X,1)), Z))
+    if (!Subtarget.canUseCMOV() && isIdentityPatternZero()) {
+      SDValue Mask = SplatLSB(Src1.getValueType());
+      SDValue And = DAG.getNode(ISD::AND, DL, Src1.getValueType(), Mask,
+                                Src1); // Mask & z
+      return DAG.getNode(RHS.getOpcode(), DL, VT, Src2, And); // y Op And
+    }
+    // SELECT (AND(X,1) == 0), (AND Y, Z), Y -> (AND Y, (OR NEG(AND(X, 1)), Z))
+    if (!Subtarget.canUseCMOV() && isIdentityPatternOnes()) {
+      SDValue Mask = SplatLSB(VT);
+      SDValue Or = DAG.getNode(ISD::OR, DL, VT, Mask, Src1); // Mask | z
+      return DAG.getNode(LHS.getOpcode(), DL, VT, Src2, Or); // y Op Or
     }
   }
 
diff --git a/llvm/test/CodeGen/X86/pull-conditional-binop-through-shift.ll b/llvm/test/CodeGen/X86/pull-conditional-binop-through-shift.ll
index 86a8a6a53248b7..8c858e04de2a14 100644
--- a/llvm/test/CodeGen/X86/pull-conditional-binop-through-shift.ll
+++ b/llvm/test/CodeGen/X86/pull-conditional-binop-through-shift.ll
@@ -697,3 +697,185 @@ define i32 @add_nosignbit_select_ashr(i32 %x, i1 %cond, ptr %dst) {
   store i32 %r, ptr %dst
   ret i32 %r
 }
+
+define i32 @shl_signbit_select_add(i32 %x, i1 %cond, ptr %dst) {
+; X64-LABEL: shl_signbit_select_add:
+; X64:       # %bb.0:
+; X64-NEXT:    movl %edi, %eax
+; X64-NEXT:    shll $4, %eax
+; X64-NEXT:    testb $1, %sil
+; X64-NEXT:    cmovel %edi, %eax
+; X64-NEXT:    addl $123456, %eax # imm = 0x1E240
+; X64-NEXT:    movl %eax, (%rdx)
+; X64-NEXT:    retq
+;
+; X86-LABEL: shl_signbit_select_add:
+; X86:       # %bb.0:
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT:    andb $1, %cl
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %edx
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    negb %cl
+; X86-NEXT:    andb $4, %cl
+; X86-NEXT:    shll %cl, %eax
+; X86-NEXT:    addl $123456, %eax # imm = 0x1E240
+; X86-NEXT:    movl %eax, (%edx)
+; X86-NEXT:    retl
+  %t0 = shl i32 %x, 4
+  %t1 = select i1 %cond, i32 %t0, i32 %x
+  %r = add i32 %t1, 123456
+  store i32 %r, ptr %dst
+  ret i32 %r
+}
+
+define i32 @shl_signbit_select_add_fail(i32 %x, i1 %cond, ptr %dst) {
+; X64-LABEL: shl_signbit_select_add_fail:
+; X64:       # %bb.0:
+; X64-NEXT:    movl %edi, %eax
+; X64-NEXT:    shll $4, %eax
+; X64-NEXT:    testb $1, %sil
+; X64-NEXT:    cmovnel %edi, %eax
+; X64-NEXT:    addl $123456, %eax # imm = 0x1E240
+; X64-NEXT:    movl %eax, (%rdx)
+; X64-NEXT:    retq
+;
+; X86-LABEL: shl_signbit_select_add_fail:
+; X86:       # %bb.0:
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    testb $1, {{[0-9]+}}(%esp)
+; X86-NEXT:    jne .LBB25_2
+; X86-NEXT:  # %bb.1:
+; X86-NEXT:    shll $4, %eax
+; X86-NEXT:  .LBB25_2:
+; X86-NEXT:    addl $123456, %eax # imm = 0x1E240
+; X86-NEXT:    movl %eax, (%ecx)
+; X86-NEXT:    retl
+  %t0 = shl i32 %x, 4
+  %t1 = select i1 %cond, i32 %x, i32 %t0
+  %r = add i32 %t1, 123456
+  store i32 %r, ptr %dst
+  ret i32 %r
+}
+
+define i32 @lshr_signbit_select_add(i32 %x, i1 %cond, ptr %dst, i32 %y) {
+; X64-LABEL: lshr_signbit_select_add:
+; X64:       # %bb.0:
+; X64-NEXT:    movl %edi, %eax
+; X64-NEXT:    # kill: def $cl killed $cl killed $ecx
+; X64-NEXT:    shrl %cl, %eax
+; X64-NEXT:    testb $1, %sil
+; X64-NEXT:    cmovel %edi, %eax
+; X64-NEXT:    addl $123456, %eax # imm = 0x1E240
+; X64-NEXT:    movl %eax, (%rdx)
+; X64-NEXT:    retq
+;
+; X86-LABEL: lshr_signbit_select_add:
+; X86:       # %bb.0:
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT:    andb $1, %cl
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %edx
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    negb %cl
+; X86-NEXT:    andb {{[0-9]+}}(%esp), %cl
+; X86-NEXT:    shrl %cl, %eax
+; X86-NEXT:    addl $123456, %eax # imm = 0x1E240
+; X86-NEXT:    movl %eax, (%edx)
+; X86-NEXT:    retl
+  %t0 = lshr i32 %x, %y
+  %t1 = select i1 %cond, i32 %t0, i32 %x
+  %r = add i32 %t1, 123456
+  store i32 %r, ptr %dst
+  ret i32 %r
+}
+
+define i32 @ashr_signbit_select_add(i32 %x, i1 %cond, ptr %dst) {
+; X64-LABEL: ashr_signbit_select_add:
+; X64:       # %bb.0:
+; X64-NEXT:    movl %edi, %eax
+; X64-NEXT:    sarl $4, %eax
+; X64-NEXT:    testb $1, %sil
+; X64-NEXT:    cmovel %edi, %eax
+; X64-NEXT:    addl $123456, %eax # imm = 0x1E240
+; X64-NEXT:    movl %eax, (%rdx)
+; X64-NEXT:    retq
+;
+; X86-LABEL: ashr_signbit_select_add:
+; X86:       # %bb.0:
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT:    andb $1, %cl
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %edx
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    negb %cl
+; X86-NEXT:    andb $4, %cl
+; X86-NEXT:    sarl %cl, %eax
+; X86-NEXT:    addl $123456, %eax # imm = 0x1E240
+; X86-NEXT:    movl %eax, (%edx)
+; X86-NEXT:    retl
+  %t0 = ashr i32 %x, 4
+  %t1 = select i1 %cond, i32 %t0, i32 %x
+  %r = add i32 %t1, 123456
+  store i32 %r, ptr %dst
+  ret i32 %r
+}
+
+define i32 @and_signbit_select_add(i32 %x, i1 %cond, ptr %dst, i32 %y) {
+; X64-LABEL: and_signbit_select_add:
+; X64:       # %bb.0:
+; X64-NEXT:    # kill: def $ecx killed $ecx def $rcx
+; X64-NEXT:    andl %edi, %ecx
+; X64-NEXT:    testb $1, %sil
+; X64-NEXT:    cmovnel %edi, %ecx
+; X64-NEXT:    leal 123456(%rcx), %eax
+; X64-NEXT:    movl %eax, (%rdx)
+; X64-NEXT:    retq
+;
+; X86-LABEL: and_signbit_select_add:
+; X86:       # %bb.0:
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT:    movzbl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    andl $1, %eax
+; X86-NEXT:    negl %eax
+; X86-NEXT:    orl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    andl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    addl $123456, %eax # imm = 0x1E240
+; X86-NEXT:    movl %eax, (%ecx)
+; X86-NEXT:    retl
+  %t0 = and i32 %x, %y
+  %t1 = select i1 %cond, i32 %x, i32 %t0
+  %r = add i32 %t1, 123456
+  store i32 %r, ptr %dst
+  ret i32 %r
+}
+
+
+define i32 @and_signbit_select_add_fail(i32 %x, i1 %cond, ptr %dst, i32 %y) {
+; X64-LABEL: and_signbit_select_add_fail:
+; X64:       # %bb.0:
+; X64-NEXT:    # kill: def $ecx killed $ecx def $rcx
+; X64-NEXT:    andl %edi, %ecx
+; X64-NEXT:    testb $1, %sil
+; X64-NEXT:    cmovel %edi, %ecx
+; X64-NEXT:    leal 123456(%rcx), %eax
+; X64-NEXT:    movl %eax, (%rdx)
+; X64-NEXT:    retq
+;
+; X86-LABEL: and_signbit_select_add_fail:
+; X86:       # %bb.0:
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    testb $1, {{[0-9]+}}(%esp)
+; X86-NEXT:    je .LBB29_2
+; X86-NEXT:  # %bb.1:
+; X86-NEXT:    andl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:  .LBB29_2:
+; X86-NEXT:    addl $123456, %eax # imm = 0x1E240
+; X86-NEXT:    movl %eax, (%ecx)
+; X86-NEXT:    retl
+  %t0 = and i32 %x, %y
+  %t1 = select i1 %cond, i32 %t0, i32 %x
+  %r = add i32 %t1, 123456
+  store i32 %r, ptr %dst
+  ret i32 %r
+}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/107910


More information about the llvm-commits mailing list