[llvm] [RISCV][WIP] Let RA do the CSR saves. (PR #90819)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 1 22:09:33 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Mikhail Gudim (mgudim)

<details>
<summary>Changes</summary>

We turn the problem of saving and restoring callee-saved registers efficiently into a register allocation problem. This has the advantage that the register allocator can essentialy do shrink-wrapping on per register basis. Currently, shrink-wrapping pass saves all CSR in the same place which may be suboptimal. Also, improvements to register allocation / coalescing will translate to improvements in shrink-wrapping.

In `finalizeLowering()` we copy all callee-saved registers from a physical register to a virtual one. In all return blocks we copy do the reverse.

---
Full diff: https://github.com/llvm/llvm-project/pull/90819.diff


6 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVFrameLowering.cpp (+57-6) 
- (modified) llvm/lib/Target/RISCV/RISCVFrameLowering.h (+1) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+77) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+2) 
- (modified) llvm/lib/Target/RISCV/RISCVSubtarget.cpp (+9) 
- (modified) llvm/lib/Target/RISCV/RISCVSubtarget.h (+2) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
index cb41577c5d9435..b725bfb56389bc 100644
--- a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
@@ -1026,12 +1026,51 @@ RISCVFrameLowering::getFrameIndexReference(const MachineFunction &MF, int FI,
   return Offset;
 }
 
-void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
-                                              BitVector &SavedRegs,
-                                              RegScavenger *RS) const {
-  TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
-  // Unconditionally spill RA and FP only if the function uses a frame
-  // pointer.
+void RISCVFrameLowering::determineMustCalleeSaves(MachineFunction &MF,
+                                              BitVector &SavedRegs) const {
+  const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
+
+  // Resize before the early returns. Some backends expect that
+  // SavedRegs.size() == TRI.getNumRegs() after this call even if there are no
+  // saved registers.
+  SavedRegs.resize(TRI.getNumRegs());
+
+  // When interprocedural register allocation is enabled caller saved registers
+  // are preferred over callee saved registers.
+  if (MF.getTarget().Options.EnableIPRA &&
+      isSafeForNoCSROpt(MF.getFunction()) &&
+      isProfitableForNoCSROpt(MF.getFunction()))
+    return;
+
+  // Get the callee saved register list...
+  const MCPhysReg *CSRegs = MF.getRegInfo().getCalleeSavedRegs();
+
+  // Early exit if there are no callee saved registers.
+  if (!CSRegs || CSRegs[0] == 0)
+    return;
+
+  // In Naked functions we aren't going to save any registers.
+  if (MF.getFunction().hasFnAttribute(Attribute::Naked))
+    return;
+
+  // Noreturn+nounwind functions never restore CSR, so no saves are needed.
+  // Purely noreturn functions may still return through throws, so those must
+  // save CSR for caller exception handlers.
+  //
+  // If the function uses longjmp to break out of its current path of
+  // execution we do not need the CSR spills either: setjmp stores all CSRs
+  // it was called with into the jmp_buf, which longjmp then restores.
+  if (MF.getFunction().hasFnAttribute(Attribute::NoReturn) &&
+        MF.getFunction().hasFnAttribute(Attribute::NoUnwind) &&
+        !MF.getFunction().hasFnAttribute(Attribute::UWTable) &&
+        enableCalleeSaveSkip(MF))
+    return;
+
+  // Functions which call __builtin_unwind_init get all their registers saved.
+  if (MF.callsUnwindInit()) {
+    SavedRegs.set();
+    return;
+  }
   if (hasFP(MF)) {
     SavedRegs.set(RISCV::X1);
     SavedRegs.set(RISCV::X8);
@@ -1041,6 +1080,18 @@ void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
     SavedRegs.set(RISCVABI::getBPReg());
 }
 
+void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
+                                              BitVector &SavedRegs,
+                                              RegScavenger *RS) const {
+  const auto &ST = MF.getSubtarget<RISCVSubtarget>();
+  const Function &F = MF.getFunction();
+  determineMustCalleeSaves(MF, SavedRegs);
+  if (ST.doCSRSavesInRA() && F.doesNotThrow())
+    return;
+
+  TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
+}
+
 std::pair<int64_t, Align>
 RISCVFrameLowering::assignRVVStackObjectOffsets(MachineFunction &MF) const {
   MachineFrameInfo &MFI = MF.getFrameInfo();
diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.h b/llvm/lib/Target/RISCV/RISCVFrameLowering.h
index 28ab4aff3b9d51..d7b9df8bd68515 100644
--- a/llvm/lib/Target/RISCV/RISCVFrameLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.h
@@ -31,6 +31,7 @@ class RISCVFrameLowering : public TargetFrameLowering {
   StackOffset getFrameIndexReference(const MachineFunction &MF, int FI,
                                      Register &FrameReg) const override;
 
+  void determineMustCalleeSaves(MachineFunction &MF, BitVector &SavedRegs) const;
   void determineCalleeSaves(MachineFunction &MF, BitVector &SavedRegs,
                             RegScavenger *RS) const override;
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 19ef1f2f18ec1a..7978dac4aa7944 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -21314,6 +21314,83 @@ unsigned RISCVTargetLowering::getCustomCtpopCost(EVT VT,
   return isCtpopFast(VT) ? 0 : 1;
 }
 
+void RISCVTargetLowering::finalizeLowering(MachineFunction &MF) const {
+  const Function &F = MF.getFunction();
+  if (!Subtarget.doCSRSavesInRA() || !F.doesNotThrow()) { 
+    TargetLoweringBase::finalizeLowering(MF);
+    return;
+  }
+
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
+  const RISCVRegisterInfo &TRI = *Subtarget.getRegisterInfo();
+  const RISCVFrameLowering &TFI = *Subtarget.getFrameLowering();
+
+  SmallVector<MachineBasicBlock *, 4> RestoreMBBs;
+  SmallVector<MachineBasicBlock *, 4> SaveMBBs;
+  SaveMBBs.push_back(&MF.front());
+  for (MachineBasicBlock &MBB : MF) {
+    if (MBB.isReturnBlock())
+      RestoreMBBs.push_back(&MBB);
+  }
+
+  BitVector MustCalleeSavedRegs;
+  TFI.determineMustCalleeSaves(MF, MustCalleeSavedRegs);
+  const MCPhysReg * CSRegs = MF.getRegInfo().getCalleeSavedRegs();
+  SmallVector<MCPhysReg, 4> EligibleRegs;
+  for (int i = 0; CSRegs[i]; ++i) {
+    if (!MustCalleeSavedRegs[i])
+      EligibleRegs.push_back(CSRegs[i]);
+  }
+
+  dbgs() << "EligibleRegs: " << EligibleRegs.size() << "\n";
+  SmallVector<Register, 4> VRegs;
+  for (MachineBasicBlock *SaveMBB : SaveMBBs) {
+    for (MCPhysReg Reg : EligibleRegs) {
+      SaveMBB->addLiveIn(Reg);
+      // TODO: should we use Maximal register class instead?
+      Register VReg = MRI.createVirtualRegister(TRI.getMinimalPhysRegClass(Reg));
+      VRegs.push_back(VReg);
+      BuildMI(
+        *SaveMBB,
+        SaveMBB->begin(),
+        SaveMBB->findDebugLoc(SaveMBB->begin()),
+        TII.get(TargetOpcode::COPY),
+        VReg
+      )
+      .addReg(Reg);
+    }
+  }
+
+  for (MachineBasicBlock *RestoreMBB : RestoreMBBs) {
+    MachineInstr &ReturnMI = RestoreMBB->back();
+    assert(ReturnMI.isReturn() && "Expected return instruction!");
+    auto VRegI = VRegs.begin();
+    for (MCPhysReg Reg : EligibleRegs) {
+      Register VReg = *VRegI;
+      BuildMI(
+        *RestoreMBB,
+        ReturnMI.getIterator(),
+        ReturnMI.getDebugLoc(),
+        TII.get(TargetOpcode::COPY),
+        Reg
+      )
+      .addReg(VReg);
+      ReturnMI.addOperand(
+        MF,
+        MachineOperand::CreateReg(
+          Reg,
+          /*isDef=*/false,
+          /*isImplicit=*/true
+        )
+      );
+      VRegI++;
+    }
+  }
+
+  TargetLoweringBase::finalizeLowering(MF);
+}
+
 bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
 
   // GISel support is in progress or complete for these opcodes.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 78f99e70c083a7..ea1079af2ead05 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -853,6 +853,8 @@ class RISCVTargetLowering : public TargetLowering {
 
   bool fallBackToDAGISel(const Instruction &Inst) const override;
 
+  void finalizeLowering(MachineFunction &MF) const override;
+
   bool lowerInterleavedLoad(LoadInst *LI,
                             ArrayRef<ShuffleVectorInst *> Shuffles,
                             ArrayRef<unsigned> Indices,
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
index d3236bb07d56d5..15476fc2d3c583 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
@@ -65,6 +65,11 @@ static cl::opt<unsigned> RISCVMinimumJumpTableEntries(
     "riscv-min-jump-table-entries", cl::Hidden,
     cl::desc("Set minimum number of entries to use a jump table on RISCV"));
 
+static cl::opt<bool> RISCVEnableSaveCSRByRA(
+    "riscv-enable-save-csr-in-ra",
+    cl::desc("Let register alloctor do csr saves/restores"),
+    cl::init(false), cl::Hidden);
+
 void RISCVSubtarget::anchor() {}
 
 RISCVSubtarget &
@@ -130,6 +135,10 @@ bool RISCVSubtarget::useConstantPoolForLargeInts() const {
   return !RISCVDisableUsingConstantPoolForLargeInts;
 }
 
+bool RISCVSubtarget::doCSRSavesInRA() const {
+  return RISCVEnableSaveCSRByRA;
+}
+
 unsigned RISCVSubtarget::getMaxBuildIntsCost() const {
   // Loading integer from constant pool needs two instructions (the reason why
   // the minimum cost is 2): an address calculation instruction and a load
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h
index c880c9e921e0ea..f3d8a70c0df14e 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.h
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h
@@ -270,6 +270,8 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
 
   bool useConstantPoolForLargeInts() const;
 
+  bool doCSRSavesInRA() const;
+
   // Maximum cost used for building integers, integers will be put into constant
   // pool if exceeded.
   unsigned getMaxBuildIntsCost() const;

``````````

</details>


https://github.com/llvm/llvm-project/pull/90819


More information about the llvm-commits mailing list