[llvm] 41759c3 - [RISCV] Add RISCVISD::BR_CC similar to RISCVISD::SELECT_CC.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 15 11:54:38 PDT 2021


Author: Craig Topper
Date: 2021-03-15T11:54:01-07:00
New Revision: 41759c3d92c5445857a89087a877c47bef907f13

URL: https://github.com/llvm/llvm-project/commit/41759c3d92c5445857a89087a877c47bef907f13
DIFF: https://github.com/llvm/llvm-project/commit/41759c3d92c5445857a89087a877c47bef907f13.diff

LOG: [RISCV] Add RISCVISD::BR_CC similar to RISCVISD::SELECT_CC.

This allows me to introduce similar combines for branches as
we have recently added for SELECT_CC. Some of them are less
useful for standalone setccs and only help branch instructions.
By having a BR_CC node its easier to only affect branches.

I'm using CondCodeSDNode to make isel patterns easier to
write so we can refer to the codes by name. SELECT_CC uses a
constant instead.

I've translated the condition code just like SELECT_CC so
we need less patterns for the swapped conditions. This
includes special cases for X < 1 and X > -1 that get translated
to blez and bgez by using a 0 constant.

computeKnownBitsForTargetNode support for SELECT_CC is added
to allow MaskedValueIsZero to work for cases where the true
and false values of the SELECT_CC are setccs and the
result of the SELECT_CC is used by a BR_CC. This was needed
to avoid regressions in some of the overflow tests.

Reviewed By: luismarques

Differential Revision: https://reviews.llvm.org/D98159

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h
    llvm/lib/Target/RISCV/RISCVInstrInfo.td
    llvm/test/CodeGen/RISCV/xaluo.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e93e61549c0f..9e45307f42c0 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -181,6 +181,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
   setOperationAction(ISD::BR_CC, XLenVT, Expand);
+  setOperationAction(ISD::BRCOND, MVT::Other, Custom);
   setOperationAction(ISD::SELECT_CC, XLenVT, Expand);
 
   setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
@@ -676,7 +677,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   // We can use any register for comparisons
   setHasMultipleConditionRegisters();
 
-  setTargetDAGCombine(ISD::SETCC);
   if (Subtarget.hasStdExtZbp()) {
     setTargetDAGCombine(ISD::OR);
   }
@@ -1207,6 +1207,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     return lowerGlobalTLSAddress(Op, DAG);
   case ISD::SELECT:
     return lowerSELECT(Op, DAG);
+  case ISD::BRCOND:
+    return lowerBRCOND(Op, DAG);
   case ISD::VASTART:
     return lowerVASTART(Op, DAG);
   case ISD::FRAMEADDR:
@@ -1906,6 +1908,29 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getNode(RISCVISD::SELECT_CC, DL, Op.getValueType(), Ops);
 }
 
+SDValue RISCVTargetLowering::lowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
+  SDValue CondV = Op.getOperand(1);
+  SDLoc DL(Op);
+  MVT XLenVT = Subtarget.getXLenVT();
+
+  if (CondV.getOpcode() == ISD::SETCC &&
+      CondV.getOperand(0).getValueType() == XLenVT) {
+    SDValue LHS = CondV.getOperand(0);
+    SDValue RHS = CondV.getOperand(1);
+    ISD::CondCode CCVal = cast<CondCodeSDNode>(CondV.getOperand(2))->get();
+
+    translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);
+
+    SDValue TargetCC = DAG.getCondCode(CCVal);
+    return DAG.getNode(RISCVISD::BR_CC, DL, Op.getValueType(), Op.getOperand(0),
+                       LHS, RHS, TargetCC, Op.getOperand(2));
+  }
+
+  return DAG.getNode(RISCVISD::BR_CC, DL, Op.getValueType(), Op.getOperand(0),
+                     CondV, DAG.getConstant(0, DL, XLenVT),
+                     DAG.getCondCode(ISD::SETNE), Op.getOperand(2));
+}
+
 SDValue RISCVTargetLowering::lowerVASTART(SDValue Op, SelectionDAG &DAG) const {
   MachineFunction &MF = DAG.getMachineFunction();
   RISCVMachineFunctionInfo *FuncInfo = MF.getInfo<RISCVMachineFunctionInfo>();
@@ -4234,21 +4259,54 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
 
     break;
   }
-  case ISD::SETCC: {
-    // (setcc X, 1, setne) -> (setcc X, 0, seteq) if we can prove X is 0/1.
-    // Comparing with 0 may allow us to fold into bnez/beqz.
-    SDValue LHS = N->getOperand(0);
-    SDValue RHS = N->getOperand(1);
-    if (LHS.getValueType().isScalableVector())
+  case RISCVISD::BR_CC: {
+    SDValue LHS = N->getOperand(1);
+    SDValue RHS = N->getOperand(2);
+    ISD::CondCode CCVal = cast<CondCodeSDNode>(N->getOperand(3))->get();
+    if (!ISD::isIntEqualitySetCC(CCVal))
       break;
-    auto CC = cast<CondCodeSDNode>(N->getOperand(2))->get();
+
+    // Fold (br_cc (setlt X, Y), 0, ne, dest) ->
+    //      (br_cc X, Y, lt, dest)
+    // Sometimes the setcc is introduced after br_cc has been formed.
+    if (LHS.getOpcode() == ISD::SETCC && isNullConstant(RHS) &&
+        LHS.getOperand(0).getValueType() == Subtarget.getXLenVT()) {
+      // If we're looking for eq 0 instead of ne 0, we need to invert the
+      // condition.
+      bool Invert = CCVal == ISD::SETEQ;
+      CCVal = cast<CondCodeSDNode>(LHS.getOperand(2))->get();
+      if (Invert)
+        CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
+
+      SDLoc DL(N);
+      RHS = LHS.getOperand(1);
+      LHS = LHS.getOperand(0);
+      translateSetCCForBranch(DL, LHS, RHS, CCVal, DAG);
+
+      return DAG.getNode(RISCVISD::BR_CC, DL, N->getValueType(0),
+                         N->getOperand(0), LHS, RHS, DAG.getCondCode(CCVal),
+                         N->getOperand(4));
+    }
+
+    // Fold (br_cc (xor X, Y), 0, eq/ne, dest) ->
+    //      (br_cc X, Y, eq/ne, trueV, falseV)
+    if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS))
+      return DAG.getNode(RISCVISD::BR_CC, SDLoc(N), N->getValueType(0),
+                         N->getOperand(0), LHS.getOperand(0), LHS.getOperand(1),
+                         N->getOperand(3), N->getOperand(4));
+
+    // (br_cc X, 1, setne, br_cc) ->
+    // (br_cc X, 0, seteq, br_cc) if we can prove X is 0/1.
+    // This can occur when legalizing some floating point comparisons.
     APInt Mask = APInt::getBitsSetFrom(LHS.getValueSizeInBits(), 1);
-    if (isOneConstant(RHS) && ISD::isIntEqualitySetCC(CC) &&
-        DAG.MaskedValueIsZero(LHS, Mask)) {
+    if (isOneConstant(RHS) && DAG.MaskedValueIsZero(LHS, Mask)) {
       SDLoc DL(N);
-      SDValue Zero = DAG.getConstant(0, DL, LHS.getValueType());
-      CC = ISD::getSetCCInverse(CC, LHS.getValueType());
-      return DAG.getSetCC(DL, N->getValueType(0), LHS, Zero, CC);
+      CCVal = ISD::getSetCCInverse(CCVal, LHS.getValueType());
+      SDValue TargetCC = DAG.getCondCode(CCVal);
+      RHS = DAG.getConstant(0, DL, LHS.getValueType());
+      return DAG.getNode(RISCVISD::BR_CC, DL, N->getValueType(0),
+                         N->getOperand(0), LHS, RHS, TargetCC,
+                         N->getOperand(4));
     }
     break;
   }
@@ -4409,6 +4467,17 @@ void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
   Known.resetAll();
   switch (Opc) {
   default: break;
+  case RISCVISD::SELECT_CC: {
+    Known = DAG.computeKnownBits(Op.getOperand(4), Depth + 1);
+    // If we don't know any bits, early out.
+    if (Known.isUnknown())
+      break;
+    KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(3), Depth + 1);
+
+    // Only known if known in both the LHS and RHS.
+    Known = KnownBits::commonBits(Known, Known2);
+    break;
+  }
   case RISCVISD::REMUW: {
     KnownBits Known2;
     Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
@@ -6155,6 +6224,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(MRET_FLAG)
   NODE_NAME_CASE(CALL)
   NODE_NAME_CASE(SELECT_CC)
+  NODE_NAME_CASE(BR_CC)
   NODE_NAME_CASE(BuildPairF64)
   NODE_NAME_CASE(SplitF64)
   NODE_NAME_CASE(TAIL)

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index dd0e72fcf84e..6f9ba9757c47 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -36,6 +36,7 @@ enum NodeType : unsigned {
   /// The lhs and rhs are XLenVT integers. The true and false values can be
   /// integer or floating point.
   SELECT_CC,
+  BR_CC,
   BuildPairF64,
   SplitF64,
   TAIL,
@@ -452,6 +453,7 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue lowerJumpTable(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerBRCOND(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVASTART(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerFRAMEADDR(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const;

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
index 6e093c1c9022..0f7eb248377b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
@@ -25,6 +25,9 @@ def SDT_RISCVCall     : SDTypeProfile<0, -1, [SDTCisVT<0, XLenVT>]>;
 def SDT_RISCVSelectCC : SDTypeProfile<1, 5, [SDTCisSameAs<1, 2>,
                                              SDTCisSameAs<0, 4>,
                                              SDTCisSameAs<4, 5>]>;
+def SDT_RISCVBrCC : SDTypeProfile<0, 4, [SDTCisSameAs<0, 1>,
+                                         SDTCisVT<2, OtherVT>,
+                                         SDTCisVT<3, OtherVT>]>;
 def SDT_RISCVReadCycleWide : SDTypeProfile<2, 0, [SDTCisVT<0, i32>,
                                                   SDTCisVT<1, i32>]>;
 def SDT_RISCVIntBinOpW : SDTypeProfile<1, 2, [
@@ -50,6 +53,8 @@ def riscv_sret_flag : SDNode<"RISCVISD::SRET_FLAG", SDTNone,
 def riscv_mret_flag : SDNode<"RISCVISD::MRET_FLAG", SDTNone,
                              [SDNPHasChain, SDNPOptInGlue]>;
 def riscv_selectcc  : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC>;
+def riscv_brcc      : SDNode<"RISCVISD::BR_CC", SDT_RISCVBrCC,
+                             [SDNPHasChain]>;
 def riscv_tail      : SDNode<"RISCVISD::TAIL", SDT_RISCVCall,
                              [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
                               SDNPVariadic]>;
@@ -961,44 +966,17 @@ def Select_GPR_Using_CC_GPR : SelectCC_rrirr<GPR, GPR>;
 
 /// Branches and jumps
 
-// Match `(brcond (CondOp ..), ..)` and lower to the appropriate RISC-V branch
-// instruction.
-class BccPat<PatFrag CondOp, RVInstB Inst>
-    : Pat<(brcond (XLenVT (CondOp GPR:$rs1, GPR:$rs2)), bb:$imm12),
+// Match `riscv_brcc` and lower to the appropriate RISC-V branch instruction.
+class BccPat<CondCode Cond, RVInstB Inst>
+    : Pat<(riscv_brcc GPR:$rs1, GPR:$rs2, Cond, bb:$imm12),
           (Inst GPR:$rs1, GPR:$rs2, simm13_lsb0:$imm12)>;
 
-def : BccPat<seteq, BEQ>;
-def : BccPat<setne, BNE>;
-def : BccPat<setlt, BLT>;
-def : BccPat<setge, BGE>;
-def : BccPat<setult, BLTU>;
-def : BccPat<setuge, BGEU>;
-
-class BccSwapPat<PatFrag CondOp, RVInst InstBcc>
-    : Pat<(brcond (XLenVT (CondOp GPR:$rs1, GPR:$rs2)), bb:$imm12),
-          (InstBcc GPR:$rs2, GPR:$rs1, bb:$imm12)>;
-
-// Condition codes that don't have matching RISC-V branch instructions, but
-// are trivially supported by swapping the two input operands
-def : BccSwapPat<setgt, BLT>;
-def : BccSwapPat<setle, BGE>;
-def : BccSwapPat<setugt, BLTU>;
-def : BccSwapPat<setule, BGEU>;
-
-// Extra patterns are needed for a brcond without a setcc (i.e. where the
-// condition was calculated elsewhere).
-def : Pat<(brcond GPR:$cond, bb:$imm12), (BNE GPR:$cond, X0, bb:$imm12)>;
-// In this pattern, the `(xor $cond, 1)` functions like (boolean) `not`, as the
-// `brcond` only uses the lowest bit.
-def : Pat<(brcond (XLenVT (xor GPR:$cond, 1)), bb:$imm12),
-          (BEQ GPR:$cond, X0, bb:$imm12)>;
-
-// Match X > -1, the canonical form of X >= 0, to the bgez pattern.
-def : Pat<(brcond (XLenVT (setgt GPR:$rs1, -1)), bb:$imm12),
-          (BGE GPR:$rs1, X0, bb:$imm12)>;
-// Lower (a < 1) as (0 >= a) into the blez pattern.
-def : Pat<(brcond (XLenVT (setlt GPR:$lhs, 1)), bb:$imm12),
-          (BGE X0, GPR:$lhs, bb:$imm12)>;
+def : BccPat<SETEQ, BEQ>;
+def : BccPat<SETNE, BNE>;
+def : BccPat<SETLT, BLT>;
+def : BccPat<SETGE, BGE>;
+def : BccPat<SETULT, BLTU>;
+def : BccPat<SETUGE, BGEU>;
 
 let isBarrier = 1, isBranch = 1, isTerminator = 1 in
 def PseudoBR : Pseudo<(outs), (ins simm21_lsb0_jal:$imm20), [(br bb:$imm20)]>,

diff  --git a/llvm/test/CodeGen/RISCV/xaluo.ll b/llvm/test/CodeGen/RISCV/xaluo.ll
index 43b27d455015..e29cfc9156bc 100644
--- a/llvm/test/CodeGen/RISCV/xaluo.ll
+++ b/llvm/test/CodeGen/RISCV/xaluo.ll
@@ -1423,8 +1423,7 @@ define zeroext i1 @saddo.br.i32(i32 %v1, i32 %v2) {
 ; RV32-NEXT:    add a2, a0, a1
 ; RV32-NEXT:    slt a0, a2, a0
 ; RV32-NEXT:    slti a1, a1, 0
-; RV32-NEXT:    xor a0, a1, a0
-; RV32-NEXT:    beqz a0, .LBB46_2
+; RV32-NEXT:    beq a1, a0, .LBB46_2
 ; RV32-NEXT:  # %bb.1: # %overflow
 ; RV32-NEXT:    mv a0, zero
 ; RV32-NEXT:    ret
@@ -1482,8 +1481,7 @@ define zeroext i1 @saddo.br.i64(i64 %v1, i64 %v2) {
 ; RV64-NEXT:    add a2, a0, a1
 ; RV64-NEXT:    slt a0, a2, a0
 ; RV64-NEXT:    slti a1, a1, 0
-; RV64-NEXT:    xor a0, a1, a0
-; RV64-NEXT:    beqz a0, .LBB47_2
+; RV64-NEXT:    beq a1, a0, .LBB47_2
 ; RV64-NEXT:  # %bb.1: # %overflow
 ; RV64-NEXT:    mv a0, zero
 ; RV64-NEXT:    ret
@@ -1587,8 +1585,7 @@ define zeroext i1 @ssubo.br.i32(i32 %v1, i32 %v2) {
 ; RV32-NEXT:    sgtz a2, a1
 ; RV32-NEXT:    sub a1, a0, a1
 ; RV32-NEXT:    slt a0, a1, a0
-; RV32-NEXT:    xor a0, a2, a0
-; RV32-NEXT:    beqz a0, .LBB50_2
+; RV32-NEXT:    beq a2, a0, .LBB50_2
 ; RV32-NEXT:  # %bb.1: # %overflow
 ; RV32-NEXT:    mv a0, zero
 ; RV32-NEXT:    ret
@@ -1644,8 +1641,7 @@ define zeroext i1 @ssubo.br.i64(i64 %v1, i64 %v2) {
 ; RV64-NEXT:    sgtz a2, a1
 ; RV64-NEXT:    sub a1, a0, a1
 ; RV64-NEXT:    slt a0, a1, a0
-; RV64-NEXT:    xor a0, a2, a0
-; RV64-NEXT:    beqz a0, .LBB51_2
+; RV64-NEXT:    beq a2, a0, .LBB51_2
 ; RV64-NEXT:  # %bb.1: # %overflow
 ; RV64-NEXT:    mv a0, zero
 ; RV64-NEXT:    ret


        


More information about the llvm-commits mailing list