[llvm] [RISCV] CSE by swapping conditional branches (PR #71111)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 2 14:36:34 PDT 2023


https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/71111

>From 9cb13ef65240fcf0c5bb34f6c90ae9a6901aba45 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Thu, 2 Nov 2023 13:57:51 -0700
Subject: [PATCH 1/2] [RISCV] CSE by swapping conditional branches

This patch performs the following optimization:

For two constants C0 and C1 from
```
li Y, C0
li Z, C1
```
To remove redundnat `li Y, C0`,
 1. if C1 = C0 + 1
 we can turn:
  (a) blt Y, X -> bge X, Z
  (b) bge Y, X -> blt X, Z

 2. if C1 = C0 - 1
 we can turn:
  (a) blt X, Y -> bge Z, X
  (b) bge X, Y -> blt Z, X

This optimization will be done by PeepholeOptimizer through
RISCVInstrInfo::optimizeCondBranch.
---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 127 ++++++++++++++++
 llvm/lib/Target/RISCV/RISCVInstrInfo.h   |   2 +
 llvm/test/CodeGen/RISCV/branch-opt.ll    | 175 +++++++++++++++++++++++
 3 files changed, 304 insertions(+)
 create mode 100644 llvm/test/CodeGen/RISCV/branch-opt.ll

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 412fb7e7f7fc16c..9a3fb7fe39713c7 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1159,6 +1159,133 @@ bool RISCVInstrInfo::reverseBranchCondition(
   return false;
 }
 
+bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
+  MachineBasicBlock *MBB = MI.getParent();
+  MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
+
+  MachineBasicBlock *TBB, *FBB;
+  SmallVector<MachineOperand, 3> Cond;
+  if (analyzeBranch(*MBB, TBB, FBB, Cond, /*AllowModify=*/false))
+    return false;
+  (void)FBB;
+
+  RISCVCC::CondCode CC = static_cast<RISCVCC::CondCode>(Cond[0].getImm());
+  assert(CC != RISCVCC::COND_INVALID);
+
+  if (CC == RISCVCC::COND_EQ || CC == RISCVCC::COND_NE)
+    return false;
+
+  // For two constants C0 and C1 from
+  // ```
+  // li Y, C0
+  // li Z, C1
+  // ```
+  // 1. if C1 = C0 + 1
+  // we can turn:
+  //  (a) blt Y, X -> bge X, Z
+  //  (b) bge Y, X -> blt X, Z
+  //
+  // 2. if C1 = C0 - 1
+  // we can turn:
+  //  (a) blt X, Y -> bge Z, X
+  //  (b) bge X, Y -> blt Z, X
+  //
+  // To make sure this optimization is really beneficial, we only
+  // optimize for cases where Y had only one use (i.e. only used by the branch).
+
+  // Right now we only care about LI (i.e. ADDI rs, 0)
+  auto isLoadImm = [](MachineInstr *MI) -> bool {
+    return MI->getOpcode() == RISCV::ADDI && MI->getOperand(1).isReg() &&
+           MI->getOperand(1).getReg() == RISCV::X0;
+  };
+  // Either a load from immediate instruction or X0.
+  auto isFromLoadImm = [&](const MachineOperand &Op) -> bool {
+    if (!Op.isReg())
+      return false;
+    Register Reg = Op.getReg();
+    if (Reg == RISCV::X0)
+      return true;
+    if (!Reg.isVirtual())
+      return false;
+    return isLoadImm(MRI.getVRegDef(Op.getReg()));
+  };
+
+  MachineOperand &LHS = MI.getOperand(0);
+  MachineOperand &RHS = MI.getOperand(1);
+  auto getConst = [&MRI](MachineOperand &Op) -> int64_t {
+    Register Reg = Op.getReg();
+    if (Reg == RISCV::X0)
+      return 0;
+    assert(Reg.isVirtual());
+    MachineInstr *Def = MRI.getVRegDef(Reg);
+    return Def->getOperand(2).getImm();
+  };
+
+  // Try to find the register for constant Z; return
+  // invalid register otherwise.
+  auto searchConst = [&](int64_t C1) -> Register {
+    MachineInstr *DefC1 = nullptr;
+    MachineBasicBlock::reverse_iterator II(&MI), E = MBB->rend();
+    for (++II; II != E; ++II) {
+      if (isLoadImm(&*II))
+        if (II->getOperand(2).getImm() == C1) {
+          DefC1 = &*II;
+          break;
+        }
+    }
+    if (DefC1)
+      return DefC1->getOperand(0).getReg();
+    else
+      return Register();
+  };
+
+  bool Modify = false;
+  if (isFromLoadImm(LHS) && MRI.hasOneUse(LHS.getReg())) {
+    // Might be case 1.
+    int64_t C0 = getConst(LHS);
+    // Signed integer overflow is UB. (UINT_MAX is bigger so we don't need
+    // to worry about unsigned overflow here)
+    if (C0 < INT_MAX)
+      if (Register RegZ = searchConst(C0 + 1)) {
+        reverseBranchCondition(Cond);
+        Cond[1] = MachineOperand::CreateReg(RHS.getReg(), /*isDef=*/false);
+        Cond[2] = MachineOperand::CreateReg(RegZ, /*isDef=*/false);
+        // We might extend the live range of Z, clear its kill flag to
+        // account for this.
+        MRI.clearKillFlags(RegZ);
+        Modify = true;
+      }
+  } else if (isFromLoadImm(RHS) && MRI.hasOneUse(RHS.getReg())) {
+    // Might be case 2.
+    int64_t C0 = getConst(RHS);
+    // For unsigned cases, we don't want C1 to wrap back to UINT_MAX
+    // when C0 is zero.
+    if ((CC == RISCVCC::COND_GE || CC == RISCVCC::COND_LT) || C0)
+      if (Register RegZ = searchConst(C0 - 1)) {
+        reverseBranchCondition(Cond);
+        Cond[1] = MachineOperand::CreateReg(RegZ, /*isDef=*/false);
+        Cond[2] = MachineOperand::CreateReg(LHS.getReg(), /*isDef=*/false);
+        // We might extend the live range of Z, clear its kill flag to
+        // account for this.
+        MRI.clearKillFlags(RegZ);
+        Modify = true;
+      }
+  }
+
+  if (!Modify)
+    return false;
+
+  // Build the new branch and remove the old one.
+  BuildMI(*MBB, MI, MI.getDebugLoc(),
+          getBrCond(static_cast<RISCVCC::CondCode>(Cond[0].getImm())))
+      .add(Cond[1])
+      .add(Cond[2])
+      .addMBB(TBB);
+  MI.eraseFromParent();
+
+  return true;
+}
+
 MachineBasicBlock *
 RISCVInstrInfo::getBranchDestBlock(const MachineInstr &MI) const {
   assert(MI.getDesc().isBranch() && "Unexpected opcode!");
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index d0112a464677ab5..491278c2e017e7c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -121,6 +121,8 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
   bool
   reverseBranchCondition(SmallVectorImpl<MachineOperand> &Cond) const override;
 
+  bool optimizeCondBranch(MachineInstr &MI) const override;
+
   MachineBasicBlock *getBranchDestBlock(const MachineInstr &MI) const override;
 
   bool isBranchOffsetInRange(unsigned BranchOpc,
diff --git a/llvm/test/CodeGen/RISCV/branch-opt.ll b/llvm/test/CodeGen/RISCV/branch-opt.ll
new file mode 100644
index 000000000000000..8ac21452301dd00
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/branch-opt.ll
@@ -0,0 +1,175 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc -mtriple=riscv32 -O2 -verify-machineinstrs < %s | FileCheck %s --check-prefix=RV32
+; RUN: llc -mtriple=riscv64 -O2 -verify-machineinstrs < %s | FileCheck %s --check-prefix=RV64
+
+define void @u_case1_a(ptr %a, i32 %b, ptr %c, ptr %d) {
+; RV32-LABEL: u_case1_a:
+; RV32:       # %bb.0:
+; RV32-NEXT:    li a4, 32
+; RV32-NEXT:    sw a4, 0(a0)
+; RV32-NEXT:    bgeu a1, a4, .LBB0_2
+; RV32-NEXT:  # %bb.1: # %block1
+; RV32-NEXT:    sw a1, 0(a2)
+; RV32-NEXT:    ret
+; RV32-NEXT:  .LBB0_2: # %block2
+; RV32-NEXT:    li a0, 87
+; RV32-NEXT:    sw a0, 0(a3)
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: u_case1_a:
+; RV64:       # %bb.0:
+; RV64-NEXT:    sext.w a4, a1
+; RV64-NEXT:    li a5, 32
+; RV64-NEXT:    sw a5, 0(a0)
+; RV64-NEXT:    bgeu a4, a5, .LBB0_2
+; RV64-NEXT:  # %bb.1: # %block1
+; RV64-NEXT:    sw a1, 0(a2)
+; RV64-NEXT:    ret
+; RV64-NEXT:  .LBB0_2: # %block2
+; RV64-NEXT:    li a0, 87
+; RV64-NEXT:    sw a0, 0(a3)
+; RV64-NEXT:    ret
+  store i32 32, ptr %a
+  %p = icmp ule i32 %b, 31
+  br i1 %p, label %block1, label %block2
+
+block1:                                           ; preds = %0
+  store i32 %b, ptr %c
+  br label %end_block
+
+block2:                                           ; preds = %0
+  store i32 87, ptr %d
+  br label %end_block
+
+end_block:                                        ; preds = %block2, %block1
+  ret void
+}
+
+define void @case1_a(ptr %a, i32 %b, ptr %c, ptr %d) {
+; RV32-LABEL: case1_a:
+; RV32:       # %bb.0:
+; RV32-NEXT:    li a4, -1
+; RV32-NEXT:    sw a4, 0(a0)
+; RV32-NEXT:    bge a1, a4, .LBB1_2
+; RV32-NEXT:  # %bb.1: # %block1
+; RV32-NEXT:    sw a1, 0(a2)
+; RV32-NEXT:    ret
+; RV32-NEXT:  .LBB1_2: # %block2
+; RV32-NEXT:    li a0, 87
+; RV32-NEXT:    sw a0, 0(a3)
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: case1_a:
+; RV64:       # %bb.0:
+; RV64-NEXT:    sext.w a4, a1
+; RV64-NEXT:    li a5, -1
+; RV64-NEXT:    sw a5, 0(a0)
+; RV64-NEXT:    bge a4, a5, .LBB1_2
+; RV64-NEXT:  # %bb.1: # %block1
+; RV64-NEXT:    sw a1, 0(a2)
+; RV64-NEXT:    ret
+; RV64-NEXT:  .LBB1_2: # %block2
+; RV64-NEXT:    li a0, 87
+; RV64-NEXT:    sw a0, 0(a3)
+; RV64-NEXT:    ret
+  store i32 -1, ptr %a
+  %p = icmp sle i32 %b, -2
+  br i1 %p, label %block1, label %block2
+
+block1:                                           ; preds = %0
+  store i32 %b, ptr %c
+  br label %end_block
+
+block2:                                           ; preds = %0
+  store i32 87, ptr %d
+  br label %end_block
+
+end_block:                                        ; preds = %block2, %block1
+  ret void
+}
+
+define void @u_case2_a(ptr %a, i32 %b, ptr %c, ptr %d) {
+; RV32-LABEL: u_case2_a:
+; RV32:       # %bb.0:
+; RV32-NEXT:    li a4, 32
+; RV32-NEXT:    sw a4, 0(a0)
+; RV32-NEXT:    bgeu a4, a1, .LBB2_2
+; RV32-NEXT:  # %bb.1: # %block1
+; RV32-NEXT:    sw a1, 0(a2)
+; RV32-NEXT:    ret
+; RV32-NEXT:  .LBB2_2: # %block2
+; RV32-NEXT:    li a0, 87
+; RV32-NEXT:    sw a0, 0(a3)
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: u_case2_a:
+; RV64:       # %bb.0:
+; RV64-NEXT:    sext.w a4, a1
+; RV64-NEXT:    li a5, 32
+; RV64-NEXT:    sw a5, 0(a0)
+; RV64-NEXT:    bgeu a5, a4, .LBB2_2
+; RV64-NEXT:  # %bb.1: # %block1
+; RV64-NEXT:    sw a1, 0(a2)
+; RV64-NEXT:    ret
+; RV64-NEXT:  .LBB2_2: # %block2
+; RV64-NEXT:    li a0, 87
+; RV64-NEXT:    sw a0, 0(a3)
+; RV64-NEXT:    ret
+  store i32 32, ptr %a
+  %p = icmp uge i32 %b, 33
+  br i1 %p, label %block1, label %block2
+
+block1:                                           ; preds = %0
+  store i32 %b, ptr %c
+  br label %end_block
+
+block2:                                           ; preds = %0
+  store i32 87, ptr %d
+  br label %end_block
+
+end_block:                                        ; preds = %block2, %block1
+  ret void
+}
+
+define void @case2_a(ptr %a, i32 %b, ptr %c, ptr %d) {
+; RV32-LABEL: case2_a:
+; RV32:       # %bb.0:
+; RV32-NEXT:    li a4, -4
+; RV32-NEXT:    sw a4, 0(a0)
+; RV32-NEXT:    bge a4, a1, .LBB3_2
+; RV32-NEXT:  # %bb.1: # %block1
+; RV32-NEXT:    sw a1, 0(a2)
+; RV32-NEXT:    ret
+; RV32-NEXT:  .LBB3_2: # %block2
+; RV32-NEXT:    li a0, 87
+; RV32-NEXT:    sw a0, 0(a3)
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: case2_a:
+; RV64:       # %bb.0:
+; RV64-NEXT:    sext.w a4, a1
+; RV64-NEXT:    li a5, -4
+; RV64-NEXT:    sw a5, 0(a0)
+; RV64-NEXT:    bge a5, a4, .LBB3_2
+; RV64-NEXT:  # %bb.1: # %block1
+; RV64-NEXT:    sw a1, 0(a2)
+; RV64-NEXT:    ret
+; RV64-NEXT:  .LBB3_2: # %block2
+; RV64-NEXT:    li a0, 87
+; RV64-NEXT:    sw a0, 0(a3)
+; RV64-NEXT:    ret
+  store i32 -4, ptr %a
+  %p = icmp sge i32 %b, -3
+  br i1 %p, label %block1, label %block2
+
+block1:                                           ; preds = %0
+  store i32 %b, ptr %c
+  br label %end_block
+
+block2:                                           ; preds = %0
+  store i32 87, ptr %d
+  br label %end_block
+
+end_block:                                        ; preds = %block2, %block1
+  ret void
+}

>From 214fb857d7bc1b25e9427d93813ada69c5d33231 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Thu, 2 Nov 2023 14:35:36 -0700
Subject: [PATCH 2/2] fixup! [RISCV] CSE by swapping conditional branches

---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 21 ++++++++-------------
 1 file changed, 8 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 9a3fb7fe39713c7..affc13e9a0684eb 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1194,7 +1194,7 @@ bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
   // optimize for cases where Y had only one use (i.e. only used by the branch).
 
   // Right now we only care about LI (i.e. ADDI rs, 0)
-  auto isLoadImm = [](MachineInstr *MI) -> bool {
+  auto isLoadImm = [](const MachineInstr *MI) -> bool {
     return MI->getOpcode() == RISCV::ADDI && MI->getOperand(1).isReg() &&
            MI->getOperand(1).getReg() == RISCV::X0;
   };
@@ -1224,16 +1224,11 @@ bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
   // Try to find the register for constant Z; return
   // invalid register otherwise.
   auto searchConst = [&](int64_t C1) -> Register {
-    MachineInstr *DefC1 = nullptr;
     MachineBasicBlock::reverse_iterator II(&MI), E = MBB->rend();
-    for (++II; II != E; ++II) {
-      if (isLoadImm(&*II))
-        if (II->getOperand(2).getImm() == C1) {
-          DefC1 = &*II;
-          break;
-        }
-    }
-    if (DefC1)
+    auto DefC1 = std::find_if(++II, E, [&](const MachineInstr &I) -> bool {
+      return isLoadImm(&I) && I.getOperand(2).getImm() == C1;
+    });
+    if (DefC1 != E)
       return DefC1->getOperand(0).getReg();
     else
       return Register();
@@ -1243,9 +1238,9 @@ bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
   if (isFromLoadImm(LHS) && MRI.hasOneUse(LHS.getReg())) {
     // Might be case 1.
     int64_t C0 = getConst(LHS);
-    // Signed integer overflow is UB. (UINT_MAX is bigger so we don't need
+    // Signed integer overflow is UB. (UINT64_MAX is bigger so we don't need
     // to worry about unsigned overflow here)
-    if (C0 < INT_MAX)
+    if (C0 < INT64_MAX)
       if (Register RegZ = searchConst(C0 + 1)) {
         reverseBranchCondition(Cond);
         Cond[1] = MachineOperand::CreateReg(RHS.getReg(), /*isDef=*/false);
@@ -1258,7 +1253,7 @@ bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
   } else if (isFromLoadImm(RHS) && MRI.hasOneUse(RHS.getReg())) {
     // Might be case 2.
     int64_t C0 = getConst(RHS);
-    // For unsigned cases, we don't want C1 to wrap back to UINT_MAX
+    // For unsigned cases, we don't want C1 to wrap back to UINT64_MAX
     // when C0 is zero.
     if ((CC == RISCVCC::COND_GE || CC == RISCVCC::COND_LT) || C0)
       if (Register RegZ = searchConst(C0 - 1)) {



More information about the llvm-commits mailing list