[llvm] 1178992 - [RISCV] Optimize 2x SELECT for floating-point types

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 27 21:02:12 PDT 2022


Author: LiaoChunyu
Date: 2022-06-28T12:02:05+08:00
New Revision: 1178992c72b002c3b2c87203252c566eeb273cc1

URL: https://github.com/llvm/llvm-project/commit/1178992c72b002c3b2c87203252c566eeb273cc1
DIFF: https://github.com/llvm/llvm-project/commit/1178992c72b002c3b2c87203252c566eeb273cc1.diff

LOG: [RISCV] Optimize 2x SELECT for floating-point types

Including the following opcode:
 Select_FPR16_Using_CC_GPR
 Select_FPR32_Using_CC_GPR
 Select_FPR64_Using_CC_GPR

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D127871

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/select-optimize-multiple.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 4fa3c7603400..72b1d4ce82d7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -9708,6 +9708,109 @@ static MachineBasicBlock *emitQuietFCMP(MachineInstr &MI, MachineBasicBlock *BB,
   return BB;
 }
 
+static MachineBasicBlock *
+EmitLoweredCascadedSelect(MachineInstr &First, MachineInstr &Second,
+                          MachineBasicBlock *ThisMBB,
+                          const RISCVSubtarget &Subtarget) {
+  // Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5)
+  // Without this, custom-inserter would have generated:
+  //
+  //   A
+  //   | \
+  //   |  B
+  //   | /
+  //   C
+  //   | \
+  //   |  D
+  //   | /
+  //   E
+  //
+  // A: X = ...; Y = ...
+  // B: empty
+  // C: Z = PHI [X, A], [Y, B]
+  // D: empty
+  // E: PHI [X, C], [Z, D]
+  //
+  // If we lower both Select_FPRX_ in a single step, we can instead generate:
+  //
+  //   A
+  //   | \
+  //   |  C
+  //   | /|
+  //   |/ |
+  //   |  |
+  //   |  D
+  //   | /
+  //   E
+  //
+  // A: X = ...; Y = ...
+  // D: empty
+  // E: PHI [X, A], [X, C], [Y, D]
+
+  const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
+  const DebugLoc &DL = First.getDebugLoc();
+  const BasicBlock *LLVM_BB = ThisMBB->getBasicBlock();
+  MachineFunction *F = ThisMBB->getParent();
+  MachineBasicBlock *FirstMBB = F->CreateMachineBasicBlock(LLVM_BB);
+  MachineBasicBlock *SecondMBB = F->CreateMachineBasicBlock(LLVM_BB);
+  MachineBasicBlock *SinkMBB = F->CreateMachineBasicBlock(LLVM_BB);
+  MachineFunction::iterator It = ++ThisMBB->getIterator();
+  F->insert(It, FirstMBB);
+  F->insert(It, SecondMBB);
+  F->insert(It, SinkMBB);
+
+  // Transfer the remainder of ThisMBB and its successor edges to SinkMBB.
+  SinkMBB->splice(SinkMBB->begin(), ThisMBB,
+                  std::next(MachineBasicBlock::iterator(First)),
+                  ThisMBB->end());
+  SinkMBB->transferSuccessorsAndUpdatePHIs(ThisMBB);
+
+  // Fallthrough block for ThisMBB.
+  ThisMBB->addSuccessor(FirstMBB);
+  // Fallthrough block for FirstMBB.
+  FirstMBB->addSuccessor(SecondMBB);
+  ThisMBB->addSuccessor(SinkMBB);
+  FirstMBB->addSuccessor(SinkMBB);
+  // This is fallthrough.
+  SecondMBB->addSuccessor(SinkMBB);
+
+  auto FirstCC = static_cast<RISCVCC::CondCode>(First.getOperand(3).getImm());
+  Register FLHS = First.getOperand(1).getReg();
+  Register FRHS = First.getOperand(2).getReg();
+  // Insert appropriate branch.
+  BuildMI(ThisMBB, DL, TII.getBrCond(FirstCC))
+      .addReg(FLHS)
+      .addReg(FRHS)
+      .addMBB(SinkMBB);
+
+  Register SLHS = Second.getOperand(1).getReg();
+  Register SRHS = Second.getOperand(2).getReg();
+  Register Op1Reg4 = First.getOperand(4).getReg();
+  Register Op1Reg5 = First.getOperand(5).getReg();
+
+  auto SecondCC = static_cast<RISCVCC::CondCode>(Second.getOperand(3).getImm());
+  // Insert appropriate branch.
+  BuildMI(FirstMBB, DL, TII.getBrCond(SecondCC))
+      .addReg(SLHS)
+      .addReg(SRHS)
+      .addMBB(SinkMBB);
+
+  Register DestReg = Second.getOperand(0).getReg();
+  Register Op2Reg4 = Second.getOperand(4).getReg();
+  BuildMI(*SinkMBB, SinkMBB->begin(), DL, TII.get(RISCV::PHI), DestReg)
+      .addReg(Op1Reg4)
+      .addMBB(ThisMBB)
+      .addReg(Op2Reg4)
+      .addMBB(FirstMBB)
+      .addReg(Op1Reg5)
+      .addMBB(SecondMBB);
+
+  // Now remove the Select_FPRX_s.
+  First.eraseFromParent();
+  Second.eraseFromParent();
+  return SinkMBB;
+}
+
 static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
                                            MachineBasicBlock *BB,
                                            const RISCVSubtarget &Subtarget) {
@@ -9735,6 +9838,10 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
   // previous selects in the sequence.
   // These conditions could be further relaxed. See the X86 target for a
   // related approach and more information.
+  //
+  // Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5))
+  // is checked here and handled by a separate function -
+  // EmitLoweredCascadedSelect.
   Register LHS = MI.getOperand(1).getReg();
   Register RHS = MI.getOperand(2).getReg();
   auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm());
@@ -9744,6 +9851,13 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
   SelectDests.insert(MI.getOperand(0).getReg());
 
   MachineInstr *LastSelectPseudo = &MI;
+  auto Next = next_nodbg(MI.getIterator(), BB->instr_end());
+  if (MI.getOpcode() != RISCV::Select_GPR_Using_CC_GPR && Next != BB->end() &&
+      Next->getOpcode() == MI.getOpcode() &&
+      Next->getOperand(5).getReg() == MI.getOperand(0).getReg() &&
+      Next->getOperand(5).isKill()) {
+    return EmitLoweredCascadedSelect(MI, *Next, BB, Subtarget);
+  }
 
   for (auto E = BB->end(), SequenceMBBI = MachineBasicBlock::iterator(MI);
        SequenceMBBI != E; ++SequenceMBBI) {

diff  --git a/llvm/test/CodeGen/RISCV/select-optimize-multiple.ll b/llvm/test/CodeGen/RISCV/select-optimize-multiple.ll
index 491b28ef0851..72c668448405 100644
--- a/llvm/test/CodeGen/RISCV/select-optimize-multiple.ll
+++ b/llvm/test/CodeGen/RISCV/select-optimize-multiple.ll
@@ -533,3 +533,79 @@ entry:
   %ret = add i32 %cond1, %cond2
   ret i32 %ret
 }
+
+define float @CascadedSelect(float noundef %a) {
+; RV32I-LABEL: CascadedSelect:
+; RV32I:       # %bb.0: # %entry
+; RV32I-NEXT:    fmv.w.x ft0, a0
+; RV32I-NEXT:    fmv.w.x ft1, zero
+; RV32I-NEXT:    flt.s a0, ft0, ft1
+; RV32I-NEXT:    bnez a0, .LBB8_3
+; RV32I-NEXT:  # %bb.1: # %entry
+; RV32I-NEXT:    lui a0, %hi(.LCPI8_0)
+; RV32I-NEXT:    flw ft1, %lo(.LCPI8_0)(a0)
+; RV32I-NEXT:    flt.s a0, ft1, ft0
+; RV32I-NEXT:    bnez a0, .LBB8_3
+; RV32I-NEXT:  # %bb.2: # %entry
+; RV32I-NEXT:    fmv.s ft1, ft0
+; RV32I-NEXT:  .LBB8_3: # %entry
+; RV32I-NEXT:    fmv.x.w a0, ft1
+; RV32I-NEXT:    ret
+;
+; RV32IBT-LABEL: CascadedSelect:
+; RV32IBT:       # %bb.0: # %entry
+; RV32IBT-NEXT:    fmv.w.x ft0, a0
+; RV32IBT-NEXT:    fmv.w.x ft1, zero
+; RV32IBT-NEXT:    flt.s a0, ft0, ft1
+; RV32IBT-NEXT:    bnez a0, .LBB8_3
+; RV32IBT-NEXT:  # %bb.1: # %entry
+; RV32IBT-NEXT:    lui a0, %hi(.LCPI8_0)
+; RV32IBT-NEXT:    flw ft1, %lo(.LCPI8_0)(a0)
+; RV32IBT-NEXT:    flt.s a0, ft1, ft0
+; RV32IBT-NEXT:    bnez a0, .LBB8_3
+; RV32IBT-NEXT:  # %bb.2: # %entry
+; RV32IBT-NEXT:    fmv.s ft1, ft0
+; RV32IBT-NEXT:  .LBB8_3: # %entry
+; RV32IBT-NEXT:    fmv.x.w a0, ft1
+; RV32IBT-NEXT:    ret
+;
+; RV64I-LABEL: CascadedSelect:
+; RV64I:       # %bb.0: # %entry
+; RV64I-NEXT:    fmv.w.x ft0, a0
+; RV64I-NEXT:    fmv.w.x ft1, zero
+; RV64I-NEXT:    flt.s a0, ft0, ft1
+; RV64I-NEXT:    bnez a0, .LBB8_3
+; RV64I-NEXT:  # %bb.1: # %entry
+; RV64I-NEXT:    lui a0, %hi(.LCPI8_0)
+; RV64I-NEXT:    flw ft1, %lo(.LCPI8_0)(a0)
+; RV64I-NEXT:    flt.s a0, ft1, ft0
+; RV64I-NEXT:    bnez a0, .LBB8_3
+; RV64I-NEXT:  # %bb.2: # %entry
+; RV64I-NEXT:    fmv.s ft1, ft0
+; RV64I-NEXT:  .LBB8_3: # %entry
+; RV64I-NEXT:    fmv.x.w a0, ft1
+; RV64I-NEXT:    ret
+;
+; RV64IBT-LABEL: CascadedSelect:
+; RV64IBT:       # %bb.0: # %entry
+; RV64IBT-NEXT:    fmv.w.x ft0, a0
+; RV64IBT-NEXT:    fmv.w.x ft1, zero
+; RV64IBT-NEXT:    flt.s a0, ft0, ft1
+; RV64IBT-NEXT:    bnez a0, .LBB8_3
+; RV64IBT-NEXT:  # %bb.1: # %entry
+; RV64IBT-NEXT:    lui a0, %hi(.LCPI8_0)
+; RV64IBT-NEXT:    flw ft1, %lo(.LCPI8_0)(a0)
+; RV64IBT-NEXT:    flt.s a0, ft1, ft0
+; RV64IBT-NEXT:    bnez a0, .LBB8_3
+; RV64IBT-NEXT:  # %bb.2: # %entry
+; RV64IBT-NEXT:    fmv.s ft1, ft0
+; RV64IBT-NEXT:  .LBB8_3: # %entry
+; RV64IBT-NEXT:    fmv.x.w a0, ft1
+; RV64IBT-NEXT:    ret
+entry:
+  %cmp = fcmp ogt float %a, 1.000000e+00
+  %cmp1 = fcmp olt float %a, 0.000000e+00
+  %.a = select i1 %cmp1, float 0.000000e+00, float %a
+  %retval.0 = select i1 %cmp, float 1.000000e+00, float %.a
+  ret float %retval.0
+}


        


More information about the llvm-commits mailing list