[llvm] 3f68f0f - [RISCV] Optimize 2x SELECT for floating-point types
via llvm-commits
llvm-commits at lists.llvm.org
Sun Jul 10 23:10:39 PDT 2022
Author: LiaoChunyu
Date: 2022-07-11T14:10:27+08:00
New Revision: 3f68f0f8160e4f17528669978baf1471073b90d3
URL: https://github.com/llvm/llvm-project/commit/3f68f0f8160e4f17528669978baf1471073b90d3
DIFF: https://github.com/llvm/llvm-project/commit/3f68f0f8160e4f17528669978baf1471073b90d3.diff
LOG: [RISCV] Optimize 2x SELECT for floating-point types
Including the following opcode:
Select_FPR16_Using_CC_GPR
Select_FPR32_Using_CC_GPR
Select_FPR64_Using_CC_GPR
Reviewed By: craig.topper
Differential Revision: https://reviews.llvm.org/D127871
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/select-optimize-multiple.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 69b2b0d5ade9..9c16da946e84 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -9822,6 +9822,109 @@ static MachineBasicBlock *emitQuietFCMP(MachineInstr &MI, MachineBasicBlock *BB,
return BB;
}
+static MachineBasicBlock *
+EmitLoweredCascadedSelect(MachineInstr &First, MachineInstr &Second,
+ MachineBasicBlock *ThisMBB,
+ const RISCVSubtarget &Subtarget) {
+ // Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5)
+ // Without this, custom-inserter would have generated:
+ //
+ // A
+ // | \
+ // | B
+ // | /
+ // C
+ // | \
+ // | D
+ // | /
+ // E
+ //
+ // A: X = ...; Y = ...
+ // B: empty
+ // C: Z = PHI [X, A], [Y, B]
+ // D: empty
+ // E: PHI [X, C], [Z, D]
+ //
+ // If we lower both Select_FPRX_ in a single step, we can instead generate:
+ //
+ // A
+ // | \
+ // | C
+ // | /|
+ // |/ |
+ // | |
+ // | D
+ // | /
+ // E
+ //
+ // A: X = ...; Y = ...
+ // D: empty
+ // E: PHI [X, A], [X, C], [Y, D]
+
+ const RISCVInstrInfo &TII = *Subtarget.getInstrInfo();
+ const DebugLoc &DL = First.getDebugLoc();
+ const BasicBlock *LLVM_BB = ThisMBB->getBasicBlock();
+ MachineFunction *F = ThisMBB->getParent();
+ MachineBasicBlock *FirstMBB = F->CreateMachineBasicBlock(LLVM_BB);
+ MachineBasicBlock *SecondMBB = F->CreateMachineBasicBlock(LLVM_BB);
+ MachineBasicBlock *SinkMBB = F->CreateMachineBasicBlock(LLVM_BB);
+ MachineFunction::iterator It = ++ThisMBB->getIterator();
+ F->insert(It, FirstMBB);
+ F->insert(It, SecondMBB);
+ F->insert(It, SinkMBB);
+
+ // Transfer the remainder of ThisMBB and its successor edges to SinkMBB.
+ SinkMBB->splice(SinkMBB->begin(), ThisMBB,
+ std::next(MachineBasicBlock::iterator(First)),
+ ThisMBB->end());
+ SinkMBB->transferSuccessorsAndUpdatePHIs(ThisMBB);
+
+ // Fallthrough block for ThisMBB.
+ ThisMBB->addSuccessor(FirstMBB);
+ // Fallthrough block for FirstMBB.
+ FirstMBB->addSuccessor(SecondMBB);
+ ThisMBB->addSuccessor(SinkMBB);
+ FirstMBB->addSuccessor(SinkMBB);
+ // This is fallthrough.
+ SecondMBB->addSuccessor(SinkMBB);
+
+ auto FirstCC = static_cast<RISCVCC::CondCode>(First.getOperand(3).getImm());
+ Register FLHS = First.getOperand(1).getReg();
+ Register FRHS = First.getOperand(2).getReg();
+ // Insert appropriate branch.
+ BuildMI(FirstMBB, DL, TII.getBrCond(FirstCC))
+ .addReg(FLHS)
+ .addReg(FRHS)
+ .addMBB(SinkMBB);
+
+ Register SLHS = Second.getOperand(1).getReg();
+ Register SRHS = Second.getOperand(2).getReg();
+ Register Op1Reg4 = First.getOperand(4).getReg();
+ Register Op1Reg5 = First.getOperand(5).getReg();
+
+ auto SecondCC = static_cast<RISCVCC::CondCode>(Second.getOperand(3).getImm());
+ // Insert appropriate branch.
+ BuildMI(ThisMBB, DL, TII.getBrCond(SecondCC))
+ .addReg(SLHS)
+ .addReg(SRHS)
+ .addMBB(SinkMBB);
+
+ Register DestReg = Second.getOperand(0).getReg();
+ Register Op2Reg4 = Second.getOperand(4).getReg();
+ BuildMI(*SinkMBB, SinkMBB->begin(), DL, TII.get(RISCV::PHI), DestReg)
+ .addReg(Op2Reg4)
+ .addMBB(ThisMBB)
+ .addReg(Op1Reg4)
+ .addMBB(FirstMBB)
+ .addReg(Op1Reg5)
+ .addMBB(SecondMBB);
+
+ // Now remove the Select_FPRX_s.
+ First.eraseFromParent();
+ Second.eraseFromParent();
+ return SinkMBB;
+}
+
static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
MachineBasicBlock *BB,
const RISCVSubtarget &Subtarget) {
@@ -9849,6 +9952,10 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
// previous selects in the sequence.
// These conditions could be further relaxed. See the X86 target for a
// related approach and more information.
+ //
+ // Select_FPRX_ (rs1, rs2, imm, rs4, (Select_FPRX_ rs1, rs2, imm, rs4, rs5))
+ // is checked here and handled by a separate function -
+ // EmitLoweredCascadedSelect.
Register LHS = MI.getOperand(1).getReg();
Register RHS = MI.getOperand(2).getReg();
auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm());
@@ -9858,6 +9965,13 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI,
SelectDests.insert(MI.getOperand(0).getReg());
MachineInstr *LastSelectPseudo = &MI;
+ auto Next = next_nodbg(MI.getIterator(), BB->instr_end());
+ if (MI.getOpcode() != RISCV::Select_GPR_Using_CC_GPR && Next != BB->end() &&
+ Next->getOpcode() == MI.getOpcode() &&
+ Next->getOperand(5).getReg() == MI.getOperand(0).getReg() &&
+ Next->getOperand(5).isKill()) {
+ return EmitLoweredCascadedSelect(MI, *Next, BB, Subtarget);
+ }
for (auto E = BB->end(), SequenceMBBI = MachineBasicBlock::iterator(MI);
SequenceMBBI != E; ++SequenceMBBI) {
diff --git a/llvm/test/CodeGen/RISCV/select-optimize-multiple.ll b/llvm/test/CodeGen/RISCV/select-optimize-multiple.ll
index 491b28ef0851..406ec7f322bd 100644
--- a/llvm/test/CodeGen/RISCV/select-optimize-multiple.ll
+++ b/llvm/test/CodeGen/RISCV/select-optimize-multiple.ll
@@ -533,3 +533,79 @@ entry:
%ret = add i32 %cond1, %cond2
ret i32 %ret
}
+
+define float @CascadedSelect(float noundef %a) {
+; RV32I-LABEL: CascadedSelect:
+; RV32I: # %bb.0: # %entry
+; RV32I-NEXT: lui a1, %hi(.LCPI8_0)
+; RV32I-NEXT: flw ft0, %lo(.LCPI8_0)(a1)
+; RV32I-NEXT: fmv.w.x ft1, a0
+; RV32I-NEXT: flt.s a0, ft0, ft1
+; RV32I-NEXT: bnez a0, .LBB8_3
+; RV32I-NEXT: # %bb.1: # %entry
+; RV32I-NEXT: fmv.w.x ft0, zero
+; RV32I-NEXT: flt.s a0, ft1, ft0
+; RV32I-NEXT: bnez a0, .LBB8_3
+; RV32I-NEXT: # %bb.2: # %entry
+; RV32I-NEXT: fmv.s ft0, ft1
+; RV32I-NEXT: .LBB8_3: # %entry
+; RV32I-NEXT: fmv.x.w a0, ft0
+; RV32I-NEXT: ret
+;
+; RV32IBT-LABEL: CascadedSelect:
+; RV32IBT: # %bb.0: # %entry
+; RV32IBT-NEXT: lui a1, %hi(.LCPI8_0)
+; RV32IBT-NEXT: flw ft0, %lo(.LCPI8_0)(a1)
+; RV32IBT-NEXT: fmv.w.x ft1, a0
+; RV32IBT-NEXT: flt.s a0, ft0, ft1
+; RV32IBT-NEXT: bnez a0, .LBB8_3
+; RV32IBT-NEXT: # %bb.1: # %entry
+; RV32IBT-NEXT: fmv.w.x ft0, zero
+; RV32IBT-NEXT: flt.s a0, ft1, ft0
+; RV32IBT-NEXT: bnez a0, .LBB8_3
+; RV32IBT-NEXT: # %bb.2: # %entry
+; RV32IBT-NEXT: fmv.s ft0, ft1
+; RV32IBT-NEXT: .LBB8_3: # %entry
+; RV32IBT-NEXT: fmv.x.w a0, ft0
+; RV32IBT-NEXT: ret
+;
+; RV64I-LABEL: CascadedSelect:
+; RV64I: # %bb.0: # %entry
+; RV64I-NEXT: lui a1, %hi(.LCPI8_0)
+; RV64I-NEXT: flw ft0, %lo(.LCPI8_0)(a1)
+; RV64I-NEXT: fmv.w.x ft1, a0
+; RV64I-NEXT: flt.s a0, ft0, ft1
+; RV64I-NEXT: bnez a0, .LBB8_3
+; RV64I-NEXT: # %bb.1: # %entry
+; RV64I-NEXT: fmv.w.x ft0, zero
+; RV64I-NEXT: flt.s a0, ft1, ft0
+; RV64I-NEXT: bnez a0, .LBB8_3
+; RV64I-NEXT: # %bb.2: # %entry
+; RV64I-NEXT: fmv.s ft0, ft1
+; RV64I-NEXT: .LBB8_3: # %entry
+; RV64I-NEXT: fmv.x.w a0, ft0
+; RV64I-NEXT: ret
+;
+; RV64IBT-LABEL: CascadedSelect:
+; RV64IBT: # %bb.0: # %entry
+; RV64IBT-NEXT: lui a1, %hi(.LCPI8_0)
+; RV64IBT-NEXT: flw ft0, %lo(.LCPI8_0)(a1)
+; RV64IBT-NEXT: fmv.w.x ft1, a0
+; RV64IBT-NEXT: flt.s a0, ft0, ft1
+; RV64IBT-NEXT: bnez a0, .LBB8_3
+; RV64IBT-NEXT: # %bb.1: # %entry
+; RV64IBT-NEXT: fmv.w.x ft0, zero
+; RV64IBT-NEXT: flt.s a0, ft1, ft0
+; RV64IBT-NEXT: bnez a0, .LBB8_3
+; RV64IBT-NEXT: # %bb.2: # %entry
+; RV64IBT-NEXT: fmv.s ft0, ft1
+; RV64IBT-NEXT: .LBB8_3: # %entry
+; RV64IBT-NEXT: fmv.x.w a0, ft0
+; RV64IBT-NEXT: ret
+entry:
+ %cmp = fcmp ogt float %a, 1.000000e+00
+ %cmp1 = fcmp olt float %a, 0.000000e+00
+ %.a = select i1 %cmp1, float 0.000000e+00, float %a
+ %retval.0 = select i1 %cmp, float 1.000000e+00, float %.a
+ ret float %retval.0
+}
More information about the llvm-commits
mailing list