[llvm] 6f5edc3 - [RISCV] Fold (add (select lhs, rhs, cc, 0, y), x) -> (select lhs, rhs, cc, x, (add x, y))

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 10 09:03:25 PDT 2021


Author: Craig Topper
Date: 2021-08-10T09:02:56-07:00
New Revision: 6f5edc3487948ca7f25118e693344f41a95a8ca1

URL: https://github.com/llvm/llvm-project/commit/6f5edc3487948ca7f25118e693344f41a95a8ca1
DIFF: https://github.com/llvm/llvm-project/commit/6f5edc3487948ca7f25118e693344f41a95a8ca1.diff

LOG: [RISCV] Fold (add (select lhs, rhs, cc, 0, y), x) -> (select lhs, rhs, cc, x, (add x, y))

Similar for sub except sub isn't commutative.

Modify the existing and/or/xor folds to also work on ISD::SELECT
and not just RISCVISD::SELECT_CC. This is needed to make sure
we do this transform before type legalization turns i32 add/sub
into add/sub+sign_extend_inreg on RV64. If we don't do this before
that, the sign_extend_inreg will still be after the select.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D107603

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rv32zbs.ll
    llvm/test/CodeGen/RISCV/select-binop-identity.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 1fbbf0b3699fe..0c92a1f54d4d6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -858,6 +858,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   // We can use any register for comparisons
   setHasMultipleConditionRegisters();
 
+  setTargetDAGCombine(ISD::ADD);
+  setTargetDAGCombine(ISD::SUB);
   setTargetDAGCombine(ISD::AND);
   setTargetDAGCombine(ISD::OR);
   setTargetDAGCombine(ISD::XOR);
@@ -5769,17 +5771,27 @@ static SDValue combineGREVI_GORCI(SDNode *N, SelectionDAG &DAG) {
 
 // Combine a constant select operand into its use:
 //
-// (and (select_cc lhs, rhs, cc, -1, c), x)
-//   -> (select_cc lhs, rhs, cc, x, (and, x, c))  [AllOnes=1]
-// (or  (select_cc lhs, rhs, cc, 0, c), x)
-//   -> (select_cc lhs, rhs, cc, x, (or, x, c))  [AllOnes=0]
-// (xor (select_cc lhs, rhs, cc, 0, c), x)
-//   -> (select_cc lhs, rhs, cc, x, (xor, x, c))  [AllOnes=0]
-static SDValue combineSelectCCAndUse(SDNode *N, SDValue Slct, SDValue OtherOp,
-                                     SelectionDAG &DAG, bool AllOnes) {
+// (and (select cond, -1, c), x)
+//   -> (select cond, x, (and x, c))  [AllOnes=1]
+// (or  (select cond, 0, c), x)
+//   -> (select cond, x, (or x, c))  [AllOnes=0]
+// (xor (select cond, 0, c), x)
+//   -> (select cond, x, (xor x, c))  [AllOnes=0]
+// (add (select cond, 0, c), x)
+//   -> (select cond, x, (add x, c))  [AllOnes=0]
+// (sub x, (select cond, 0, c))
+//   -> (select cond, x, (sub x, c))  [AllOnes=0]
+static SDValue combineSelectAndUse(SDNode *N, SDValue Slct, SDValue OtherOp,
+                                   SelectionDAG &DAG, bool AllOnes) {
   EVT VT = N->getValueType(0);
 
-  if (Slct.getOpcode() != RISCVISD::SELECT_CC || !Slct.hasOneUse())
+  // Skip vectors.
+  if (VT.isVector())
+    return SDValue();
+
+  if ((Slct.getOpcode() != ISD::SELECT &&
+       Slct.getOpcode() != RISCVISD::SELECT_CC) ||
+      !Slct.hasOneUse())
     return SDValue();
 
   auto isZeroOrAllOnes = [](SDValue N, bool AllOnes) {
@@ -5787,8 +5799,9 @@ static SDValue combineSelectCCAndUse(SDNode *N, SDValue Slct, SDValue OtherOp,
   };
 
   bool SwapSelectOps;
-  SDValue TrueVal = Slct.getOperand(3);
-  SDValue FalseVal = Slct.getOperand(4);
+  unsigned OpOffset = Slct.getOpcode() == RISCVISD::SELECT_CC ? 2 : 0;
+  SDValue TrueVal = Slct.getOperand(1 + OpOffset);
+  SDValue FalseVal = Slct.getOperand(2 + OpOffset);
   SDValue NonConstantVal;
   if (isZeroOrAllOnes(TrueVal, AllOnes)) {
     SwapSelectOps = false;
@@ -5802,40 +5815,53 @@ static SDValue combineSelectCCAndUse(SDNode *N, SDValue Slct, SDValue OtherOp,
   // Slct is now know to be the desired identity constant when CC is true.
   TrueVal = OtherOp;
   FalseVal = DAG.getNode(N->getOpcode(), SDLoc(N), VT, OtherOp, NonConstantVal);
-  // Unless SwapSelectOps says CC should be false.
+  // Unless SwapSelectOps says the condition should be false.
   if (SwapSelectOps)
     std::swap(TrueVal, FalseVal);
 
-  return DAG.getNode(RISCVISD::SELECT_CC, SDLoc(N), VT,
-                     {Slct.getOperand(0), Slct.getOperand(1),
-                      Slct.getOperand(2), TrueVal, FalseVal});
+  if (Slct.getOpcode() == RISCVISD::SELECT_CC)
+    return DAG.getNode(RISCVISD::SELECT_CC, SDLoc(N), VT,
+                       {Slct.getOperand(0), Slct.getOperand(1),
+                        Slct.getOperand(2), TrueVal, FalseVal});
+
+  return DAG.getNode(ISD::SELECT, SDLoc(N), VT,
+                     {Slct.getOperand(0), TrueVal, FalseVal});
 }
 
 // Attempt combineSelectAndUse on each operand of a commutative operator N.
-static SDValue combineSelectCCAndUseCommutative(SDNode *N, SelectionDAG &DAG,
-                                                bool AllOnes) {
+static SDValue combineSelectAndUseCommutative(SDNode *N, SelectionDAG &DAG,
+                                              bool AllOnes) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
-  if (SDValue Result = combineSelectCCAndUse(N, N0, N1, DAG, AllOnes))
+  if (SDValue Result = combineSelectAndUse(N, N0, N1, DAG, AllOnes))
     return Result;
-  if (SDValue Result = combineSelectCCAndUse(N, N1, N0, DAG, AllOnes))
+  if (SDValue Result = combineSelectAndUse(N, N1, N0, DAG, AllOnes))
     return Result;
   return SDValue();
 }
 
-static SDValue performANDCombine(SDNode *N,
-                                 TargetLowering::DAGCombinerInfo &DCI,
-                                 const RISCVSubtarget &Subtarget) {
-  SelectionDAG &DAG = DCI.DAG;
+static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG) {
+  // fold (add (select lhs, rhs, cc, 0, y), x) ->
+  //      (select lhs, rhs, cc, x, (add x, y))
+  return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false);
+}
 
-  // fold (and (select_cc lhs, rhs, cc, -1, y), x) ->
+static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG) {
+  // fold (sub x, (select lhs, rhs, cc, 0, y)) ->
+  //      (select lhs, rhs, cc, x, (sub x, y))
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false);
+}
+
+static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG) {
+  // fold (and (select lhs, rhs, cc, -1, y), x) ->
   //      (select lhs, rhs, cc, x, (and x, y))
-  return combineSelectCCAndUseCommutative(N, DAG, true);
+  return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ true);
 }
 
-static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+static SDValue performORCombine(SDNode *N, SelectionDAG &DAG,
                                 const RISCVSubtarget &Subtarget) {
-  SelectionDAG &DAG = DCI.DAG;
   if (Subtarget.hasStdExtZbp()) {
     if (auto GREV = combineORToGREV(SDValue(N, 0), DAG, Subtarget))
       return GREV;
@@ -5845,19 +5871,15 @@ static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
       return SHFL;
   }
 
-  // fold (or (select_cc lhs, rhs, cc, 0, y), x) ->
-  //      (select lhs, rhs, cc, x, (or x, y))
-  return combineSelectCCAndUseCommutative(N, DAG, false);
+  // fold (or (select cond, 0, y), x) ->
+  //      (select cond, x, (or x, y))
+  return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false);
 }
 
-static SDValue performXORCombine(SDNode *N,
-                                 TargetLowering::DAGCombinerInfo &DCI,
-                                 const RISCVSubtarget &Subtarget) {
-  SelectionDAG &DAG = DCI.DAG;
-
-  // fold (xor (select_cc lhs, rhs, cc, 0, y), x) ->
-  //      (select lhs, rhs, cc, x, (xor x, y))
-  return combineSelectCCAndUseCommutative(N, DAG, false);
+static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG) {
+  // fold (xor (select cond, 0, y), x) ->
+  //      (select cond, x, (xor x, y))
+  return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false);
 }
 
 // Attempt to turn ANY_EXTEND into SIGN_EXTEND if the input to the ANY_EXTEND
@@ -6167,12 +6189,16 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     return DAG.getNode(ISD::AND, DL, VT, NewFMV,
                        DAG.getConstant(~SignBit, DL, VT));
   }
+  case ISD::ADD:
+    return performADDCombine(N, DAG);
+  case ISD::SUB:
+    return performSUBCombine(N, DAG);
   case ISD::AND:
-    return performANDCombine(N, DCI, Subtarget);
+    return performANDCombine(N, DAG);
   case ISD::OR:
-    return performORCombine(N, DCI, Subtarget);
+    return performORCombine(N, DAG, Subtarget);
   case ISD::XOR:
-    return performXORCombine(N, DCI, Subtarget);
+    return performXORCombine(N, DAG);
   case ISD::ANY_EXTEND:
     return performANY_EXTENDCombine(N, DCI, Subtarget);
   case ISD::ZERO_EXTEND:

diff  --git a/llvm/test/CodeGen/RISCV/rv32zbs.ll b/llvm/test/CodeGen/RISCV/rv32zbs.ll
index dce44f8c702ae..1a7c1b4f1a007 100644
--- a/llvm/test/CodeGen/RISCV/rv32zbs.ll
+++ b/llvm/test/CodeGen/RISCV/rv32zbs.ll
@@ -75,16 +75,15 @@ define i64 @sbclr_i64(i64 %a, i64 %b) nounwind {
 ;
 ; RV32IB-LABEL: sbclr_i64:
 ; RV32IB:       # %bb.0:
-; RV32IB-NEXT:    andi a3, a2, 63
-; RV32IB-NEXT:    addi a3, a3, -32
-; RV32IB-NEXT:    bset a4, zero, a3
-; RV32IB-NEXT:    slti a5, a3, 0
-; RV32IB-NEXT:    cmov a4, a5, zero, a4
-; RV32IB-NEXT:    bset a2, zero, a2
-; RV32IB-NEXT:    srai a3, a3, 31
-; RV32IB-NEXT:    and a2, a3, a2
-; RV32IB-NEXT:    andn a1, a1, a4
-; RV32IB-NEXT:    andn a0, a0, a2
+; RV32IB-NEXT:    bset a3, zero, a2
+; RV32IB-NEXT:    andi a2, a2, 63
+; RV32IB-NEXT:    addi a2, a2, -32
+; RV32IB-NEXT:    srai a4, a2, 31
+; RV32IB-NEXT:    and a3, a4, a3
+; RV32IB-NEXT:    slti a4, a2, 0
+; RV32IB-NEXT:    bclr a2, a1, a2
+; RV32IB-NEXT:    cmov a1, a4, a1, a2
+; RV32IB-NEXT:    andn a0, a0, a3
 ; RV32IB-NEXT:    ret
 ;
 ; RV32IBS-LABEL: sbclr_i64:

diff  --git a/llvm/test/CodeGen/RISCV/select-binop-identity.ll b/llvm/test/CodeGen/RISCV/select-binop-identity.ll
index 6e5d6255685b1..3449a33edf0f8 100644
--- a/llvm/test/CodeGen/RISCV/select-binop-identity.ll
+++ b/llvm/test/CodeGen/RISCV/select-binop-identity.ll
@@ -157,22 +157,20 @@ define i64 @xor_select_all_zeros_i64(i1 zeroext %c, i64 %x, i64 %y) {
 define signext i32 @add_select_all_zeros_i32(i1 zeroext %c, i32 signext %x, i32 signext %y) {
 ; RV32I-LABEL: add_select_all_zeros_i32:
 ; RV32I:       # %bb.0:
-; RV32I-NEXT:    mv a3, zero
 ; RV32I-NEXT:    bnez a0, .LBB6_2
 ; RV32I-NEXT:  # %bb.1:
-; RV32I-NEXT:    mv a3, a1
+; RV32I-NEXT:    add a2, a2, a1
 ; RV32I-NEXT:  .LBB6_2:
-; RV32I-NEXT:    add a0, a2, a3
+; RV32I-NEXT:    mv a0, a2
 ; RV32I-NEXT:    ret
 ;
 ; RV64I-LABEL: add_select_all_zeros_i32:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    mv a3, zero
 ; RV64I-NEXT:    bnez a0, .LBB6_2
 ; RV64I-NEXT:  # %bb.1:
-; RV64I-NEXT:    mv a3, a1
+; RV64I-NEXT:    addw a2, a2, a1
 ; RV64I-NEXT:  .LBB6_2:
-; RV64I-NEXT:    addw a0, a2, a3
+; RV64I-NEXT:    mv a0, a2
 ; RV64I-NEXT:    ret
   %a = select i1 %c, i32 0, i32 %x
   %b = add i32 %y, %a
@@ -182,24 +180,25 @@ define signext i32 @add_select_all_zeros_i32(i1 zeroext %c, i32 signext %x, i32
 define i64 @add_select_all_zeros_i64(i1 zeroext %c, i64 %x, i64 %y) {
 ; RV32I-LABEL: add_select_all_zeros_i64:
 ; RV32I:       # %bb.0:
-; RV32I-NEXT:    bnez a0, .LBB7_2
+; RV32I-NEXT:    beqz a0, .LBB7_2
 ; RV32I-NEXT:  # %bb.1:
-; RV32I-NEXT:    mv a2, zero
-; RV32I-NEXT:    mv a1, zero
+; RV32I-NEXT:    add a0, a4, a2
+; RV32I-NEXT:    add a1, a3, a1
+; RV32I-NEXT:    sltu a2, a1, a3
+; RV32I-NEXT:    add a4, a0, a2
+; RV32I-NEXT:    mv a3, a1
 ; RV32I-NEXT:  .LBB7_2:
-; RV32I-NEXT:    add a0, a1, a3
-; RV32I-NEXT:    sltu a1, a0, a1
-; RV32I-NEXT:    add a2, a2, a4
-; RV32I-NEXT:    add a1, a2, a1
+; RV32I-NEXT:    mv a0, a3
+; RV32I-NEXT:    mv a1, a4
 ; RV32I-NEXT:    ret
 ;
 ; RV64I-LABEL: add_select_all_zeros_i64:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    bnez a0, .LBB7_2
+; RV64I-NEXT:    beqz a0, .LBB7_2
 ; RV64I-NEXT:  # %bb.1:
-; RV64I-NEXT:    mv a1, zero
+; RV64I-NEXT:    add a2, a2, a1
 ; RV64I-NEXT:  .LBB7_2:
-; RV64I-NEXT:    add a0, a1, a2
+; RV64I-NEXT:    mv a0, a2
 ; RV64I-NEXT:    ret
   %a = select i1 %c, i64 %x, i64 0
   %b = add i64 %a, %y
@@ -209,22 +208,20 @@ define i64 @add_select_all_zeros_i64(i1 zeroext %c, i64 %x, i64 %y) {
 define signext i32 @sub_select_all_zeros_i32(i1 zeroext %c, i32 signext %x, i32 signext %y) {
 ; RV32I-LABEL: sub_select_all_zeros_i32:
 ; RV32I:       # %bb.0:
-; RV32I-NEXT:    mv a3, zero
 ; RV32I-NEXT:    bnez a0, .LBB8_2
 ; RV32I-NEXT:  # %bb.1:
-; RV32I-NEXT:    mv a3, a1
+; RV32I-NEXT:    sub a2, a2, a1
 ; RV32I-NEXT:  .LBB8_2:
-; RV32I-NEXT:    sub a0, a2, a3
+; RV32I-NEXT:    mv a0, a2
 ; RV32I-NEXT:    ret
 ;
 ; RV64I-LABEL: sub_select_all_zeros_i32:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    mv a3, zero
 ; RV64I-NEXT:    bnez a0, .LBB8_2
 ; RV64I-NEXT:  # %bb.1:
-; RV64I-NEXT:    mv a3, a1
+; RV64I-NEXT:    subw a2, a2, a1
 ; RV64I-NEXT:  .LBB8_2:
-; RV64I-NEXT:    subw a0, a2, a3
+; RV64I-NEXT:    mv a0, a2
 ; RV64I-NEXT:    ret
   %a = select i1 %c, i32 0, i32 %x
   %b = sub i32 %y, %a
@@ -234,25 +231,24 @@ define signext i32 @sub_select_all_zeros_i32(i1 zeroext %c, i32 signext %x, i32
 define i64 @sub_select_all_zeros_i64(i1 zeroext %c, i64 %x, i64 %y) {
 ; RV32I-LABEL: sub_select_all_zeros_i64:
 ; RV32I:       # %bb.0:
-; RV32I-NEXT:    bnez a0, .LBB9_2
+; RV32I-NEXT:    beqz a0, .LBB9_2
 ; RV32I-NEXT:  # %bb.1:
-; RV32I-NEXT:    mv a2, zero
-; RV32I-NEXT:    mv a1, zero
-; RV32I-NEXT:  .LBB9_2:
 ; RV32I-NEXT:    sltu a0, a3, a1
 ; RV32I-NEXT:    sub a2, a4, a2
-; RV32I-NEXT:    sub a2, a2, a0
-; RV32I-NEXT:    sub a0, a3, a1
-; RV32I-NEXT:    mv a1, a2
+; RV32I-NEXT:    sub a4, a2, a0
+; RV32I-NEXT:    sub a3, a3, a1
+; RV32I-NEXT:  .LBB9_2:
+; RV32I-NEXT:    mv a0, a3
+; RV32I-NEXT:    mv a1, a4
 ; RV32I-NEXT:    ret
 ;
 ; RV64I-LABEL: sub_select_all_zeros_i64:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    bnez a0, .LBB9_2
+; RV64I-NEXT:    beqz a0, .LBB9_2
 ; RV64I-NEXT:  # %bb.1:
-; RV64I-NEXT:    mv a1, zero
+; RV64I-NEXT:    sub a2, a2, a1
 ; RV64I-NEXT:  .LBB9_2:
-; RV64I-NEXT:    sub a0, a2, a1
+; RV64I-NEXT:    mv a0, a2
 ; RV64I-NEXT:    ret
   %a = select i1 %c, i64 %x, i64 0
   %b = sub i64 %y, %a


        


More information about the llvm-commits mailing list