[llvm] [RISCV] CSE by swapping conditional branches (PR #71111)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 2 14:03:23 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Min-Yih Hsu (mshockwave)
<details>
<summary>Changes</summary>
This patch performs the following optimization:
For two constants C0 and C1 from
```
li Y, C0
li Z, C1
```
To remove redundnat `li Y, C0`,
1. if C1 = C0 + 1 we can turn: (a) blt Y, X -> bge X, Z (b) bge Y, X -> blt X, Z
2. if C1 = C0 - 1 we can turn: (a) blt X, Y -> bge Z, X (b) bge X, Y -> blt Z, X
This optimization will be done by PeepholeOptimizer through RISCVInstrInfo::optimizeCondBranch.
---
Full diff: https://github.com/llvm/llvm-project/pull/71111.diff
3 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.cpp (+127)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.h (+2)
- (added) llvm/test/CodeGen/RISCV/branch-opt.ll (+175)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 412fb7e7f7fc16c..9a3fb7fe39713c7 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1159,6 +1159,133 @@ bool RISCVInstrInfo::reverseBranchCondition(
return false;
}
+bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
+ MachineBasicBlock *MBB = MI.getParent();
+ MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+
+ MachineBasicBlock *TBB, *FBB;
+ SmallVector<MachineOperand, 3> Cond;
+ if (analyzeBranch(*MBB, TBB, FBB, Cond, /*AllowModify=*/false))
+ return false;
+ (void)FBB;
+
+ RISCVCC::CondCode CC = static_cast<RISCVCC::CondCode>(Cond[0].getImm());
+ assert(CC != RISCVCC::COND_INVALID);
+
+ if (CC == RISCVCC::COND_EQ || CC == RISCVCC::COND_NE)
+ return false;
+
+ // For two constants C0 and C1 from
+ // ```
+ // li Y, C0
+ // li Z, C1
+ // ```
+ // 1. if C1 = C0 + 1
+ // we can turn:
+ // (a) blt Y, X -> bge X, Z
+ // (b) bge Y, X -> blt X, Z
+ //
+ // 2. if C1 = C0 - 1
+ // we can turn:
+ // (a) blt X, Y -> bge Z, X
+ // (b) bge X, Y -> blt Z, X
+ //
+ // To make sure this optimization is really beneficial, we only
+ // optimize for cases where Y had only one use (i.e. only used by the branch).
+
+ // Right now we only care about LI (i.e. ADDI rs, 0)
+ auto isLoadImm = [](MachineInstr *MI) -> bool {
+ return MI->getOpcode() == RISCV::ADDI && MI->getOperand(1).isReg() &&
+ MI->getOperand(1).getReg() == RISCV::X0;
+ };
+ // Either a load from immediate instruction or X0.
+ auto isFromLoadImm = [&](const MachineOperand &Op) -> bool {
+ if (!Op.isReg())
+ return false;
+ Register Reg = Op.getReg();
+ if (Reg == RISCV::X0)
+ return true;
+ if (!Reg.isVirtual())
+ return false;
+ return isLoadImm(MRI.getVRegDef(Op.getReg()));
+ };
+
+ MachineOperand &LHS = MI.getOperand(0);
+ MachineOperand &RHS = MI.getOperand(1);
+ auto getConst = [&MRI](MachineOperand &Op) -> int64_t {
+ Register Reg = Op.getReg();
+ if (Reg == RISCV::X0)
+ return 0;
+ assert(Reg.isVirtual());
+ MachineInstr *Def = MRI.getVRegDef(Reg);
+ return Def->getOperand(2).getImm();
+ };
+
+ // Try to find the register for constant Z; return
+ // invalid register otherwise.
+ auto searchConst = [&](int64_t C1) -> Register {
+ MachineInstr *DefC1 = nullptr;
+ MachineBasicBlock::reverse_iterator II(&MI), E = MBB->rend();
+ for (++II; II != E; ++II) {
+ if (isLoadImm(&*II))
+ if (II->getOperand(2).getImm() == C1) {
+ DefC1 = &*II;
+ break;
+ }
+ }
+ if (DefC1)
+ return DefC1->getOperand(0).getReg();
+ else
+ return Register();
+ };
+
+ bool Modify = false;
+ if (isFromLoadImm(LHS) && MRI.hasOneUse(LHS.getReg())) {
+ // Might be case 1.
+ int64_t C0 = getConst(LHS);
+ // Signed integer overflow is UB. (UINT_MAX is bigger so we don't need
+ // to worry about unsigned overflow here)
+ if (C0 < INT_MAX)
+ if (Register RegZ = searchConst(C0 + 1)) {
+ reverseBranchCondition(Cond);
+ Cond[1] = MachineOperand::CreateReg(RHS.getReg(), /*isDef=*/false);
+ Cond[2] = MachineOperand::CreateReg(RegZ, /*isDef=*/false);
+ // We might extend the live range of Z, clear its kill flag to
+ // account for this.
+ MRI.clearKillFlags(RegZ);
+ Modify = true;
+ }
+ } else if (isFromLoadImm(RHS) && MRI.hasOneUse(RHS.getReg())) {
+ // Might be case 2.
+ int64_t C0 = getConst(RHS);
+ // For unsigned cases, we don't want C1 to wrap back to UINT_MAX
+ // when C0 is zero.
+ if ((CC == RISCVCC::COND_GE || CC == RISCVCC::COND_LT) || C0)
+ if (Register RegZ = searchConst(C0 - 1)) {
+ reverseBranchCondition(Cond);
+ Cond[1] = MachineOperand::CreateReg(RegZ, /*isDef=*/false);
+ Cond[2] = MachineOperand::CreateReg(LHS.getReg(), /*isDef=*/false);
+ // We might extend the live range of Z, clear its kill flag to
+ // account for this.
+ MRI.clearKillFlags(RegZ);
+ Modify = true;
+ }
+ }
+
+ if (!Modify)
+ return false;
+
+ // Build the new branch and remove the old one.
+ BuildMI(*MBB, MI, MI.getDebugLoc(),
+ getBrCond(static_cast<RISCVCC::CondCode>(Cond[0].getImm())))
+ .add(Cond[1])
+ .add(Cond[2])
+ .addMBB(TBB);
+ MI.eraseFromParent();
+
+ return true;
+}
+
MachineBasicBlock *
RISCVInstrInfo::getBranchDestBlock(const MachineInstr &MI) const {
assert(MI.getDesc().isBranch() && "Unexpected opcode!");
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index d0112a464677ab5..491278c2e017e7c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -121,6 +121,8 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
bool
reverseBranchCondition(SmallVectorImpl<MachineOperand> &Cond) const override;
+ bool optimizeCondBranch(MachineInstr &MI) const override;
+
MachineBasicBlock *getBranchDestBlock(const MachineInstr &MI) const override;
bool isBranchOffsetInRange(unsigned BranchOpc,
diff --git a/llvm/test/CodeGen/RISCV/branch-opt.ll b/llvm/test/CodeGen/RISCV/branch-opt.ll
new file mode 100644
index 000000000000000..8ac21452301dd00
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/branch-opt.ll
@@ -0,0 +1,175 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc -mtriple=riscv32 -O2 -verify-machineinstrs < %s | FileCheck %s --check-prefix=RV32
+; RUN: llc -mtriple=riscv64 -O2 -verify-machineinstrs < %s | FileCheck %s --check-prefix=RV64
+
+define void @u_case1_a(ptr %a, i32 %b, ptr %c, ptr %d) {
+; RV32-LABEL: u_case1_a:
+; RV32: # %bb.0:
+; RV32-NEXT: li a4, 32
+; RV32-NEXT: sw a4, 0(a0)
+; RV32-NEXT: bgeu a1, a4, .LBB0_2
+; RV32-NEXT: # %bb.1: # %block1
+; RV32-NEXT: sw a1, 0(a2)
+; RV32-NEXT: ret
+; RV32-NEXT: .LBB0_2: # %block2
+; RV32-NEXT: li a0, 87
+; RV32-NEXT: sw a0, 0(a3)
+; RV32-NEXT: ret
+;
+; RV64-LABEL: u_case1_a:
+; RV64: # %bb.0:
+; RV64-NEXT: sext.w a4, a1
+; RV64-NEXT: li a5, 32
+; RV64-NEXT: sw a5, 0(a0)
+; RV64-NEXT: bgeu a4, a5, .LBB0_2
+; RV64-NEXT: # %bb.1: # %block1
+; RV64-NEXT: sw a1, 0(a2)
+; RV64-NEXT: ret
+; RV64-NEXT: .LBB0_2: # %block2
+; RV64-NEXT: li a0, 87
+; RV64-NEXT: sw a0, 0(a3)
+; RV64-NEXT: ret
+ store i32 32, ptr %a
+ %p = icmp ule i32 %b, 31
+ br i1 %p, label %block1, label %block2
+
+block1: ; preds = %0
+ store i32 %b, ptr %c
+ br label %end_block
+
+block2: ; preds = %0
+ store i32 87, ptr %d
+ br label %end_block
+
+end_block: ; preds = %block2, %block1
+ ret void
+}
+
+define void @case1_a(ptr %a, i32 %b, ptr %c, ptr %d) {
+; RV32-LABEL: case1_a:
+; RV32: # %bb.0:
+; RV32-NEXT: li a4, -1
+; RV32-NEXT: sw a4, 0(a0)
+; RV32-NEXT: bge a1, a4, .LBB1_2
+; RV32-NEXT: # %bb.1: # %block1
+; RV32-NEXT: sw a1, 0(a2)
+; RV32-NEXT: ret
+; RV32-NEXT: .LBB1_2: # %block2
+; RV32-NEXT: li a0, 87
+; RV32-NEXT: sw a0, 0(a3)
+; RV32-NEXT: ret
+;
+; RV64-LABEL: case1_a:
+; RV64: # %bb.0:
+; RV64-NEXT: sext.w a4, a1
+; RV64-NEXT: li a5, -1
+; RV64-NEXT: sw a5, 0(a0)
+; RV64-NEXT: bge a4, a5, .LBB1_2
+; RV64-NEXT: # %bb.1: # %block1
+; RV64-NEXT: sw a1, 0(a2)
+; RV64-NEXT: ret
+; RV64-NEXT: .LBB1_2: # %block2
+; RV64-NEXT: li a0, 87
+; RV64-NEXT: sw a0, 0(a3)
+; RV64-NEXT: ret
+ store i32 -1, ptr %a
+ %p = icmp sle i32 %b, -2
+ br i1 %p, label %block1, label %block2
+
+block1: ; preds = %0
+ store i32 %b, ptr %c
+ br label %end_block
+
+block2: ; preds = %0
+ store i32 87, ptr %d
+ br label %end_block
+
+end_block: ; preds = %block2, %block1
+ ret void
+}
+
+define void @u_case2_a(ptr %a, i32 %b, ptr %c, ptr %d) {
+; RV32-LABEL: u_case2_a:
+; RV32: # %bb.0:
+; RV32-NEXT: li a4, 32
+; RV32-NEXT: sw a4, 0(a0)
+; RV32-NEXT: bgeu a4, a1, .LBB2_2
+; RV32-NEXT: # %bb.1: # %block1
+; RV32-NEXT: sw a1, 0(a2)
+; RV32-NEXT: ret
+; RV32-NEXT: .LBB2_2: # %block2
+; RV32-NEXT: li a0, 87
+; RV32-NEXT: sw a0, 0(a3)
+; RV32-NEXT: ret
+;
+; RV64-LABEL: u_case2_a:
+; RV64: # %bb.0:
+; RV64-NEXT: sext.w a4, a1
+; RV64-NEXT: li a5, 32
+; RV64-NEXT: sw a5, 0(a0)
+; RV64-NEXT: bgeu a5, a4, .LBB2_2
+; RV64-NEXT: # %bb.1: # %block1
+; RV64-NEXT: sw a1, 0(a2)
+; RV64-NEXT: ret
+; RV64-NEXT: .LBB2_2: # %block2
+; RV64-NEXT: li a0, 87
+; RV64-NEXT: sw a0, 0(a3)
+; RV64-NEXT: ret
+ store i32 32, ptr %a
+ %p = icmp uge i32 %b, 33
+ br i1 %p, label %block1, label %block2
+
+block1: ; preds = %0
+ store i32 %b, ptr %c
+ br label %end_block
+
+block2: ; preds = %0
+ store i32 87, ptr %d
+ br label %end_block
+
+end_block: ; preds = %block2, %block1
+ ret void
+}
+
+define void @case2_a(ptr %a, i32 %b, ptr %c, ptr %d) {
+; RV32-LABEL: case2_a:
+; RV32: # %bb.0:
+; RV32-NEXT: li a4, -4
+; RV32-NEXT: sw a4, 0(a0)
+; RV32-NEXT: bge a4, a1, .LBB3_2
+; RV32-NEXT: # %bb.1: # %block1
+; RV32-NEXT: sw a1, 0(a2)
+; RV32-NEXT: ret
+; RV32-NEXT: .LBB3_2: # %block2
+; RV32-NEXT: li a0, 87
+; RV32-NEXT: sw a0, 0(a3)
+; RV32-NEXT: ret
+;
+; RV64-LABEL: case2_a:
+; RV64: # %bb.0:
+; RV64-NEXT: sext.w a4, a1
+; RV64-NEXT: li a5, -4
+; RV64-NEXT: sw a5, 0(a0)
+; RV64-NEXT: bge a5, a4, .LBB3_2
+; RV64-NEXT: # %bb.1: # %block1
+; RV64-NEXT: sw a1, 0(a2)
+; RV64-NEXT: ret
+; RV64-NEXT: .LBB3_2: # %block2
+; RV64-NEXT: li a0, 87
+; RV64-NEXT: sw a0, 0(a3)
+; RV64-NEXT: ret
+ store i32 -4, ptr %a
+ %p = icmp sge i32 %b, -3
+ br i1 %p, label %block1, label %block2
+
+block1: ; preds = %0
+ store i32 %b, ptr %c
+ br label %end_block
+
+block2: ; preds = %0
+ store i32 87, ptr %d
+ br label %end_block
+
+end_block: ; preds = %block2, %block1
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/71111
More information about the llvm-commits
mailing list