[llvm] 211147c - [AArch64] Convert CMP/SELECT sign patterns to OR & ASR.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 16 09:17:55 PST 2021


Author: Florian Hahn
Date: 2021-02-16T17:17:34Z
New Revision: 211147c5ba49a17c8624186f50519885d89ca33d

URL: https://github.com/llvm/llvm-project/commit/211147c5ba49a17c8624186f50519885d89ca33d
DIFF: https://github.com/llvm/llvm-project/commit/211147c5ba49a17c8624186f50519885d89ca33d.diff

LOG: [AArch64] Convert CMP/SELECT sign patterns to OR & ASR.

ICMP & SELECT patterns extracting the sign of a value can be simplified
to OR & ASR (see  https://alive2.llvm.org/ce/z/Xx4iZ0).

This does not save any instructions in IR, but it is profitable on
AArch64, because we need at least 2 extra instructions to materialize 1
and -1 for the SELECT.

The improvements result in ~5% speedups on loops of the form

    static int sign_of(int x) {
      if (x < 0) return -1;
      return 1;
    }

    void foo(const int *x, int *res, int cnt) {
      for (int i=0;i<cnt;i++)
        res[i] = sign_of(x[i]);
    }

Reviewed By: dmgreen

Differential Revision: https://reviews.llvm.org/D96596

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/cmp-select-sign.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e866fc527a35..ea6ec8a258cf 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6707,13 +6707,26 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
     assert((LHS.getValueType() == RHS.getValueType()) &&
            (LHS.getValueType() == MVT::i32 || LHS.getValueType() == MVT::i64));
 
+    ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal);
+    ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal);
+    ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS);
+    // Check for sign pattern (SELECT_CC setgt, iN lhs, -1, 1, -1) and transform
+    // into (OR (ASR lhs, N-1), 1), which requires less instructions for the
+    // supported types.
+    if (CC == ISD::SETGT && RHSC && RHSC->isAllOnesValue() && CTVal && CFVal &&
+        CTVal->isOne() && CFVal->isAllOnesValue() &&
+        LHS.getValueType() == TVal.getValueType()) {
+      EVT VT = LHS.getValueType();
+      SDValue Shift =
+          DAG.getNode(ISD::SRA, dl, VT, LHS,
+                      DAG.getConstant(VT.getSizeInBits() - 1, dl, VT));
+      return DAG.getNode(ISD::OR, dl, VT, Shift, DAG.getConstant(1, dl, VT));
+    }
+
     unsigned Opcode = AArch64ISD::CSEL;
 
     // If both the TVal and the FVal are constants, see if we can swap them in
     // order to for a CSINV or CSINC out of them.
-    ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal);
-    ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal);
-
     if (CTVal && CFVal && CTVal->isAllOnesValue() && CFVal->isNullValue()) {
       std::swap(TVal, FVal);
       std::swap(CTVal, CFVal);
@@ -6916,7 +6929,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
   if (CCVal.getOpcode() == ISD::SETCC) {
     LHS = CCVal.getOperand(0);
     RHS = CCVal.getOperand(1);
-    CC = cast<CondCodeSDNode>(CCVal->getOperand(2))->get();
+    CC = cast<CondCodeSDNode>(CCVal.getOperand(2))->get();
   } else {
     LHS = CCVal;
     RHS = DAG.getConstant(0, DL, CCVal.getValueType());
@@ -14970,6 +14983,39 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
   SDValue N0 = N->getOperand(0);
   EVT CCVT = N0.getValueType();
 
+  // Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform
+  // into (OR (ASR lhs, N-1), 1), which requires less instructions for the
+  // supported types.
+  SDValue SetCC = N->getOperand(0);
+  if (SetCC.getOpcode() == ISD::SETCC &&
+      SetCC.getOperand(2) == DAG.getCondCode(ISD::SETGT)) {
+    SDValue CmpLHS = SetCC.getOperand(0);
+    EVT VT = CmpLHS.getValueType();
+    SDNode *CmpRHS = SetCC.getOperand(1).getNode();
+    SDNode *SplatLHS = N->getOperand(1).getNode();
+    SDNode *SplatRHS = N->getOperand(2).getNode();
+    APInt SplatLHSVal;
+    if (CmpLHS.getValueType() == N->getOperand(1).getValueType() &&
+        VT.isSimple() &&
+        is_contained(
+            makeArrayRef({MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
+                          MVT::v2i32, MVT::v4i32, MVT::v2i64}),
+            VT.getSimpleVT().SimpleTy) &&
+        ISD::isConstantSplatVector(SplatLHS, SplatLHSVal) &&
+        SplatLHSVal.isOneValue() && ISD::isConstantSplatVectorAllOnes(CmpRHS) &&
+        ISD::isConstantSplatVectorAllOnes(SplatRHS)) {
+      unsigned NumElts = VT.getVectorNumElements();
+      SmallVector<SDValue, 8> Ops(
+          NumElts, DAG.getConstant(VT.getScalarSizeInBits() - 1, SDLoc(N),
+                                   VT.getScalarType()));
+      SDValue Val = DAG.getBuildVector(VT, SDLoc(N), Ops);
+
+      auto Shift = DAG.getNode(ISD::SRA, SDLoc(N), VT, CmpLHS, Val);
+      auto Or = DAG.getNode(ISD::OR, SDLoc(N), VT, Shift, N->getOperand(1));
+      return Or;
+    }
+  }
+
   if (N0.getOpcode() != ISD::SETCC || CCVT.getVectorNumElements() != 1 ||
       CCVT.getVectorElementType() != MVT::i1)
     return SDValue();
@@ -14983,10 +15029,9 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
 
   SDValue IfTrue = N->getOperand(1);
   SDValue IfFalse = N->getOperand(2);
-  SDValue SetCC =
-      DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(),
-                   N0.getOperand(0), N0.getOperand(1),
-                   cast<CondCodeSDNode>(N0.getOperand(2))->get());
+  SetCC = DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(),
+                       N0.getOperand(0), N0.getOperand(1),
+                       cast<CondCodeSDNode>(N0.getOperand(2))->get());
   return DAG.getNode(ISD::VSELECT, SDLoc(N), ResVT, SetCC,
                      IfTrue, IfFalse);
 }

diff  --git a/llvm/test/CodeGen/AArch64/cmp-select-sign.ll b/llvm/test/CodeGen/AArch64/cmp-select-sign.ll
index 55a11307487d..f58fb55c4405 100644
--- a/llvm/test/CodeGen/AArch64/cmp-select-sign.ll
+++ b/llvm/test/CodeGen/AArch64/cmp-select-sign.ll
@@ -4,10 +4,8 @@
 define i3 @sign_i3(i3 %a) {
 ; CHECK-LABEL: sign_i3:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    sbfx w8, w0, #0, #3
-; CHECK-NEXT:    cmp w8, #0 // =0
-; CHECK-NEXT:    mov w8, #1
-; CHECK-NEXT:    cneg w0, w8, lt
+; CHECK-NEXT:    sbfx w8, w0, #2, #1
+; CHECK-NEXT:    orr w0, w8, #0x1
 ; CHECK-NEXT:    ret
   %c = icmp sgt i3 %a, -1
   %res = select i1 %c, i3 1, i3 -1
@@ -17,10 +15,8 @@ define i3 @sign_i3(i3 %a) {
 define i4 @sign_i4(i4 %a) {
 ; CHECK-LABEL: sign_i4:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    sbfx w8, w0, #0, #4
-; CHECK-NEXT:    cmp w8, #0 // =0
-; CHECK-NEXT:    mov w8, #1
-; CHECK-NEXT:    cneg w0, w8, lt
+; CHECK-NEXT:    sbfx w8, w0, #3, #1
+; CHECK-NEXT:    orr w0, w8, #0x1
 ; CHECK-NEXT:    ret
   %c = icmp sgt i4 %a, -1
   %res = select i1 %c, i4 1, i4 -1
@@ -30,10 +26,8 @@ define i4 @sign_i4(i4 %a) {
 define i8 @sign_i8(i8 %a) {
 ; CHECK-LABEL: sign_i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    sxtb w8, w0
-; CHECK-NEXT:    cmp w8, #0 // =0
-; CHECK-NEXT:    mov w8, #1
-; CHECK-NEXT:    cneg w0, w8, lt
+; CHECK-NEXT:    sbfx w8, w0, #7, #1
+; CHECK-NEXT:    orr w0, w8, #0x1
 ; CHECK-NEXT:    ret
   %c = icmp sgt i8 %a, -1
   %res = select i1 %c, i8 1, i8 -1
@@ -43,10 +37,8 @@ define i8 @sign_i8(i8 %a) {
 define i16 @sign_i16(i16 %a) {
 ; CHECK-LABEL: sign_i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    sxth w8, w0
-; CHECK-NEXT:    cmp w8, #0 // =0
-; CHECK-NEXT:    mov w8, #1
-; CHECK-NEXT:    cneg w0, w8, lt
+; CHECK-NEXT:    sbfx w8, w0, #15, #1
+; CHECK-NEXT:    orr w0, w8, #0x1
 ; CHECK-NEXT:    ret
   %c = icmp sgt i16 %a, -1
   %res = select i1 %c, i16 1, i16 -1
@@ -56,9 +48,8 @@ define i16 @sign_i16(i16 %a) {
 define i32 @sign_i32(i32 %a) {
 ; CHECK-LABEL: sign_i32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    cmp w0, #0 // =0
-; CHECK-NEXT:    mov w8, #1
-; CHECK-NEXT:    cneg w0, w8, lt
+; CHECK-NEXT:    asr w8, w0, #31
+; CHECK-NEXT:    orr w0, w8, #0x1
 ; CHECK-NEXT:    ret
   %c = icmp sgt i32 %a, -1
   %res = select i1 %c, i32 1, i32 -1
@@ -68,9 +59,8 @@ define i32 @sign_i32(i32 %a) {
 define i64 @sign_i64(i64 %a) {
 ; CHECK-LABEL: sign_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    cmp x0, #0 // =0
-; CHECK-NEXT:    mov w8, #1
-; CHECK-NEXT:    cneg x0, x8, lt
+; CHECK-NEXT:    asr x8, x0, #63
+; CHECK-NEXT:    orr x0, x8, #0x1
 ; CHECK-NEXT:    ret
   %c = icmp sgt i64 %a, -1
   %res = select i1 %c, i64 1, i64 -1
@@ -124,11 +114,9 @@ define i64 @not_sign_i64_4(i64 %a) {
 define <7 x i8> @sign_7xi8(<7 x i8> %a) {
 ; CHECK-LABEL: sign_7xi8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    movi v1.2d, #0xffffffffffffffff
-; CHECK-NEXT:    cmgt v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    sshr v0.8b, v0.8b, #7
 ; CHECK-NEXT:    movi v1.8b, #1
-; CHECK-NEXT:    and v1.8b, v0.8b, v1.8b
-; CHECK-NEXT:    orn v0.8b, v1.8b, v0.8b
+; CHECK-NEXT:    orr v0.8b, v0.8b, v1.8b
 ; CHECK-NEXT:    ret
   %c = icmp sgt <7 x i8> %a, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
   %res = select <7 x i1> %c, <7 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, <7 x i8> <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
@@ -138,11 +126,9 @@ define <7 x i8> @sign_7xi8(<7 x i8> %a) {
 define <8 x i8> @sign_8xi8(<8 x i8> %a) {
 ; CHECK-LABEL: sign_8xi8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    movi v1.2d, #0xffffffffffffffff
-; CHECK-NEXT:    cmgt v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    sshr v0.8b, v0.8b, #7
 ; CHECK-NEXT:    movi v1.8b, #1
-; CHECK-NEXT:    and v1.8b, v0.8b, v1.8b
-; CHECK-NEXT:    orn v0.8b, v1.8b, v0.8b
+; CHECK-NEXT:    orr v0.8b, v0.8b, v1.8b
 ; CHECK-NEXT:    ret
   %c = icmp sgt <8 x i8> %a, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
   %res = select <8 x i1> %c, <8 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, <8 x i8> <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
@@ -152,11 +138,9 @@ define <8 x i8> @sign_8xi8(<8 x i8> %a) {
 define <16 x i8> @sign_16xi8(<16 x i8> %a) {
 ; CHECK-LABEL: sign_16xi8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    movi v1.2d, #0xffffffffffffffff
-; CHECK-NEXT:    cmgt v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    sshr v0.16b, v0.16b, #7
 ; CHECK-NEXT:    movi v1.16b, #1
-; CHECK-NEXT:    and v1.16b, v0.16b, v1.16b
-; CHECK-NEXT:    orn v0.16b, v1.16b, v0.16b
+; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
 ; CHECK-NEXT:    ret
   %c = icmp sgt <16 x i8> %a, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
   %res = select <16 x i1> %c, <16 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, <16 x i8> <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
@@ -166,11 +150,8 @@ define <16 x i8> @sign_16xi8(<16 x i8> %a) {
 define <3 x i32> @sign_3xi32(<3 x i32> %a) {
 ; CHECK-LABEL: sign_3xi32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    movi v1.2d, #0xffffffffffffffff
-; CHECK-NEXT:    cmgt v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    movi v1.4s, #1
-; CHECK-NEXT:    and v1.16b, v0.16b, v1.16b
-; CHECK-NEXT:    orn v0.16b, v1.16b, v0.16b
+; CHECK-NEXT:    sshr v0.4s, v0.4s, #31
+; CHECK-NEXT:    orr v0.4s, #1
 ; CHECK-NEXT:    ret
   %c = icmp sgt <3 x i32> %a, <i32 -1, i32 -1, i32 -1>
   %res = select <3 x i1> %c, <3 x i32> <i32 1, i32 1, i32 1>, <3 x i32> <i32 -1, i32 -1, i32 -1>
@@ -180,11 +161,8 @@ define <3 x i32> @sign_3xi32(<3 x i32> %a) {
 define <4 x i32> @sign_4xi32(<4 x i32> %a) {
 ; CHECK-LABEL: sign_4xi32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    movi v1.2d, #0xffffffffffffffff
-; CHECK-NEXT:    cmgt v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    movi v1.4s, #1
-; CHECK-NEXT:    and v1.16b, v0.16b, v1.16b
-; CHECK-NEXT:    orn v0.16b, v1.16b, v0.16b
+; CHECK-NEXT:    sshr v0.4s, v0.4s, #31
+; CHECK-NEXT:    orr v0.4s, #1
 ; CHECK-NEXT:    ret
   %c = icmp sgt <4 x i32> %a, <i32 -1, i32 -1, i32 -1, i32 -1>
   %res = select <4 x i1> %c, <4 x i32> <i32 1, i32 1, i32 1, i32 1>, <4 x i32> <i32 -1, i32 -1, i32 -1, i32 -1>
@@ -199,12 +177,11 @@ define <4 x i32> @sign_4xi32_multi_use(<4 x i32> %a) {
 ; CHECK-NEXT:    .cfi_def_cfa_offset 32
 ; CHECK-NEXT:    .cfi_offset w30, -16
 ; CHECK-NEXT:    movi v1.2d, #0xffffffffffffffff
-; CHECK-NEXT:    movi v2.4s, #1
+; CHECK-NEXT:    sshr v2.4s, v0.4s, #31
 ; CHECK-NEXT:    cmgt v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    and v1.16b, v0.16b, v2.16b
-; CHECK-NEXT:    orn v1.16b, v1.16b, v0.16b
+; CHECK-NEXT:    orr v2.4s, #1
 ; CHECK-NEXT:    xtn v0.4h, v0.4s
-; CHECK-NEXT:    str q1, [sp] // 16-byte Folded Spill
+; CHECK-NEXT:    str q2, [sp] // 16-byte Folded Spill
 ; CHECK-NEXT:    bl use_4xi1
 ; CHECK-NEXT:    ldr q0, [sp] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldr x30, [sp, #16] // 8-byte Folded Reload
@@ -268,25 +245,20 @@ define <4 x i32> @not_sign_4xi32_3(<4 x i32> %a) {
 define <4 x i65> @sign_4xi65(<4 x i65> %a) {
 ; CHECK-LABEL: sign_4xi65:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    sbfx x11, x3, #0, #1
-; CHECK-NEXT:    sbfx x10, x5, #0, #1
-; CHECK-NEXT:    mov w12, #1
-; CHECK-NEXT:    cmp x11, #0 // =0
-; CHECK-NEXT:    sbfx x9, x7, #0, #1
-; CHECK-NEXT:    cneg x2, x12, lt
-; CHECK-NEXT:    cmp x10, #0 // =0
 ; CHECK-NEXT:    sbfx x8, x1, #0, #1
-; CHECK-NEXT:    cneg x4, x12, lt
-; CHECK-NEXT:    cmp x9, #0 // =0
-; CHECK-NEXT:    cneg x6, x12, lt
-; CHECK-NEXT:    cmp x8, #0 // =0
-; CHECK-NEXT:    lsr x5, x10, #63
-; CHECK-NEXT:    cneg x10, x12, lt
+; CHECK-NEXT:    sbfx x9, x7, #0, #1
+; CHECK-NEXT:    orr x6, x9, #0x1
+; CHECK-NEXT:    lsr x7, x9, #63
+; CHECK-NEXT:    orr x9, x8, #0x1
 ; CHECK-NEXT:    lsr x1, x8, #63
-; CHECK-NEXT:    fmov d0, x10
+; CHECK-NEXT:    fmov d0, x9
+; CHECK-NEXT:    sbfx x10, x5, #0, #1
+; CHECK-NEXT:    sbfx x11, x3, #0, #1
 ; CHECK-NEXT:    mov v0.d[1], x1
+; CHECK-NEXT:    orr x2, x11, #0x1
 ; CHECK-NEXT:    lsr x3, x11, #63
-; CHECK-NEXT:    lsr x7, x9, #63
+; CHECK-NEXT:    orr x4, x10, #0x1
+; CHECK-NEXT:    lsr x5, x10, #63
 ; CHECK-NEXT:    fmov x0, d0
 ; CHECK-NEXT:    ret
   %c = icmp sgt <4 x i65> %a, <i65 -1, i65 -1, i65 -1, i65 -1>


        


More information about the llvm-commits mailing list