[llvm] [RISCV][WIP] Let RA do the CSR saves. (PR #90819)
Mikhail Gudim via llvm-commits
llvm-commits at lists.llvm.org
Wed May 1 22:26:53 PDT 2024
https://github.com/mgudim updated https://github.com/llvm/llvm-project/pull/90819
>From de5e1cd3a2b0afbef99d09ef7e2e159cd4dd0933 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at gmail.com>
Date: Thu, 2 May 2024 00:59:08 -0400
Subject: [PATCH] [RISCV][WIP] Let RA do the CSR saves.
We turn the problem of saving and restoring callee-saved registers efficiently into a
register allocation problem. This has the advantage that
the register allocator can essentialy do shrink-wrapping on per register
basis. Currently, shrink-wrapping pass saves all CSR in the same place
which may be suboptimal. Also, improvements to register allocation /
coalescing will translate to improvements in shrink-wrapping.
In `finalizeLowering()` we copy all callee-saved registers from a
physical register to a virtual one. In all return blocks we copy do the
reverse.
---
llvm/lib/Target/RISCV/RISCVFrameLowering.cpp | 63 +++++++++++++++++--
llvm/lib/Target/RISCV/RISCVFrameLowering.h | 2 +
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 64 ++++++++++++++++++++
llvm/lib/Target/RISCV/RISCVISelLowering.h | 2 +
llvm/lib/Target/RISCV/RISCVSubtarget.cpp | 7 +++
llvm/lib/Target/RISCV/RISCVSubtarget.h | 2 +
6 files changed, 134 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
index cb41577c5d9435..437a935c12e5f4 100644
--- a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
@@ -1026,12 +1026,51 @@ RISCVFrameLowering::getFrameIndexReference(const MachineFunction &MF, int FI,
return Offset;
}
-void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
- BitVector &SavedRegs,
- RegScavenger *RS) const {
- TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
- // Unconditionally spill RA and FP only if the function uses a frame
- // pointer.
+void RISCVFrameLowering::determineMustCalleeSaves(MachineFunction &MF,
+ BitVector &SavedRegs) const {
+ const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
+
+ // Resize before the early returns. Some backends expect that
+ // SavedRegs.size() == TRI.getNumRegs() after this call even if there are no
+ // saved registers.
+ SavedRegs.resize(TRI.getNumRegs());
+
+ // When interprocedural register allocation is enabled caller saved registers
+ // are preferred over callee saved registers.
+ if (MF.getTarget().Options.EnableIPRA &&
+ isSafeForNoCSROpt(MF.getFunction()) &&
+ isProfitableForNoCSROpt(MF.getFunction()))
+ return;
+
+ // Get the callee saved register list...
+ const MCPhysReg *CSRegs = MF.getRegInfo().getCalleeSavedRegs();
+
+ // Early exit if there are no callee saved registers.
+ if (!CSRegs || CSRegs[0] == 0)
+ return;
+
+ // In Naked functions we aren't going to save any registers.
+ if (MF.getFunction().hasFnAttribute(Attribute::Naked))
+ return;
+
+ // Noreturn+nounwind functions never restore CSR, so no saves are needed.
+ // Purely noreturn functions may still return through throws, so those must
+ // save CSR for caller exception handlers.
+ //
+ // If the function uses longjmp to break out of its current path of
+ // execution we do not need the CSR spills either: setjmp stores all CSRs
+ // it was called with into the jmp_buf, which longjmp then restores.
+ if (MF.getFunction().hasFnAttribute(Attribute::NoReturn) &&
+ MF.getFunction().hasFnAttribute(Attribute::NoUnwind) &&
+ !MF.getFunction().hasFnAttribute(Attribute::UWTable) &&
+ enableCalleeSaveSkip(MF))
+ return;
+
+ // Functions which call __builtin_unwind_init get all their registers saved.
+ if (MF.callsUnwindInit()) {
+ SavedRegs.set();
+ return;
+ }
if (hasFP(MF)) {
SavedRegs.set(RISCV::X1);
SavedRegs.set(RISCV::X8);
@@ -1041,6 +1080,18 @@ void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
SavedRegs.set(RISCVABI::getBPReg());
}
+void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
+ BitVector &SavedRegs,
+ RegScavenger *RS) const {
+ const auto &ST = MF.getSubtarget<RISCVSubtarget>();
+ const Function &F = MF.getFunction();
+ determineMustCalleeSaves(MF, SavedRegs);
+ if (ST.doCSRSavesInRA() && F.doesNotThrow())
+ return;
+
+ TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
+}
+
std::pair<int64_t, Align>
RISCVFrameLowering::assignRVVStackObjectOffsets(MachineFunction &MF) const {
MachineFrameInfo &MFI = MF.getFrameInfo();
diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.h b/llvm/lib/Target/RISCV/RISCVFrameLowering.h
index 28ab4aff3b9d51..f6977d8092f441 100644
--- a/llvm/lib/Target/RISCV/RISCVFrameLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.h
@@ -31,6 +31,8 @@ class RISCVFrameLowering : public TargetFrameLowering {
StackOffset getFrameIndexReference(const MachineFunction &MF, int FI,
Register &FrameReg) const override;
+ void determineMustCalleeSaves(MachineFunction &MF,
+ BitVector &SavedRegs) const;
void determineCalleeSaves(MachineFunction &MF, BitVector &SavedRegs,
RegScavenger *RS) const override;
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 19ef1f2f18ec1a..c63e2d375353c2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -21314,6 +21314,70 @@ unsigned RISCVTargetLowering::getCustomCtpopCost(EVT VT,
return isCtpopFast(VT) ? 0 : 1;
}
+void RISCVTargetLowering::finalizeLowering(MachineFunction &MF) const {
+ const Function &F = MF.getFunction();
+ if (!Subtarget.doCSRSavesInRA() || !F.doesNotThrow()) {
+ TargetLoweringBase::finalizeLowering(MF);
+ return;
+ }
+
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
+ const RISCVRegisterInfo &TRI = *Subtarget.getRegisterInfo();
+ const RISCVFrameLowering &TFI = *Subtarget.getFrameLowering();
+
+ SmallVector<MachineBasicBlock *, 4> RestoreMBBs;
+ SmallVector<MachineBasicBlock *, 4> SaveMBBs;
+ SaveMBBs.push_back(&MF.front());
+ for (MachineBasicBlock &MBB : MF) {
+ if (MBB.isReturnBlock())
+ RestoreMBBs.push_back(&MBB);
+ }
+
+ BitVector MustCalleeSavedRegs;
+ TFI.determineMustCalleeSaves(MF, MustCalleeSavedRegs);
+ const MCPhysReg *CSRegs = MF.getRegInfo().getCalleeSavedRegs();
+ SmallVector<MCPhysReg, 4> EligibleRegs;
+ for (int i = 0; CSRegs[i]; ++i) {
+ if (!MustCalleeSavedRegs[i])
+ EligibleRegs.push_back(CSRegs[i]);
+ }
+
+ dbgs() << "EligibleRegs: " << EligibleRegs.size() << "\n";
+ SmallVector<Register, 4> VRegs;
+ for (MachineBasicBlock *SaveMBB : SaveMBBs) {
+ for (MCPhysReg Reg : EligibleRegs) {
+ SaveMBB->addLiveIn(Reg);
+ // TODO: should we use Maximal register class instead?
+ Register VReg =
+ MRI.createVirtualRegister(TRI.getMinimalPhysRegClass(Reg));
+ VRegs.push_back(VReg);
+ BuildMI(*SaveMBB, SaveMBB->begin(),
+ SaveMBB->findDebugLoc(SaveMBB->begin()),
+ TII.get(TargetOpcode::COPY), VReg)
+ .addReg(Reg);
+ }
+ }
+
+ for (MachineBasicBlock *RestoreMBB : RestoreMBBs) {
+ MachineInstr &ReturnMI = RestoreMBB->back();
+ assert(ReturnMI.isReturn() && "Expected return instruction!");
+ auto VRegI = VRegs.begin();
+ for (MCPhysReg Reg : EligibleRegs) {
+ Register VReg = *VRegI;
+ BuildMI(*RestoreMBB, ReturnMI.getIterator(), ReturnMI.getDebugLoc(),
+ TII.get(TargetOpcode::COPY), Reg)
+ .addReg(VReg);
+ ReturnMI.addOperand(MF, MachineOperand::CreateReg(Reg,
+ /*isDef=*/false,
+ /*isImplicit=*/true));
+ VRegI++;
+ }
+ }
+
+ TargetLoweringBase::finalizeLowering(MF);
+}
+
bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
// GISel support is in progress or complete for these opcodes.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 78f99e70c083a7..ea1079af2ead05 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -853,6 +853,8 @@ class RISCVTargetLowering : public TargetLowering {
bool fallBackToDAGISel(const Instruction &Inst) const override;
+ void finalizeLowering(MachineFunction &MF) const override;
+
bool lowerInterleavedLoad(LoadInst *LI,
ArrayRef<ShuffleVectorInst *> Shuffles,
ArrayRef<unsigned> Indices,
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
index d3236bb07d56d5..88dab938ab1176 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
@@ -65,6 +65,11 @@ static cl::opt<unsigned> RISCVMinimumJumpTableEntries(
"riscv-min-jump-table-entries", cl::Hidden,
cl::desc("Set minimum number of entries to use a jump table on RISCV"));
+static cl::opt<bool> RISCVEnableSaveCSRByRA(
+ "riscv-enable-save-csr-in-ra",
+ cl::desc("Let register alloctor do csr saves/restores"), cl::init(false),
+ cl::Hidden);
+
void RISCVSubtarget::anchor() {}
RISCVSubtarget &
@@ -130,6 +135,8 @@ bool RISCVSubtarget::useConstantPoolForLargeInts() const {
return !RISCVDisableUsingConstantPoolForLargeInts;
}
+bool RISCVSubtarget::doCSRSavesInRA() const { return RISCVEnableSaveCSRByRA; }
+
unsigned RISCVSubtarget::getMaxBuildIntsCost() const {
// Loading integer from constant pool needs two instructions (the reason why
// the minimum cost is 2): an address calculation instruction and a load
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h
index c880c9e921e0ea..f3d8a70c0df14e 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.h
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h
@@ -270,6 +270,8 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
bool useConstantPoolForLargeInts() const;
+ bool doCSRSavesInRA() const;
+
// Maximum cost used for building integers, integers will be put into constant
// pool if exceeded.
unsigned getMaxBuildIntsCost() const;
More information about the llvm-commits
mailing list