[llvm] [RISCV] CSE by swapping conditional branches (PR #71111)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 2 14:17:29 PDT 2023


================
@@ -1159,6 +1159,133 @@ bool RISCVInstrInfo::reverseBranchCondition(
   return false;
 }
 
+bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
+  MachineBasicBlock *MBB = MI.getParent();
+  MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+
+  MachineBasicBlock *TBB, *FBB;
+  SmallVector<MachineOperand, 3> Cond;
+  if (analyzeBranch(*MBB, TBB, FBB, Cond, /*AllowModify=*/false))
+    return false;
+  (void)FBB;
+
+  RISCVCC::CondCode CC = static_cast<RISCVCC::CondCode>(Cond[0].getImm());
+  assert(CC != RISCVCC::COND_INVALID);
+
+  if (CC == RISCVCC::COND_EQ || CC == RISCVCC::COND_NE)
+    return false;
+
+  // For two constants C0 and C1 from
+  // ```
+  // li Y, C0
+  // li Z, C1
+  // ```
+  // 1. if C1 = C0 + 1
+  // we can turn:
+  //  (a) blt Y, X -> bge X, Z
+  //  (b) bge Y, X -> blt X, Z
+  //
+  // 2. if C1 = C0 - 1
+  // we can turn:
+  //  (a) blt X, Y -> bge Z, X
+  //  (b) bge X, Y -> blt Z, X
+  //
+  // To make sure this optimization is really beneficial, we only
+  // optimize for cases where Y had only one use (i.e. only used by the branch).
+
+  // Right now we only care about LI (i.e. ADDI rs, 0)
+  auto isLoadImm = [](MachineInstr *MI) -> bool {
+    return MI->getOpcode() == RISCV::ADDI && MI->getOperand(1).isReg() &&
+           MI->getOperand(1).getReg() == RISCV::X0;
+  };
+  // Either a load from immediate instruction or X0.
+  auto isFromLoadImm = [&](const MachineOperand &Op) -> bool {
+    if (!Op.isReg())
+      return false;
+    Register Reg = Op.getReg();
+    if (Reg == RISCV::X0)
+      return true;
+    if (!Reg.isVirtual())
+      return false;
+    return isLoadImm(MRI.getVRegDef(Op.getReg()));
+  };
+
+  MachineOperand &LHS = MI.getOperand(0);
+  MachineOperand &RHS = MI.getOperand(1);
+  auto getConst = [&MRI](MachineOperand &Op) -> int64_t {
+    Register Reg = Op.getReg();
+    if (Reg == RISCV::X0)
+      return 0;
+    assert(Reg.isVirtual());
+    MachineInstr *Def = MRI.getVRegDef(Reg);
+    return Def->getOperand(2).getImm();
+  };
+
+  // Try to find the register for constant Z; return
+  // invalid register otherwise.
+  auto searchConst = [&](int64_t C1) -> Register {
+    MachineInstr *DefC1 = nullptr;
+    MachineBasicBlock::reverse_iterator II(&MI), E = MBB->rend();
+    for (++II; II != E; ++II) {
+      if (isLoadImm(&*II))
+        if (II->getOperand(2).getImm() == C1) {
+          DefC1 = &*II;
+          break;
+        }
+    }
+    if (DefC1)
+      return DefC1->getOperand(0).getReg();
+    else
+      return Register();
+  };
+
+  bool Modify = false;
+  if (isFromLoadImm(LHS) && MRI.hasOneUse(LHS.getReg())) {
+    // Might be case 1.
+    int64_t C0 = getConst(LHS);
+    // Signed integer overflow is UB. (UINT_MAX is bigger so we don't need
----------------
topperc wrote:

I think you want INT64_MAX?

https://github.com/llvm/llvm-project/pull/71111


More information about the llvm-commits mailing list