[llvm] e0e0891 - [RISCV][GISel] Select G_BRCOND and G_ICMP together when possible.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 12 15:54:02 PST 2023


Author: Craig Topper
Date: 2023-11-12T15:53:23-08:00
New Revision: e0e0891d741588684b0803d7724e5080f9c75537

URL: https://github.com/llvm/llvm-project/commit/e0e0891d741588684b0803d7724e5080f9c75537
DIFF: https://github.com/llvm/llvm-project/commit/e0e0891d741588684b0803d7724e5080f9c75537.diff

LOG: [RISCV][GISel] Select G_BRCOND and G_ICMP together when possible.

This allows us to fold the G_ICMP operands into the conditional branch.

This reuses the helper function we have for folding a G_ICMP into
G_SELECT.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
index d97776265ce6418..25f4a217c070349 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
@@ -296,6 +296,92 @@ RISCVInstructionSelector::selectAddrRegImm(MachineOperand &Root) const {
            [=](MachineInstrBuilder &MIB) { MIB.addImm(0); }}};
 }
 
+/// Returns the RISCVCC::CondCode that corresponds to the CmpInst::Predicate CC.
+/// CC Must be an ICMP Predicate.
+static RISCVCC::CondCode getRISCVCCFromICmp(CmpInst::Predicate CC) {
+  switch (CC) {
+  default:
+    llvm_unreachable("Expected ICMP CmpInst::Predicate.");
+  case CmpInst::Predicate::ICMP_EQ:
+    return RISCVCC::COND_EQ;
+  case CmpInst::Predicate::ICMP_NE:
+    return RISCVCC::COND_NE;
+  case CmpInst::Predicate::ICMP_ULT:
+    return RISCVCC::COND_LTU;
+  case CmpInst::Predicate::ICMP_SLT:
+    return RISCVCC::COND_LT;
+  case CmpInst::Predicate::ICMP_UGE:
+    return RISCVCC::COND_GEU;
+  case CmpInst::Predicate::ICMP_SGE:
+    return RISCVCC::COND_GE;
+  }
+}
+
+static void getOperandsForBranch(Register CondReg, MachineRegisterInfo &MRI,
+                                 RISCVCC::CondCode &CC, Register &LHS,
+                                 Register &RHS) {
+  // Try to fold an ICmp. If that fails, use a NE compare with X0.
+  CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
+  if (!mi_match(CondReg, MRI, m_GICmp(m_Pred(Pred), m_Reg(LHS), m_Reg(RHS)))) {
+    LHS = CondReg;
+    RHS = RISCV::X0;
+    CC = RISCVCC::COND_NE;
+    return;
+  }
+
+  // We found an ICmp, do some canonicalizations.
+
+  // Adjust comparisons to use comparison with 0 if possible.
+  if (auto Constant = getIConstantVRegSExtVal(RHS, MRI)) {
+    switch (Pred) {
+    case CmpInst::Predicate::ICMP_SGT:
+      // Convert X > -1 to X >= 0
+      if (*Constant == -1) {
+        CC = RISCVCC::COND_GE;
+        RHS = RISCV::X0;
+        return;
+      }
+      break;
+    case CmpInst::Predicate::ICMP_SLT:
+      // Convert X < 1 to 0 >= X
+      if (*Constant == 1) {
+        CC = RISCVCC::COND_GE;
+        RHS = LHS;
+        LHS = RISCV::X0;
+        return;
+      }
+      break;
+    default:
+      break;
+    }
+  }
+
+  switch (Pred) {
+  default:
+    llvm_unreachable("Expected ICMP CmpInst::Predicate.");
+  case CmpInst::Predicate::ICMP_EQ:
+  case CmpInst::Predicate::ICMP_NE:
+  case CmpInst::Predicate::ICMP_ULT:
+  case CmpInst::Predicate::ICMP_SLT:
+  case CmpInst::Predicate::ICMP_UGE:
+  case CmpInst::Predicate::ICMP_SGE:
+    // These CCs are supported directly by RISC-V branches.
+    break;
+  case CmpInst::Predicate::ICMP_SGT:
+  case CmpInst::Predicate::ICMP_SLE:
+  case CmpInst::Predicate::ICMP_UGT:
+  case CmpInst::Predicate::ICMP_ULE:
+    // These CCs are not supported directly by RISC-V branches, but changing the
+    // direction of the CC and swapping LHS and RHS are.
+    Pred = CmpInst::getSwappedPredicate(Pred);
+    std::swap(LHS, RHS);
+    break;
+  }
+
+  CC = getRISCVCCFromICmp(Pred);
+  return;
+}
+
 bool RISCVInstructionSelector::select(MachineInstr &MI) {
   MachineBasicBlock &MBB = *MI.getParent();
   MachineFunction &MF = *MBB.getParent();
@@ -398,10 +484,12 @@ bool RISCVInstructionSelector::select(MachineInstr &MI) {
   case TargetOpcode::G_GLOBAL_VALUE:
     return selectGlobalValue(MI, MIB, MRI);
   case TargetOpcode::G_BRCOND: {
-    // TODO: Fold with G_ICMP.
-    auto Bcc =
-        MIB.buildInstr(RISCV::BNE, {}, {MI.getOperand(0), Register(RISCV::X0)})
-            .addMBB(MI.getOperand(1).getMBB());
+    Register LHS, RHS;
+    RISCVCC::CondCode CC;
+    getOperandsForBranch(MI.getOperand(0).getReg(), MRI, CC, LHS, RHS);
+
+    auto Bcc = MIB.buildInstr(RISCVCC::getBrCond(CC), {}, {LHS, RHS})
+                   .addMBB(MI.getOperand(1).getMBB());
     MI.eraseFromParent();
     return constrainSelectedInstRegOperands(*Bcc, TII, TRI, RBI);
   }
@@ -719,93 +807,6 @@ bool RISCVInstructionSelector::selectSExtInreg(MachineInstr &MI,
   return true;
 }
 
-/// Returns the RISCVCC::CondCode that corresponds to the CmpInst::Predicate CC.
-/// CC Must be an ICMP Predicate.
-static RISCVCC::CondCode getRISCVCCFromICMP(CmpInst::Predicate CC) {
-  switch (CC) {
-  default:
-    llvm_unreachable("Expected ICMP CmpInst::Predicate.");
-  case CmpInst::Predicate::ICMP_EQ:
-    return RISCVCC::COND_EQ;
-  case CmpInst::Predicate::ICMP_NE:
-    return RISCVCC::COND_NE;
-  case CmpInst::Predicate::ICMP_ULT:
-    return RISCVCC::COND_LTU;
-  case CmpInst::Predicate::ICMP_SLT:
-    return RISCVCC::COND_LT;
-  case CmpInst::Predicate::ICMP_UGE:
-    return RISCVCC::COND_GEU;
-  case CmpInst::Predicate::ICMP_SGE:
-    return RISCVCC::COND_GE;
-  }
-}
-
-static void getOperandsForBranch(Register CondReg, MachineIRBuilder &MIB,
-                                 MachineRegisterInfo &MRI,
-                                 RISCVCC::CondCode &CC, Register &LHS,
-                                 Register &RHS) {
-  // Try to fold an ICmp. If that fails, use a NE compare with X0.
-  CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
-  if (!mi_match(CondReg, MRI, m_GICmp(m_Pred(Pred), m_Reg(LHS), m_Reg(RHS)))) {
-    LHS = CondReg;
-    RHS = RISCV::X0;
-    CC = RISCVCC::COND_NE;
-    return;
-  }
-
-  // We found an ICmp, do some canonicalizations.
-
-  // Adjust comparisons to use comparison with 0 if possible.
-  if (auto Constant = getIConstantVRegSExtVal(RHS, MRI)) {
-    switch (Pred) {
-    case CmpInst::Predicate::ICMP_SGT:
-      // Convert X > -1 to X >= 0
-      if (*Constant == -1) {
-        CC = RISCVCC::COND_GE;
-        RHS = RISCV::X0;
-        return;
-      }
-      break;
-    case CmpInst::Predicate::ICMP_SLT:
-      // Convert X < 1 to 0 >= X
-      if (*Constant == 1) {
-        CC = RISCVCC::COND_GE;
-        RHS = LHS;
-        LHS = RISCV::X0;
-        return;
-      }
-      break;
-    default:
-      break;
-    }
-  }
-
-  switch (Pred) {
-  default:
-    llvm_unreachable("Expected ICMP CmpInst::Predicate.");
-  case CmpInst::Predicate::ICMP_EQ:
-  case CmpInst::Predicate::ICMP_NE:
-  case CmpInst::Predicate::ICMP_ULT:
-  case CmpInst::Predicate::ICMP_SLT:
-  case CmpInst::Predicate::ICMP_UGE:
-  case CmpInst::Predicate::ICMP_SGE:
-    // These CCs are supported directly by RISC-V branches.
-    break;
-  case CmpInst::Predicate::ICMP_SGT:
-  case CmpInst::Predicate::ICMP_SLE:
-  case CmpInst::Predicate::ICMP_UGT:
-  case CmpInst::Predicate::ICMP_ULE:
-    // These CCs are not supported directly by RISC-V branches, but changing the
-    // direction of the CC and swapping LHS and RHS are.
-    Pred = CmpInst::getSwappedPredicate(Pred);
-    std::swap(LHS, RHS);
-    break;
-  }
-
-  CC = getRISCVCCFromICMP(Pred);
-  return;
-}
-
 bool RISCVInstructionSelector::selectSelect(MachineInstr &MI,
                                             MachineIRBuilder &MIB,
                                             MachineRegisterInfo &MRI) const {
@@ -813,7 +814,7 @@ bool RISCVInstructionSelector::selectSelect(MachineInstr &MI,
 
   Register LHS, RHS;
   RISCVCC::CondCode CC;
-  getOperandsForBranch(SelectMI.getCondReg(), MIB, MRI, CC, LHS, RHS);
+  getOperandsForBranch(SelectMI.getCondReg(), MRI, CC, LHS, RHS);
 
   MachineInstr *Result = MIB.buildInstr(RISCV::Select_GPR_Using_CC_GPR)
                              .addDef(SelectMI.getReg(0))

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 01f2bb9d730375e..9271f807a84838b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -904,25 +904,29 @@ static void parseCondBranch(MachineInstr &LastInst, MachineBasicBlock *&Target,
   Cond.push_back(LastInst.getOperand(1));
 }
 
-const MCInstrDesc &RISCVInstrInfo::getBrCond(RISCVCC::CondCode CC) const {
+unsigned RISCVCC::getBrCond(RISCVCC::CondCode CC) {
   switch (CC) {
   default:
     llvm_unreachable("Unknown condition code!");
   case RISCVCC::COND_EQ:
-    return get(RISCV::BEQ);
+    return RISCV::BEQ;
   case RISCVCC::COND_NE:
-    return get(RISCV::BNE);
+    return RISCV::BNE;
   case RISCVCC::COND_LT:
-    return get(RISCV::BLT);
+    return RISCV::BLT;
   case RISCVCC::COND_GE:
-    return get(RISCV::BGE);
+    return RISCV::BGE;
   case RISCVCC::COND_LTU:
-    return get(RISCV::BLTU);
+    return RISCV::BLTU;
   case RISCVCC::COND_GEU:
-    return get(RISCV::BGEU);
+    return RISCV::BGEU;
   }
 }
 
+const MCInstrDesc &RISCVInstrInfo::getBrCond(RISCVCC::CondCode CC) const {
+  return get(RISCVCC::getBrCond(CC));
+}
+
 RISCVCC::CondCode RISCVCC::getOppositeBranchCondition(RISCVCC::CondCode CC) {
   switch (CC) {
   default:

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 491278c2e017e7c..b33d8c28561596b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -43,6 +43,7 @@ enum CondCode {
 };
 
 CondCode getOppositeBranchCondition(CondCode);
+unsigned getBrCond(CondCode CC);
 
 } // end of namespace RISCVCC
 


        


More information about the llvm-commits mailing list