[llvm] [RISCV][WIP] Let RA do the CSR saves. (PR #90819)

Mikhail Gudim via llvm-commits llvm-commits at lists.llvm.org
Wed May 1 22:26:53 PDT 2024


https://github.com/mgudim updated https://github.com/llvm/llvm-project/pull/90819

>From de5e1cd3a2b0afbef99d09ef7e2e159cd4dd0933 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at gmail.com>
Date: Thu, 2 May 2024 00:59:08 -0400
Subject: [PATCH] [RISCV][WIP] Let RA do the CSR saves.

We turn the problem of saving and restoring callee-saved registers efficiently into a
register allocation problem. This has the advantage that
the register allocator can essentialy do shrink-wrapping on per register
basis. Currently, shrink-wrapping pass saves all CSR in the same place
which may be suboptimal. Also, improvements to register allocation /
coalescing will translate to improvements in shrink-wrapping.

In `finalizeLowering()` we copy all callee-saved registers from a
physical register to a virtual one. In all return blocks we copy do the
reverse.
---
 llvm/lib/Target/RISCV/RISCVFrameLowering.cpp | 63 +++++++++++++++++--
 llvm/lib/Target/RISCV/RISCVFrameLowering.h   |  2 +
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp  | 64 ++++++++++++++++++++
 llvm/lib/Target/RISCV/RISCVISelLowering.h    |  2 +
 llvm/lib/Target/RISCV/RISCVSubtarget.cpp     |  7 +++
 llvm/lib/Target/RISCV/RISCVSubtarget.h       |  2 +
 6 files changed, 134 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
index cb41577c5d9435..437a935c12e5f4 100644
--- a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
@@ -1026,12 +1026,51 @@ RISCVFrameLowering::getFrameIndexReference(const MachineFunction &MF, int FI,
   return Offset;
 }
 
-void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
-                                              BitVector &SavedRegs,
-                                              RegScavenger *RS) const {
-  TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
-  // Unconditionally spill RA and FP only if the function uses a frame
-  // pointer.
+void RISCVFrameLowering::determineMustCalleeSaves(MachineFunction &MF,
+                                                  BitVector &SavedRegs) const {
+  const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
+
+  // Resize before the early returns. Some backends expect that
+  // SavedRegs.size() == TRI.getNumRegs() after this call even if there are no
+  // saved registers.
+  SavedRegs.resize(TRI.getNumRegs());
+
+  // When interprocedural register allocation is enabled caller saved registers
+  // are preferred over callee saved registers.
+  if (MF.getTarget().Options.EnableIPRA &&
+      isSafeForNoCSROpt(MF.getFunction()) &&
+      isProfitableForNoCSROpt(MF.getFunction()))
+    return;
+
+  // Get the callee saved register list...
+  const MCPhysReg *CSRegs = MF.getRegInfo().getCalleeSavedRegs();
+
+  // Early exit if there are no callee saved registers.
+  if (!CSRegs || CSRegs[0] == 0)
+    return;
+
+  // In Naked functions we aren't going to save any registers.
+  if (MF.getFunction().hasFnAttribute(Attribute::Naked))
+    return;
+
+  // Noreturn+nounwind functions never restore CSR, so no saves are needed.
+  // Purely noreturn functions may still return through throws, so those must
+  // save CSR for caller exception handlers.
+  //
+  // If the function uses longjmp to break out of its current path of
+  // execution we do not need the CSR spills either: setjmp stores all CSRs
+  // it was called with into the jmp_buf, which longjmp then restores.
+  if (MF.getFunction().hasFnAttribute(Attribute::NoReturn) &&
+      MF.getFunction().hasFnAttribute(Attribute::NoUnwind) &&
+      !MF.getFunction().hasFnAttribute(Attribute::UWTable) &&
+      enableCalleeSaveSkip(MF))
+    return;
+
+  // Functions which call __builtin_unwind_init get all their registers saved.
+  if (MF.callsUnwindInit()) {
+    SavedRegs.set();
+    return;
+  }
   if (hasFP(MF)) {
     SavedRegs.set(RISCV::X1);
     SavedRegs.set(RISCV::X8);
@@ -1041,6 +1080,18 @@ void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
     SavedRegs.set(RISCVABI::getBPReg());
 }
 
+void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
+                                              BitVector &SavedRegs,
+                                              RegScavenger *RS) const {
+  const auto &ST = MF.getSubtarget<RISCVSubtarget>();
+  const Function &F = MF.getFunction();
+  determineMustCalleeSaves(MF, SavedRegs);
+  if (ST.doCSRSavesInRA() && F.doesNotThrow())
+    return;
+
+  TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
+}
+
 std::pair<int64_t, Align>
 RISCVFrameLowering::assignRVVStackObjectOffsets(MachineFunction &MF) const {
   MachineFrameInfo &MFI = MF.getFrameInfo();
diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.h b/llvm/lib/Target/RISCV/RISCVFrameLowering.h
index 28ab4aff3b9d51..f6977d8092f441 100644
--- a/llvm/lib/Target/RISCV/RISCVFrameLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.h
@@ -31,6 +31,8 @@ class RISCVFrameLowering : public TargetFrameLowering {
   StackOffset getFrameIndexReference(const MachineFunction &MF, int FI,
                                      Register &FrameReg) const override;
 
+  void determineMustCalleeSaves(MachineFunction &MF,
+                                BitVector &SavedRegs) const;
   void determineCalleeSaves(MachineFunction &MF, BitVector &SavedRegs,
                             RegScavenger *RS) const override;
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 19ef1f2f18ec1a..c63e2d375353c2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -21314,6 +21314,70 @@ unsigned RISCVTargetLowering::getCustomCtpopCost(EVT VT,
   return isCtpopFast(VT) ? 0 : 1;
 }
 
+void RISCVTargetLowering::finalizeLowering(MachineFunction &MF) const {
+  const Function &F = MF.getFunction();
+  if (!Subtarget.doCSRSavesInRA() || !F.doesNotThrow()) {
+    TargetLoweringBase::finalizeLowering(MF);
+    return;
+  }
+
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
+  const RISCVRegisterInfo &TRI = *Subtarget.getRegisterInfo();
+  const RISCVFrameLowering &TFI = *Subtarget.getFrameLowering();
+
+  SmallVector<MachineBasicBlock *, 4> RestoreMBBs;
+  SmallVector<MachineBasicBlock *, 4> SaveMBBs;
+  SaveMBBs.push_back(&MF.front());
+  for (MachineBasicBlock &MBB : MF) {
+    if (MBB.isReturnBlock())
+      RestoreMBBs.push_back(&MBB);
+  }
+
+  BitVector MustCalleeSavedRegs;
+  TFI.determineMustCalleeSaves(MF, MustCalleeSavedRegs);
+  const MCPhysReg *CSRegs = MF.getRegInfo().getCalleeSavedRegs();
+  SmallVector<MCPhysReg, 4> EligibleRegs;
+  for (int i = 0; CSRegs[i]; ++i) {
+    if (!MustCalleeSavedRegs[i])
+      EligibleRegs.push_back(CSRegs[i]);
+  }
+
+  dbgs() << "EligibleRegs: " << EligibleRegs.size() << "\n";
+  SmallVector<Register, 4> VRegs;
+  for (MachineBasicBlock *SaveMBB : SaveMBBs) {
+    for (MCPhysReg Reg : EligibleRegs) {
+      SaveMBB->addLiveIn(Reg);
+      // TODO: should we use Maximal register class instead?
+      Register VReg =
+          MRI.createVirtualRegister(TRI.getMinimalPhysRegClass(Reg));
+      VRegs.push_back(VReg);
+      BuildMI(*SaveMBB, SaveMBB->begin(),
+              SaveMBB->findDebugLoc(SaveMBB->begin()),
+              TII.get(TargetOpcode::COPY), VReg)
+          .addReg(Reg);
+    }
+  }
+
+  for (MachineBasicBlock *RestoreMBB : RestoreMBBs) {
+    MachineInstr &ReturnMI = RestoreMBB->back();
+    assert(ReturnMI.isReturn() && "Expected return instruction!");
+    auto VRegI = VRegs.begin();
+    for (MCPhysReg Reg : EligibleRegs) {
+      Register VReg = *VRegI;
+      BuildMI(*RestoreMBB, ReturnMI.getIterator(), ReturnMI.getDebugLoc(),
+              TII.get(TargetOpcode::COPY), Reg)
+          .addReg(VReg);
+      ReturnMI.addOperand(MF, MachineOperand::CreateReg(Reg,
+                                                        /*isDef=*/false,
+                                                        /*isImplicit=*/true));
+      VRegI++;
+    }
+  }
+
+  TargetLoweringBase::finalizeLowering(MF);
+}
+
 bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
 
   // GISel support is in progress or complete for these opcodes.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 78f99e70c083a7..ea1079af2ead05 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -853,6 +853,8 @@ class RISCVTargetLowering : public TargetLowering {
 
   bool fallBackToDAGISel(const Instruction &Inst) const override;
 
+  void finalizeLowering(MachineFunction &MF) const override;
+
   bool lowerInterleavedLoad(LoadInst *LI,
                             ArrayRef<ShuffleVectorInst *> Shuffles,
                             ArrayRef<unsigned> Indices,
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
index d3236bb07d56d5..88dab938ab1176 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
@@ -65,6 +65,11 @@ static cl::opt<unsigned> RISCVMinimumJumpTableEntries(
     "riscv-min-jump-table-entries", cl::Hidden,
     cl::desc("Set minimum number of entries to use a jump table on RISCV"));
 
+static cl::opt<bool> RISCVEnableSaveCSRByRA(
+    "riscv-enable-save-csr-in-ra",
+    cl::desc("Let register alloctor do csr saves/restores"), cl::init(false),
+    cl::Hidden);
+
 void RISCVSubtarget::anchor() {}
 
 RISCVSubtarget &
@@ -130,6 +135,8 @@ bool RISCVSubtarget::useConstantPoolForLargeInts() const {
   return !RISCVDisableUsingConstantPoolForLargeInts;
 }
 
+bool RISCVSubtarget::doCSRSavesInRA() const { return RISCVEnableSaveCSRByRA; }
+
 unsigned RISCVSubtarget::getMaxBuildIntsCost() const {
   // Loading integer from constant pool needs two instructions (the reason why
   // the minimum cost is 2): an address calculation instruction and a load
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h
index c880c9e921e0ea..f3d8a70c0df14e 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.h
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h
@@ -270,6 +270,8 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
 
   bool useConstantPoolForLargeInts() const;
 
+  bool doCSRSavesInRA() const;
+
   // Maximum cost used for building integers, integers will be put into constant
   // pool if exceeded.
   unsigned getMaxBuildIntsCost() const;



More information about the llvm-commits mailing list