[llvm] [RISCV][WIP] Let RA do the CSR saves. (PR #90819)

Mikhail Gudim via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 6 09:58:18 PST 2024


https://github.com/mgudim updated https://github.com/llvm/llvm-project/pull/90819

>From 5cbd4b0949dffe959a03c1ca2963b4df8b23d9b5 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at ventanamicro.com>
Date: Wed, 30 Oct 2024 10:49:13 -0700
Subject: [PATCH 1/2] WIP


>From ae5d2b4ba2a4d6c82f40a03cd344a228eeda3683 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at ventanamicro.com>
Date: Fri, 1 Nov 2024 01:13:59 -0700
Subject: [PATCH 2/2] [RISCV][WIP] Let RA do the CSR saves.

We turn the problem of saving and restoring callee-saved registers efficiently into a register allocation problem. This has the advantage that the register allocator can essentialy do shrink-wrapping on per register basis. Currently, shrink-wrapping pass saves all CSR in the same place which may be suboptimal. Also, improvements to register allocation / coalescing will translate to improvements in shrink-wrapping.

In finalizeLowering() we copy all callee-saved registers from a physical register to a virtual one. In all return blocks we copy do the reverse.
---
 .../llvm/CodeGen/ReachingDefAnalysis.h        |   5 +
 .../llvm/CodeGen/TargetFrameLowering.h        |   6 +
 .../llvm/CodeGen/TargetSubtargetInfo.h        |   2 +
 llvm/lib/CodeGen/MachineLICM.cpp              |  44 +-
 llvm/lib/CodeGen/PrologEpilogInserter.cpp     |   7 +
 llvm/lib/CodeGen/ReachingDefAnalysis.cpp      |  60 +-
 llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp  |  15 +-
 llvm/lib/CodeGen/TargetSubtargetInfo.cpp      |   4 +
 llvm/lib/Target/RISCV/CMakeLists.txt          |   1 +
 llvm/lib/Target/RISCV/RISCV.h                 |   3 +
 llvm/lib/Target/RISCV/RISCVCFIInserter.cpp    | 569 ++++++++++++++++++
 llvm/lib/Target/RISCV/RISCVFrameLowering.cpp  | 244 +++++++-
 llvm/lib/Target/RISCV/RISCVFrameLowering.h    |   6 +
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 119 ++++
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |   2 +
 llvm/lib/Target/RISCV/RISCVInstrInfo.h        |   8 +
 llvm/lib/Target/RISCV/RISCVInstrInfo.td       |   4 +
 .../Target/RISCV/RISCVMachineFunctionInfo.cpp |  28 +
 .../Target/RISCV/RISCVMachineFunctionInfo.h   |  17 +
 llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp   |   8 +
 llvm/lib/Target/RISCV/RISCVSubtarget.cpp      |   9 +
 llvm/lib/Target/RISCV/RISCVSubtarget.h        |   2 +
 llvm/lib/Target/RISCV/RISCVTargetMachine.cpp  |   2 +
 23 files changed, 1138 insertions(+), 27 deletions(-)
 create mode 100644 llvm/lib/Target/RISCV/RISCVCFIInserter.cpp

diff --git a/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h b/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
index d6a1f064ec0a58e..6e6684ae53e0c59 100644
--- a/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
+++ b/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
@@ -114,8 +114,11 @@ class ReachingDefAnalysis : public MachineFunctionPass {
 private:
   MachineFunction *MF = nullptr;
   const TargetRegisterInfo *TRI = nullptr;
+  const TargetInstrInfo *TII = nullptr;
   LoopTraversal::TraversalOrder TraversedMBBOrder;
   unsigned NumRegUnits = 0;
+  unsigned NumStackObjects = 0;
+  int ObjectIndexBegin = 0;
   /// Instruction that defined each register, relative to the beginning of the
   /// current basic block.  When a LiveRegsDefInfo is used to represent a
   /// live-out register, this value is relative to the end of the basic block,
@@ -138,6 +141,8 @@ class ReachingDefAnalysis : public MachineFunctionPass {
   DenseMap<MachineInstr *, int> InstIds;
 
   MBBReachingDefsInfo MBBReachingDefs;
+  using MBBFrameObjsReachingDefsInfo = std::vector<std::vector<std::vector<int>>>;
+  MBBFrameObjsReachingDefsInfo MBBFrameObjsReachingDefs;
 
   /// Default values are 'nothing happened a long time ago'.
   const int ReachingDefDefaultVal = -(1 << 21);
diff --git a/llvm/include/llvm/CodeGen/TargetFrameLowering.h b/llvm/include/llvm/CodeGen/TargetFrameLowering.h
index 97de0197da9b400..db7c9f3fce43980 100644
--- a/llvm/include/llvm/CodeGen/TargetFrameLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetFrameLowering.h
@@ -24,6 +24,7 @@ namespace llvm {
   class CalleeSavedInfo;
   class MachineFunction;
   class RegScavenger;
+  class ReachingDefAnalysis;
 
 namespace TargetStackID {
 enum Value {
@@ -210,6 +211,11 @@ class TargetFrameLowering {
   /// for noreturn nounwind functions.
   virtual bool enableCalleeSaveSkip(const MachineFunction &MF) const;
 
+  virtual void emitCFIsForCSRsHandledByRA(MachineFunction &MF,
+                            ReachingDefAnalysis *RDA) const {
+    return;
+  }
+
   /// emitProlog/emitEpilog - These methods insert prolog and epilog code into
   /// the function.
   virtual void emitPrologue(MachineFunction &MF,
diff --git a/llvm/include/llvm/CodeGen/TargetSubtargetInfo.h b/llvm/include/llvm/CodeGen/TargetSubtargetInfo.h
index bfaa6450779ae09..ea8237cdbac7d05 100644
--- a/llvm/include/llvm/CodeGen/TargetSubtargetInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetSubtargetInfo.h
@@ -317,6 +317,8 @@ class TargetSubtargetInfo : public MCSubtargetInfo {
     return false;
   }
 
+  virtual bool doCSRSavesInRA() const;
+
   /// Classify a global function reference. This mainly used to fetch target
   /// special flags for lowering a function address. For example mark a function
   /// call should be plt or pc-related addressing.
diff --git a/llvm/lib/CodeGen/MachineLICM.cpp b/llvm/lib/CodeGen/MachineLICM.cpp
index 7ea07862b839d02..01fdc102961895a 100644
--- a/llvm/lib/CodeGen/MachineLICM.cpp
+++ b/llvm/lib/CodeGen/MachineLICM.cpp
@@ -262,15 +262,19 @@ namespace {
     void HoistOutOfLoop(MachineDomTreeNode *HeaderN, MachineLoop *CurLoop,
                         MachineBasicBlock *CurPreheader);
 
-    void InitRegPressure(MachineBasicBlock *BB);
+    void InitRegPressure(MachineBasicBlock *BB, const MachineLoop* Loop);
 
     SmallDenseMap<unsigned, int> calcRegisterCost(const MachineInstr *MI,
                                                   bool ConsiderSeen,
-                                                  bool ConsiderUnseenAsDef);
+                                                  bool ConsiderUnseenAsDef,
+                                                  bool IgnoreDefs = false);
 
+    bool allDefsAreOnlyUsedOutsideOfTheLoop(const MachineInstr &MI, const MachineLoop *Loop);
     void UpdateRegPressure(const MachineInstr *MI,
-                           bool ConsiderUnseenAsDef = false);
+                           bool ConsiderUnseenAsDef = false, bool IgnoreDefs = false);
 
+    void UpdateRegPressureForUsesOnly(const MachineInstr *MI,
+                           bool ConsiderUnseenAsDef = false);
     MachineInstr *ExtractHoistableLoad(MachineInstr *MI, MachineLoop *CurLoop);
 
     MachineInstr *LookForDuplicate(const MachineInstr *MI,
@@ -884,7 +888,7 @@ void MachineLICMImpl::HoistOutOfLoop(MachineDomTreeNode *HeaderN,
   // Compute registers which are livein into the loop headers.
   RegSeen.clear();
   BackTrace.clear();
-  InitRegPressure(Preheader);
+  InitRegPressure(Preheader, CurLoop);
 
   // Now perform LICM.
   for (MachineDomTreeNode *Node : Scopes) {
@@ -934,7 +938,7 @@ static bool isOperandKill(const MachineOperand &MO, MachineRegisterInfo *MRI) {
 /// Find all virtual register references that are liveout of the preheader to
 /// initialize the starting "register pressure". Note this does not count live
 /// through (livein but not used) registers.
-void MachineLICMImpl::InitRegPressure(MachineBasicBlock *BB) {
+void MachineLICMImpl::InitRegPressure(MachineBasicBlock *BB, const MachineLoop *Loop) {
   std::fill(RegPressure.begin(), RegPressure.end(), 0);
 
   // If the preheader has only a single predecessor and it ends with a
@@ -945,17 +949,32 @@ void MachineLICMImpl::InitRegPressure(MachineBasicBlock *BB) {
     MachineBasicBlock *TBB = nullptr, *FBB = nullptr;
     SmallVector<MachineOperand, 4> Cond;
     if (!TII->analyzeBranch(*BB, TBB, FBB, Cond, false) && Cond.empty())
-      InitRegPressure(*BB->pred_begin());
+      InitRegPressure(*BB->pred_begin(), Loop);
   }
 
-  for (const MachineInstr &MI : *BB)
-    UpdateRegPressure(&MI, /*ConsiderUnseenAsDef=*/true);
+  for (const MachineInstr &MI : *BB) {
+    bool IgnoreDefs = allDefsAreOnlyUsedOutsideOfTheLoop(MI, Loop);
+    UpdateRegPressure(&MI, /*ConsiderUnseenAsDef=*/true, IgnoreDefs);
+  }
+}
+
+bool MachineLICMImpl::allDefsAreOnlyUsedOutsideOfTheLoop(const MachineInstr &MI, const MachineLoop *Loop) {
+  for (const MachineOperand DefMO : MI.all_defs()) {
+    if (!DefMO.isReg())
+      continue;
+    for(const MachineInstr &UseMI : MRI->use_instructions(DefMO.getReg())) {
+      if (Loop->contains(UseMI.getParent()))
+      return false;
+    }
+  }
+  return true;
 }
 
 /// Update estimate of register pressure after the specified instruction.
 void MachineLICMImpl::UpdateRegPressure(const MachineInstr *MI,
-                                        bool ConsiderUnseenAsDef) {
-  auto Cost = calcRegisterCost(MI, /*ConsiderSeen=*/true, ConsiderUnseenAsDef);
+                                        bool ConsiderUnseenAsDef,
+                                        bool IgnoreDefs) {
+  auto Cost = calcRegisterCost(MI, /*ConsiderSeen=*/true, ConsiderUnseenAsDef, IgnoreDefs);
   for (const auto &RPIdAndCost : Cost) {
     unsigned Class = RPIdAndCost.first;
     if (static_cast<int>(RegPressure[Class]) < -RPIdAndCost.second)
@@ -973,7 +992,8 @@ void MachineLICMImpl::UpdateRegPressure(const MachineInstr *MI,
 /// FIXME: Figure out a way to consider 'RegSeen' from all code paths.
 SmallDenseMap<unsigned, int>
 MachineLICMImpl::calcRegisterCost(const MachineInstr *MI, bool ConsiderSeen,
-                                  bool ConsiderUnseenAsDef) {
+                                  bool ConsiderUnseenAsDef,
+                                  bool IgnoreDefs) {
   SmallDenseMap<unsigned, int> Cost;
   if (MI->isImplicitDef())
     return Cost;
@@ -991,7 +1011,7 @@ MachineLICMImpl::calcRegisterCost(const MachineInstr *MI, bool ConsiderSeen,
 
     RegClassWeight W = TRI->getRegClassWeight(RC);
     int RCCost = 0;
-    if (MO.isDef())
+    if (MO.isDef() && !IgnoreDefs)
       RCCost = W.RegWeight;
     else {
       bool isKill = isOperandKill(MO, MRI);
diff --git a/llvm/lib/CodeGen/PrologEpilogInserter.cpp b/llvm/lib/CodeGen/PrologEpilogInserter.cpp
index ee03eaa8ae527c6..78b10d665a0157a 100644
--- a/llvm/lib/CodeGen/PrologEpilogInserter.cpp
+++ b/llvm/lib/CodeGen/PrologEpilogInserter.cpp
@@ -36,6 +36,7 @@
 #include "llvm/CodeGen/MachineOperand.h"
 #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/ReachingDefAnalysis.h"
 #include "llvm/CodeGen/RegisterScavenging.h"
 #include "llvm/CodeGen/TargetFrameLowering.h"
 #include "llvm/CodeGen/TargetInstrInfo.h"
@@ -95,6 +96,7 @@ class PEI : public MachineFunctionPass {
   bool runOnMachineFunction(MachineFunction &MF) override;
 
 private:
+  ReachingDefAnalysis *RDA = nullptr;
   RegScavenger *RS = nullptr;
 
   // MinCSFrameIndex, MaxCSFrameIndex - Keeps the range of callee saved
@@ -153,6 +155,7 @@ INITIALIZE_PASS_BEGIN(PEI, DEBUG_TYPE, "Prologue/Epilogue Insertion", false,
 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
+INITIALIZE_PASS_DEPENDENCY(ReachingDefAnalysis)
 INITIALIZE_PASS_END(PEI, DEBUG_TYPE,
                     "Prologue/Epilogue Insertion & Frame Finalization", false,
                     false)
@@ -169,6 +172,7 @@ void PEI::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.addPreserved<MachineLoopInfoWrapperPass>();
   AU.addPreserved<MachineDominatorTreeWrapperPass>();
   AU.addRequired<MachineOptimizationRemarkEmitterPass>();
+  AU.addRequired<ReachingDefAnalysis>();
   MachineFunctionPass::getAnalysisUsage(AU);
 }
 
@@ -227,6 +231,7 @@ bool PEI::runOnMachineFunction(MachineFunction &MF) {
   RS = TRI->requiresRegisterScavenging(MF) ? new RegScavenger() : nullptr;
   FrameIndexVirtualScavenging = TRI->requiresFrameIndexScavenging(MF);
   ORE = &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE();
+  RDA = &getAnalysis<ReachingDefAnalysis>();
 
   // Spill frame pointer and/or base pointer registers if they are clobbered.
   // It is placed before call frame instruction elimination so it will not mess
@@ -262,6 +267,7 @@ bool PEI::runOnMachineFunction(MachineFunction &MF) {
   // called functions.  Because of this, calculateCalleeSavedRegisters()
   // must be called before this function in order to set the AdjustsStack
   // and MaxCallFrameSize variables.
+  RDA->reset();
   if (!F.hasFnAttribute(Attribute::Naked))
     insertPrologEpilogCode(MF);
 
@@ -1164,6 +1170,7 @@ void PEI::calculateFrameObjectOffsets(MachineFunction &MF) {
 void PEI::insertPrologEpilogCode(MachineFunction &MF) {
   const TargetFrameLowering &TFI = *MF.getSubtarget().getFrameLowering();
 
+  TFI.emitCFIsForCSRsHandledByRA(MF, RDA);
   // Add prologue to the function...
   for (MachineBasicBlock *SaveBlock : SaveBlocks)
     TFI.emitPrologue(MF, *SaveBlock);
diff --git a/llvm/lib/CodeGen/ReachingDefAnalysis.cpp b/llvm/lib/CodeGen/ReachingDefAnalysis.cpp
index 0e8220ec6251cb7..2120d15465ff9a8 100644
--- a/llvm/lib/CodeGen/ReachingDefAnalysis.cpp
+++ b/llvm/lib/CodeGen/ReachingDefAnalysis.cpp
@@ -10,6 +10,8 @@
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/CodeGen/LiveRegUnits.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
 #include "llvm/CodeGen/TargetRegisterInfo.h"
 #include "llvm/CodeGen/TargetSubtargetInfo.h"
 #include "llvm/Support/Debug.h"
@@ -48,12 +50,31 @@ static bool isValidRegDefOf(const MachineOperand &MO, MCRegister PhysReg,
   return TRI->regsOverlap(MO.getReg(), PhysReg);
 }
 
+static bool isFIDef(const MachineInstr &MI, int FrameIndex, const TargetInstrInfo *TII) {
+  int DefFrameIndex = 0;
+  int SrcFrameIndex = 0;
+  if (
+    TII->isStoreToStackSlot(MI, DefFrameIndex) ||
+    TII->isStackSlotCopy(MI, DefFrameIndex, SrcFrameIndex)
+    ) {
+    return DefFrameIndex == FrameIndex;
+  }
+  return false;
+}
+
+
 void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) {
   unsigned MBBNumber = MBB->getNumber();
   assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
          "Unexpected basic block number.");
   MBBReachingDefs.startBasicBlock(MBBNumber, NumRegUnits);
 
+  MBBFrameObjsReachingDefs[MBBNumber].resize(NumStackObjects);
+  for (unsigned FOIdx = 0; FOIdx < NumStackObjects; ++FOIdx) {
+    MBBFrameObjsReachingDefs[MBBNumber][FOIdx].push_back(-1);
+  }
+
+
   // Reset instruction counter in each basic block.
   CurInstr = 0;
 
@@ -126,6 +147,12 @@ void ReachingDefAnalysis::processDefs(MachineInstr *MI) {
          "Unexpected basic block number.");
 
   for (auto &MO : MI->operands()) {
+    if (MO.isFI()) {
+      int FrameIndex = MO.getIndex();
+      if (!isFIDef(*MI, FrameIndex, TII))
+        continue;
+      MBBFrameObjsReachingDefs[MBBNumber][FrameIndex - ObjectIndexBegin].push_back(CurInstr);
+    }
     if (!isValidRegDef(MO))
       continue;
     for (MCRegUnit Unit : TRI->regunits(MO.getReg().asMCReg())) {
@@ -211,7 +238,9 @@ void ReachingDefAnalysis::processBasicBlock(
 
 bool ReachingDefAnalysis::runOnMachineFunction(MachineFunction &mf) {
   MF = &mf;
-  TRI = MF->getSubtarget().getRegisterInfo();
+  const TargetSubtargetInfo &STI = MF->getSubtarget();
+  TRI = STI.getRegisterInfo();
+  TII = STI.getInstrInfo();
   LLVM_DEBUG(dbgs() << "********** REACHING DEFINITION ANALYSIS **********\n");
   init();
   traverse();
@@ -222,6 +251,7 @@ void ReachingDefAnalysis::releaseMemory() {
   // Clear the internal vectors.
   MBBOutRegsInfos.clear();
   MBBReachingDefs.clear();
+  MBBFrameObjsReachingDefs.clear();
   InstIds.clear();
   LiveRegs.clear();
 }
@@ -234,7 +264,10 @@ void ReachingDefAnalysis::reset() {
 
 void ReachingDefAnalysis::init() {
   NumRegUnits = TRI->getNumRegUnits();
+  NumStackObjects = MF->getFrameInfo().getNumObjects();
+  ObjectIndexBegin = MF->getFrameInfo().getObjectIndexBegin();
   MBBReachingDefs.init(MF->getNumBlockIDs());
+  MBBFrameObjsReachingDefs.resize(MF->getNumBlockIDs());
   // Initialize the MBBOutRegsInfos
   MBBOutRegsInfos.resize(MF->getNumBlockIDs());
   LoopTraversal Traversal;
@@ -269,6 +302,18 @@ int ReachingDefAnalysis::getReachingDef(MachineInstr *MI,
   assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
          "Unexpected basic block number.");
   int LatestDef = ReachingDefDefaultVal;
+
+  if (Register::isStackSlot(PhysReg)) {
+    int FrameIndex = Register::stackSlot2Index(PhysReg);
+    for (int Def : MBBFrameObjsReachingDefs[MBBNumber][FrameIndex - ObjectIndexBegin]) {
+      if (Def >= InstId)
+        break;
+      DefRes = Def;
+    }
+    LatestDef = std::max(LatestDef, DefRes);
+    return LatestDef;
+  }
+
   for (MCRegUnit Unit : TRI->regunits(PhysReg)) {
     for (int Def : MBBReachingDefs.defs(MBBNumber, Unit)) {
       if (Def >= InstId)
@@ -425,7 +470,7 @@ void ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB,
   VisitedBBs.insert(MBB);
   LiveRegUnits LiveRegs(*TRI);
   LiveRegs.addLiveOuts(*MBB);
-  if (LiveRegs.available(PhysReg))
+  if (Register::isPhysicalRegister(PhysReg) && LiveRegs.available(PhysReg))
     return;
 
   if (auto *Def = getLocalLiveOutMIDef(MBB, PhysReg))
@@ -508,7 +553,7 @@ bool ReachingDefAnalysis::isReachingDefLiveOut(MachineInstr *MI,
   MachineBasicBlock *MBB = MI->getParent();
   LiveRegUnits LiveRegs(*TRI);
   LiveRegs.addLiveOuts(*MBB);
-  if (LiveRegs.available(PhysReg))
+  if (Register::isPhysicalRegister(PhysReg) && LiveRegs.available(PhysReg))
     return false;
 
   auto Last = MBB->getLastNonDebugInstr();
@@ -529,7 +574,7 @@ ReachingDefAnalysis::getLocalLiveOutMIDef(MachineBasicBlock *MBB,
                                           MCRegister PhysReg) const {
   LiveRegUnits LiveRegs(*TRI);
   LiveRegs.addLiveOuts(*MBB);
-  if (LiveRegs.available(PhysReg))
+  if (Register::isPhysicalRegister(PhysReg) && LiveRegs.available(PhysReg))
     return nullptr;
 
   auto Last = MBB->getLastNonDebugInstr();
@@ -537,6 +582,13 @@ ReachingDefAnalysis::getLocalLiveOutMIDef(MachineBasicBlock *MBB,
     return nullptr;
 
   int Def = getReachingDef(&*Last, PhysReg);
+
+  if (Register::isStackSlot(PhysReg)) {
+    int FrameIndex = Register::stackSlot2Index(PhysReg);
+    if (isFIDef(*Last, FrameIndex, TII))
+      return &*Last;
+  }
+
   for (auto &MO : Last->operands())
     if (isValidRegDefOf(MO, PhysReg, TRI))
       return &*Last;
diff --git a/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp b/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp
index a1dccc4d59723bb..023c03c5a2a9228 100644
--- a/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp
+++ b/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp
@@ -44,6 +44,11 @@ static cl::opt<bool> EnableLocalReassignment(
              "may be compile time intensive"),
     cl::init(false));
 
+static cl::opt<float> MinWeightRatioNeededToEvictHint(
+    "min-weight-ratio-needed-to-evict-hint", cl::Hidden,
+    cl::desc("The minimum ration of weight needed in order for a live range with bigger weight to evict another live range which satisfies a hint"),
+    cl::init(1.0));
+
 namespace llvm {
 cl::opt<unsigned> EvictInterferenceCutoff(
     "regalloc-eviction-max-interference-cutoff", cl::Hidden,
@@ -156,8 +161,14 @@ bool DefaultEvictionAdvisor::shouldEvict(const LiveInterval &A, bool IsHint,
   if (CanSplit && IsHint && !BreaksHint)
     return true;
 
-  if (A.weight() > B.weight()) {
-    LLVM_DEBUG(dbgs() << "should evict: " << B << " w= " << B.weight() << '\n');
+  float AWeight = A.weight();
+  float BWeight = B.weight();
+  if (AWeight > BWeight) {
+    float WeightRatio = BWeight == 0.0 ? std::numeric_limits<float>::infinity() : AWeight / BWeight;
+    if (CanSplit && !IsHint && BreaksHint && (WeightRatio < MinWeightRatioNeededToEvictHint)) {
+      return false;
+    }
+    LLVM_DEBUG(dbgs() << "should evict: " << B << " w= " << BWeight << '\n');
     return true;
   }
   return false;
diff --git a/llvm/lib/CodeGen/TargetSubtargetInfo.cpp b/llvm/lib/CodeGen/TargetSubtargetInfo.cpp
index 6c97bc0568bdeee..566d5420c638efb 100644
--- a/llvm/lib/CodeGen/TargetSubtargetInfo.cpp
+++ b/llvm/lib/CodeGen/TargetSubtargetInfo.cpp
@@ -45,6 +45,10 @@ bool TargetSubtargetInfo::enableRALocalReassignment(
   return true;
 }
 
+bool TargetSubtargetInfo::doCSRSavesInRA() const {
+  return false;
+}
+
 bool TargetSubtargetInfo::enablePostRAScheduler() const {
   return getSchedModel().PostRAScheduler;
 }
diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt
index fd049d1a57860ef..e8897ed14dcea15 100644
--- a/llvm/lib/Target/RISCV/CMakeLists.txt
+++ b/llvm/lib/Target/RISCV/CMakeLists.txt
@@ -29,6 +29,7 @@ add_public_tablegen_target(RISCVCommonTableGen)
 
 add_llvm_target(RISCVCodeGen
   RISCVAsmPrinter.cpp
+  RISCVCFIInserter.cpp
   RISCVCallingConv.cpp
   RISCVCodeGenPrepare.cpp
   RISCVConstantPoolValue.cpp
diff --git a/llvm/lib/Target/RISCV/RISCV.h b/llvm/lib/Target/RISCV/RISCV.h
index d7bab601d545ccb..65d6f7725726f8f 100644
--- a/llvm/lib/Target/RISCV/RISCV.h
+++ b/llvm/lib/Target/RISCV/RISCV.h
@@ -105,6 +105,9 @@ void initializeRISCVPreLegalizerCombinerPass(PassRegistry &);
 
 FunctionPass *createRISCVVLOptimizerPass();
 void initializeRISCVVLOptimizerPass(PassRegistry &);
+
+FunctionPass *createRISCVCFIInstrInserter();
+void initializeRISCVCFIInstrInserterPass(PassRegistry &);
 } // namespace llvm
 
 #endif
diff --git a/llvm/lib/Target/RISCV/RISCVCFIInserter.cpp b/llvm/lib/Target/RISCV/RISCVCFIInserter.cpp
new file mode 100644
index 000000000000000..9f6109524b314ec
--- /dev/null
+++ b/llvm/lib/Target/RISCV/RISCVCFIInserter.cpp
@@ -0,0 +1,569 @@
+//===------ RISCVCFIInstrInserter.cpp - Insert additional CFI instructions -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file This pass verifies incoming and outgoing CFA information of basic
+/// blocks. CFA information is information about offset and register set by CFI
+/// directives, valid at the start and end of a basic block. This pass checks
+/// that outgoing information of predecessors matches incoming information of
+/// their successors. Then it checks if blocks have correct CFA calculation rule
+/// set and inserts additional CFI instruction at their beginnings if they
+/// don't. CFI instructions are inserted if basic blocks have incorrect offset
+/// or register set by previous blocks, as a result of a non-linear layout of
+/// blocks in a function.
+//===----------------------------------------------------------------------===//
+
+#include "RISCV.h"
+#include "RISCVMachineFunctionInfo.h"
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/BinaryFormat/Dwarf.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetFrameLowering.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetSubtargetInfo.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/MC/MCDwarf.h"
+#include "llvm/Support/LEB128.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "riscv-cfi-inserter"
+
+//static cl::opt<bool> VerifyCFI("verify-cfiinstrs",
+//    cl::desc("Verify Call Frame Information instructions"),
+//    cl::init(false),
+//    cl::Hidden);
+
+namespace {
+class RISCVCFIInstrInserter : public MachineFunctionPass {
+ public:
+  static char ID;
+
+  RISCVCFIInstrInserter() : MachineFunctionPass(ID) {
+    initializeRISCVCFIInstrInserterPass(*PassRegistry::getPassRegistry());
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesAll();
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  bool runOnMachineFunction(MachineFunction &MF) override {
+    if (!MF.needsFrameMoves())
+      return false;
+
+    if (!MF.getSubtarget().doCSRSavesInRA())
+      return false;
+
+    RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
+    MBBVector.resize(MF.getNumBlockIDs());
+    calculateCFAInfo(MF);
+
+    //if (VerifyCFI) {
+    //  if (unsigned ErrorNum = verify(MF))
+    //    report_fatal_error("Found " + Twine(ErrorNum) +
+    //                       " in/out CFI information errors.");
+    //}
+    bool insertedCFI = insertCFIInstrs(MF);
+    MBBVector.clear();
+    return insertedCFI;
+  }
+
+ private:
+#define INVALID_REG UINT_MAX
+#define INVALID_OFFSET INT_MAX
+  /// contains the location where CSR register is saved.
+  /// Registers are recorded by their Dwarf numbers.
+  struct CSRLocation {
+    bool IsReg = true;
+    int Reg = 0;
+    int FrameReg = 0;
+    int Offset = 0;
+    bool isEqual(const CSRLocation &Other) const {
+      if (IsReg)
+        return Other.IsReg ? (Reg == Other.Reg) : false;
+      return !Other.IsReg ? ((Offset == Other.Offset) && FrameReg == Other.FrameReg) : false;
+    }
+  };
+
+  struct MBBCFAInfo {
+    MachineBasicBlock *MBB;
+    /// Value of cfa offset valid at basic block entry.
+    int IncomingCFAOffset = -1;
+    /// Value of cfa offset valid at basic block exit.
+    int OutgoingCFAOffset = -1;
+    /// Value of cfa register valid at basic block entry.
+    int IncomingCFARegister = 0;
+    /// Value of cfa register valid at basic block exit.
+    int OutgoingCFARegister = 0;
+    /// Set of callee saved registers saved at basic block entry.
+    SmallVector<CSRLocation> IncomingCSRLocations;
+    /// Set of callee saved registers saved at basic block exit.
+    SmallVector<CSRLocation> OutgoingCSRLocations;
+    /// If in/out cfa offset and register values for this block have already
+    /// been set or not.
+    bool Processed = false;
+  };
+
+  RISCVMachineFunctionInfo *RVFI;
+  /// Contains cfa offset and register values valid at entry and exit of basic
+  /// blocks.
+  std::vector<MBBCFAInfo> MBBVector;
+
+  /// Calculate cfa offset and register values valid at entry and exit for all
+  /// basic blocks in a function.
+  void calculateCFAInfo(MachineFunction &MF);
+  /// Calculate cfa offset and register values valid at basic block exit by
+  /// checking the block for CFI instructions. Block's incoming CFA info remains
+  /// the same.
+  void calculateOutgoingCFAInfo(MBBCFAInfo &MBBInfo);
+  /// Update in/out cfa offset and register values for successors of the basic
+  /// block.
+  void updateSuccCFAInfo(MBBCFAInfo &MBBInfo);
+
+  /// Check if incoming CFA information of a basic block matches outgoing CFA
+  /// information of the previous block. If it doesn't, insert CFI instruction
+  /// at the beginning of the block that corrects the CFA calculation rule for
+  /// that block.
+  bool insertCFIInstrs(MachineFunction &MF);
+  /// Return the cfa offset value that should be set at the beginning of a MBB
+  /// if needed. The negated value is needed when creating CFI instructions that
+  /// set absolute offset.
+  int getCorrectCFAOffset(MachineBasicBlock *MBB) {
+    return MBBVector[MBB->getNumber()].IncomingCFAOffset;
+  }
+
+  void reportCFAError(const MBBCFAInfo &Pred, const MBBCFAInfo &Succ);
+  void reportCSRError(const MBBCFAInfo &Pred, const MBBCFAInfo &Succ);
+  /// Go through each MBB in a function and check that outgoing offset and
+  /// register of its predecessors match incoming offset and register of that
+  /// MBB, as well as that incoming offset and register of its successors match
+  /// outgoing offset and register of the MBB.
+  unsigned verify(MachineFunction &MF);
+};
+}  // namespace
+
+char RISCVCFIInstrInserter::ID = 0;
+INITIALIZE_PASS(RISCVCFIInstrInserter, "cfi-instr-inserter",
+                "Check CFA info and insert CFI instructions if needed", false,
+                false)
+FunctionPass *llvm::createRISCVCFIInstrInserter() { return new RISCVCFIInstrInserter(); }
+
+void RISCVCFIInstrInserter::calculateCFAInfo(MachineFunction &MF) {
+  const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
+  // Initial CFA offset value i.e. the one valid at the beginning of the
+  // function.
+  int InitialOffset =
+      MF.getSubtarget().getFrameLowering()->getInitialCFAOffset(MF);
+  // Initial CFA register value i.e. the one valid at the beginning of the
+  // function.
+  int InitialRegister =
+      TRI.getDwarfRegNum(MF.getSubtarget().getFrameLowering()->getInitialCFARegister(MF), true);
+  unsigned NumRegs = TRI.getNumSupportedRegs(MF);
+
+  // Initialize MBBMap.
+  for (MachineBasicBlock &MBB : MF) {
+    MBBCFAInfo &MBBInfo = MBBVector[MBB.getNumber()];
+    MBBInfo.MBB = &MBB;
+    MBBInfo.IncomingCFAOffset = InitialOffset;
+    MBBInfo.OutgoingCFAOffset = InitialOffset;
+    MBBInfo.IncomingCFARegister = InitialRegister;
+    MBBInfo.OutgoingCFARegister = InitialRegister;
+    MBBInfo.IncomingCSRLocations.resize(NumRegs);
+    MBBInfo.OutgoingCSRLocations.resize(NumRegs);
+  }
+
+  MBBCFAInfo &EntryMBBInfo = MBBVector[MF.front().getNumber()];
+  const MCPhysReg * CSRegs = MF.getRegInfo().getCalleeSavedRegs();
+  for (int i = 0; CSRegs[i]; ++i) {
+    unsigned Reg = TRI.getDwarfRegNum(CSRegs[i], true);
+    CSRLocation &CSRLoc = EntryMBBInfo.IncomingCSRLocations[Reg];
+    CSRLoc.IsReg = true;
+    CSRLoc.Reg = Reg;
+  }
+  // Set in/out cfa info for all blocks in the function. This traversal is based
+  // on the assumption that the first block in the function is the entry block
+  // i.e. that it has initial cfa offset and register values as incoming CFA
+  // information.
+  updateSuccCFAInfo(MBBVector[MF.front().getNumber()]);
+
+  LLVM_DEBUG(
+    dbgs() << "Calculated CFI info for " << MF.getName() << "\n";
+    for (MachineBasicBlock &MBB : MF) {
+      dbgs() << "BasicBlock: " <<MBB.getNumber() << " " << MBB.getName() << "\n";
+      dbgs() << "IncomingCSRLocations:\n";
+      for (int i = 0; CSRegs[i]; ++i) {
+        int Reg = TRI.getDwarfRegNum(CSRegs[i], true);
+        const CSRLocation &CSRLoc = MBBVector[MBB.getNumber()].IncomingCSRLocations[Reg];
+        dbgs() << "CSR: " << Reg << ", Location: {";
+        dbgs() << "IsReg: " << CSRLoc.IsReg << ", ";
+        dbgs() << "Reg: " << CSRLoc.Reg << ", ";
+        dbgs() << "FrameReg: " << CSRLoc.FrameReg << ", ";
+        dbgs() << "Offset: " << CSRLoc.Offset << "}\n";
+      }
+      dbgs() << "OutgoingCSRLocations:\n";
+      for (int i = 0; CSRegs[i]; ++i) {
+        int Reg = TRI.getDwarfRegNum(CSRegs[i], true);
+        const CSRLocation &CSRLoc = MBBVector[MBB.getNumber()].OutgoingCSRLocations[Reg];
+        dbgs() << "CSR: " << Reg << ", Location: {";
+        dbgs() << "IsReg: " << CSRLoc.IsReg << ", ";
+        dbgs() << "Reg: " << CSRLoc.Reg << ", ";
+        dbgs() << "FrameReg: " << CSRLoc.FrameReg << ", ";
+        dbgs() << "Offset: " << CSRLoc.Offset << "}\n";
+      }
+    }
+  );
+}
+
+void RISCVCFIInstrInserter::calculateOutgoingCFAInfo(MBBCFAInfo &MBBInfo) {
+  MachineFunction *MF = MBBInfo.MBB->getParent();
+  const std::vector<MCCFIInstruction> &Instrs = MF->getFrameInstructions();
+
+  int &OutgoingCFAOffset = MBBInfo.OutgoingCFAOffset;
+  int &OutgoingCFARegister = MBBInfo.OutgoingCFARegister;
+  SmallVector<CSRLocation> &OutgoingCSRLocations = MBBInfo.OutgoingCSRLocations;
+
+  OutgoingCSRLocations = MBBInfo.IncomingCSRLocations;
+  // Determine cfa offset and register set by the block.
+  for (MachineInstr &MI : *MBBInfo.MBB) {
+    if (MI.isCFIInstruction()) {
+      unsigned CFIIndex = MI.getOperand(0).getCFIIndex();
+      const MCCFIInstruction &CFI = Instrs[CFIIndex];
+      switch (CFI.getOperation()) {
+      case MCCFIInstruction::OpDefCfaRegister: {
+        int Reg = CFI.getRegister();
+        assert(Reg >= 0 && "Negative dwarf register number!");
+        OutgoingCFARegister = Reg;
+        break;
+      }
+      case MCCFIInstruction::OpDefCfaOffset: {
+        OutgoingCFAOffset = CFI.getOffset();
+        break;
+      }
+      case MCCFIInstruction::OpAdjustCfaOffset: {
+        OutgoingCFAOffset += CFI.getOffset();
+        break;
+      }
+      case MCCFIInstruction::OpDefCfa: {
+        int Reg = CFI.getRegister();
+        assert(Reg >= 0 && "Negative dwarf register number!");
+        OutgoingCFARegister = Reg;
+        OutgoingCFAOffset = CFI.getOffset();
+        break;
+      }
+      case MCCFIInstruction::OpOffset: {
+        int Reg = CFI.getRegister();
+        assert(Reg >= 0 && "Negative dwarf register number!");
+        OutgoingCSRLocations[Reg].Offset = CFI.getOffset();
+        OutgoingCSRLocations[Reg].FrameReg = CFI.getOffset();
+        OutgoingCSRLocations[Reg].IsReg = false;
+        break;
+      }
+      case MCCFIInstruction::OpEscape: {
+        int Reg;
+        int FrameReg;
+        int64_t Offset;
+        bool isRegPlusOffset = RVFI->getCFIInfo(&MI, Reg, FrameReg, Offset);
+        if (!isRegPlusOffset) {
+          break;
+        }
+        assert(Reg >= 0 && "Negative dwarf register number!");
+        assert(FrameReg >= 0 && "Negative dwarf register number!");
+        OutgoingCSRLocations[Reg].IsReg = false;
+        OutgoingCSRLocations[Reg].Offset = Offset;
+        OutgoingCSRLocations[Reg].FrameReg = FrameReg;
+        break;
+      }
+      case MCCFIInstruction::OpRegister: {
+        int Reg = CFI.getRegister();
+        assert(Reg >= 0 && "Negative dwarf register number!");
+        int Reg2 = CFI.getRegister();
+        assert(Reg2 >= 0 && "Negative dwarf register number!");
+        OutgoingCSRLocations[Reg].Reg = Reg2;
+        OutgoingCSRLocations[Reg].IsReg = true;
+        break;
+      }
+      case MCCFIInstruction::OpRelOffset:
+        report_fatal_error(
+            "Support for .cfi_rel_offset not implemented! Value of CFA "
+            "may be incorrect!\n");
+        break;
+      case MCCFIInstruction::OpRestore:
+        report_fatal_error(
+            "Support for .cfi_restore not implemented! Value of CFA "
+            "may be incorrect!\n");
+        break;
+      case MCCFIInstruction::OpLLVMDefAspaceCfa:
+        // TODO: Add support for handling cfi_def_aspace_cfa.
+#ifndef NDEBUG
+        report_fatal_error(
+            "Support for cfi_llvm_def_aspace_cfa not implemented! Value of CFA "
+            "may be incorrect!\n");
+#endif
+        break;
+      case MCCFIInstruction::OpRememberState:
+        // TODO: Add support for handling cfi_remember_state.
+#ifndef NDEBUG
+        report_fatal_error(
+            "Support for cfi_remember_state not implemented! Value of CFA "
+            "may be incorrect!\n");
+#endif
+        break;
+      case MCCFIInstruction::OpRestoreState:
+        // TODO: Add support for handling cfi_restore_state.
+#ifndef NDEBUG
+        report_fatal_error(
+            "Support for cfi_restore_state not implemented! Value of CFA may "
+            "be incorrect!\n");
+#endif
+        break;
+      case MCCFIInstruction::OpUndefined:
+      case MCCFIInstruction::OpSameValue:
+      case MCCFIInstruction::OpWindowSave:
+      case MCCFIInstruction::OpNegateRAState:
+      case MCCFIInstruction::OpGnuArgsSize:
+        break;
+      }
+    }
+  }
+
+  MBBInfo.Processed = true;
+}
+
+void RISCVCFIInstrInserter::updateSuccCFAInfo(MBBCFAInfo &MBBInfo) {
+  SmallVector<MachineBasicBlock *, 4> Stack;
+  Stack.push_back(MBBInfo.MBB);
+
+  do {
+    MachineBasicBlock *Current = Stack.pop_back_val();
+    MBBCFAInfo &CurrentInfo = MBBVector[Current->getNumber()];
+    calculateOutgoingCFAInfo(CurrentInfo);
+    for (auto *Succ : CurrentInfo.MBB->successors()) {
+      MBBCFAInfo &SuccInfo = MBBVector[Succ->getNumber()];
+      if (!SuccInfo.Processed) {
+        SuccInfo.IncomingCFAOffset = CurrentInfo.OutgoingCFAOffset;
+        SuccInfo.IncomingCFARegister = CurrentInfo.OutgoingCFARegister;
+        SuccInfo.IncomingCSRLocations = CurrentInfo.OutgoingCSRLocations;
+        Stack.push_back(Succ);
+      }
+    }
+  } while (!Stack.empty());
+}
+
+bool RISCVCFIInstrInserter::insertCFIInstrs(MachineFunction &MF) {
+  const MBBCFAInfo *PrevMBBInfo = &MBBVector[MF.front().getNumber()];
+  const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
+  bool InsertedCFIInstr = false;
+
+  BitVector SetDifference;
+  for (MachineBasicBlock &MBB : MF) {
+    // Skip the first MBB in a function
+    if (MBB.getNumber() == MF.front().getNumber()) continue;
+
+    const MBBCFAInfo &MBBInfo = MBBVector[MBB.getNumber()];
+    auto MBBI = MBBInfo.MBB->begin();
+    DebugLoc DL = MBBInfo.MBB->findDebugLoc(MBBI);
+
+    // If the current MBB will be placed in a unique section, a full DefCfa
+    // must be emitted.
+    const bool ForceFullCFA = MBB.isBeginSection();
+
+    if ((PrevMBBInfo->OutgoingCFAOffset != MBBInfo.IncomingCFAOffset &&
+         PrevMBBInfo->OutgoingCFARegister != MBBInfo.IncomingCFARegister) ||
+        ForceFullCFA) {
+      // If both outgoing offset and register of a previous block don't match
+      // incoming offset and register of this block, or if this block begins a
+      // section, add a def_cfa instruction with the correct offset and
+      // register for this block.
+      unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::cfiDefCfa(
+          nullptr, MBBInfo.IncomingCFARegister, getCorrectCFAOffset(&MBB)));
+      BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
+          .addCFIIndex(CFIIndex);
+      InsertedCFIInstr = true;
+    } else if (PrevMBBInfo->OutgoingCFAOffset != MBBInfo.IncomingCFAOffset) {
+      // If outgoing offset of a previous block doesn't match incoming offset
+      // of this block, add a def_cfa_offset instruction with the correct
+      // offset for this block.
+      unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::cfiDefCfaOffset(
+          nullptr, getCorrectCFAOffset(&MBB)));
+      BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
+          .addCFIIndex(CFIIndex);
+      InsertedCFIInstr = true;
+    } else if (PrevMBBInfo->OutgoingCFARegister !=
+               MBBInfo.IncomingCFARegister) {
+      unsigned CFIIndex =
+          MF.addFrameInst(MCCFIInstruction::createDefCfaRegister(
+              nullptr, MBBInfo.IncomingCFARegister));
+      BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
+          .addCFIIndex(CFIIndex);
+      InsertedCFIInstr = true;
+    }
+
+    if (ForceFullCFA) {
+      MF.getSubtarget().getFrameLowering()->emitCalleeSavedFrameMovesFullCFA(
+          *MBBInfo.MBB, MBBI);
+      InsertedCFIInstr = true;
+      PrevMBBInfo = &MBBInfo;
+      continue;
+    }
+
+    for (unsigned i = 0; i < PrevMBBInfo->OutgoingCSRLocations.size(); ++i) {
+      const CSRLocation &OutgoingCSRLoc = PrevMBBInfo->OutgoingCSRLocations[i];
+      const CSRLocation &IncomingCSRLoc = MBBInfo.IncomingCSRLocations[i];
+      if (IncomingCSRLoc.IsReg && (IncomingCSRLoc.Reg == 0))
+        continue;
+      if (MBBInfo.IncomingCSRLocations[i].isEqual(OutgoingCSRLoc))
+        continue;
+      unsigned CFIIndex;
+      if (IncomingCSRLoc.IsReg) {
+        CFIIndex = MF.addFrameInst(
+          MCCFIInstruction::createRegister(nullptr, i, IncomingCSRLoc.Reg)
+        );
+      }
+      else {
+        //CFIIndex = MF.addFrameInst(
+        //  MCCFIInstruction::createOffset(nullptr, i, IncomingCSRLoc.Offset)
+        //);
+        std::string CommentBuffer;
+        llvm::raw_string_ostream Comment(CommentBuffer);
+        int DwarfRegSP = IncomingCSRLoc.FrameReg;
+        int DwarfEHRegNum = i;
+        int64_t FixedOffset = IncomingCSRLoc.Offset;
+        // Build up the expression (SP + FixedOffset)
+        SmallString<64> Expr;
+        uint8_t Buffer[16];
+
+        Comment << FixedOffset;
+        //0x11
+        Expr.push_back(dwarf::DW_OP_consts);
+        Expr.append(Buffer, Buffer + encodeSLEB128(FixedOffset, Buffer));
+
+        //0x92
+        Expr.push_back((uint8_t)dwarf::DW_OP_bregx);
+        //0x02
+        Expr.append(Buffer, Buffer + encodeULEB128(DwarfRegSP, Buffer));
+        Expr.push_back(0);
+
+        //0x22
+        Expr.push_back((uint8_t)dwarf::DW_OP_plus);
+        // Wrap this into DW_CFA_def_cfa.
+        SmallString<64> DefCfaExpr;
+        // 0x10
+        DefCfaExpr.push_back(dwarf::DW_CFA_expression);
+        DefCfaExpr.append(Buffer, Buffer + encodeULEB128(DwarfEHRegNum, Buffer));
+        DefCfaExpr.append(Buffer, Buffer + encodeULEB128(Expr.size(), Buffer));
+        DefCfaExpr.append(Expr.str());
+        CFIIndex = MF.addFrameInst(
+          MCCFIInstruction::createEscape(
+            nullptr,
+            DefCfaExpr.str(),
+            SMLoc(),
+            Comment.str()
+          )
+        );
+      }
+      BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
+          .addCFIIndex(CFIIndex);
+      InsertedCFIInstr = true;
+    }
+    //BitVector::apply([](auto x, auto y) { return x & ~y; }, SetDifference,
+    //                 PrevMBBInfo->OutgoingCSRSaved, MBBInfo.IncomingCSRSaved);
+    //for (int Reg : SetDifference.set_bits()) {
+    //  unsigned CFIIndex =
+    //      MF.addFrameInst(MCCFIInstruction::createRestore(nullptr, Reg));
+    //  BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
+    //      .addCFIIndex(CFIIndex);
+    //  InsertedCFIInstr = true;
+    //}
+
+    //BitVector::apply([](auto x, auto y) { return x & ~y; }, SetDifference,
+    //                 MBBInfo.IncomingCSRSaved, PrevMBBInfo->OutgoingCSRSaved);
+    //for (int Reg : SetDifference.set_bits()) {
+    //  auto it = CSRLocMap.find(Reg);
+    //  assert(it != CSRLocMap.end() && "Reg should have an entry in CSRLocMap");
+    //  unsigned CFIIndex;
+    //  CSRSavedLocation RO = it->second;
+    //  if (!RO.Reg && RO.Offset) {
+    //    CFIIndex = MF.addFrameInst(
+    //        MCCFIInstruction::createOffset(nullptr, Reg, *RO.Offset));
+    //  } else if (RO.Reg && !RO.Offset) {
+    //    CFIIndex = MF.addFrameInst(
+    //        MCCFIInstruction::createRegister(nullptr, Reg, *RO.Reg));
+    //  } else {
+    //    llvm_unreachable("RO.Reg and RO.Offset cannot both be valid/invalid");
+    //  }
+    //  BuildMI(*MBBInfo.MBB, MBBI, DL, TII->get(TargetOpcode::CFI_INSTRUCTION))
+    //      .addCFIIndex(CFIIndex);
+    //  InsertedCFIInstr = true;
+    //}
+
+    PrevMBBInfo = &MBBInfo;
+  }
+  return InsertedCFIInstr;
+}
+
+//void RISCVCFIInstrInserter::reportCFAError(const MBBCFAInfo &Pred,
+//                                      const MBBCFAInfo &Succ) {
+//  errs() << "*** Inconsistent CFA register and/or offset between pred and succ "
+//            "***\n";
+//  errs() << "Pred: " << Pred.MBB->getName() << " #" << Pred.MBB->getNumber()
+//         << " in " << Pred.MBB->getParent()->getName()
+//         << " outgoing CFA Reg:" << Pred.OutgoingCFARegister << "\n";
+//  errs() << "Pred: " << Pred.MBB->getName() << " #" << Pred.MBB->getNumber()
+//         << " in " << Pred.MBB->getParent()->getName()
+//         << " outgoing CFA Offset:" << Pred.OutgoingCFAOffset << "\n";
+//  errs() << "Succ: " << Succ.MBB->getName() << " #" << Succ.MBB->getNumber()
+//         << " incoming CFA Reg:" << Succ.IncomingCFARegister << "\n";
+//  errs() << "Succ: " << Succ.MBB->getName() << " #" << Succ.MBB->getNumber()
+//         << " incoming CFA Offset:" << Succ.IncomingCFAOffset << "\n";
+//}
+//
+//void RISCVCFIInstrInserter::reportCSRError(const MBBCFAInfo &Pred,
+//                                      const MBBCFAInfo &Succ) {
+//  errs() << "*** Inconsistent CSR Saved between pred and succ in function "
+//         << Pred.MBB->getParent()->getName() << " ***\n";
+//  errs() << "Pred: " << Pred.MBB->getName() << " #" << Pred.MBB->getNumber()
+//         << " outgoing CSR Saved: ";
+//  for (int Reg : Pred.OutgoingCSRSaved.set_bits())
+//    errs() << Reg << " ";
+//  errs() << "\n";
+//  errs() << "Succ: " << Succ.MBB->getName() << " #" << Succ.MBB->getNumber()
+//         << " incoming CSR Saved: ";
+//  for (int Reg : Succ.IncomingCSRSaved.set_bits())
+//    errs() << Reg << " ";
+//  errs() << "\n";
+//}
+
+//unsigned RISCVCFIInstrInserter::verify(MachineFunction &MF) {
+//  unsigned ErrorNum = 0;
+//  for (auto *CurrMBB : depth_first(&MF)) {
+//    const MBBCFAInfo &CurrMBBInfo = MBBVector[CurrMBB->getNumber()];
+//    for (MachineBasicBlock *Succ : CurrMBB->successors()) {
+//      const MBBCFAInfo &SuccMBBInfo = MBBVector[Succ->getNumber()];
+//      // Check that incoming offset and register values of successors match the
+//      // outgoing offset and register values of CurrMBB
+//      if (SuccMBBInfo.IncomingCFAOffset != CurrMBBInfo.OutgoingCFAOffset ||
+//          SuccMBBInfo.IncomingCFARegister != CurrMBBInfo.OutgoingCFARegister) {
+//        // Inconsistent offsets/registers are ok for 'noreturn' blocks because
+//        // we don't generate epilogues inside such blocks.
+//        if (SuccMBBInfo.MBB->succ_empty() && !SuccMBBInfo.MBB->isReturnBlock())
+//          continue;
+//        reportCFAError(CurrMBBInfo, SuccMBBInfo);
+//        ErrorNum++;
+//      }
+//      // Check that IncomingCSRSaved of every successor matches the
+//      // OutgoingCSRSaved of CurrMBB
+//      if (SuccMBBInfo.IncomingCSRSaved != CurrMBBInfo.OutgoingCSRSaved) {
+//        reportCSRError(CurrMBBInfo, SuccMBBInfo);
+//        ErrorNum++;
+//      }
+//    }
+//  }
+//  return ErrorNum;
+//}
diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
index f5851f371545191..16d11d7d320b557 100644
--- a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
@@ -18,12 +18,14 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/ReachingDefAnalysis.h"
 #include "llvm/CodeGen/RegisterScavenging.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/MC/MCDwarf.h"
 #include "llvm/Support/LEB128.h"
 
 #include <algorithm>
+#include <unordered_set>
 
 using namespace llvm;
 
@@ -525,6 +527,181 @@ static MCCFIInstruction createDefCFAOffset(const TargetRegisterInfo &TRI,
                                         Comment.str());
 }
 
+struct CFIBuildInfo {
+  MachineBasicBlock *MBB;
+  MachineInstr *InsertAfterMI; // nullptr means insert at MBB.begin()
+  DebugLoc DL;
+  unsigned CFIIndex;
+  bool ShouldRecord = false;
+  int DwarfEHRegNum = 0;
+  int DwarfFrameReg = 0;
+  int64_t FixedOffset = 0;
+};
+
+static void trackRegisterAndEmitCFIs(
+  MachineFunction &MF,
+  MachineInstr &MI,
+  MCRegister Reg,
+  int DwarfEHRegNum,
+  const ReachingDefAnalysis &RDA,
+  const TargetInstrInfo &TII,
+  const MachineFrameInfo &MFI,
+  const RISCVRegisterInfo &TRI,
+  std::vector<CFIBuildInfo> &CFIBuildInfos,
+  std::unordered_set<MachineInstr *> &VisitedRestorePoints,
+  std::unordered_set<MachineInstr *> &VisitedDefs
+) {
+
+  if (VisitedRestorePoints.find(&MI) != VisitedRestorePoints.end()) {
+    return;
+  }
+  VisitedRestorePoints.insert(&MI);
+  SmallPtrSet<MachineInstr *, 2> Defs;
+  RDA.getGlobalReachingDefs(&MI, Reg, Defs);
+  MachineBasicBlock &EntryMBB = MF.front();
+  if (Defs.empty()) {
+    // it's a live-in register at the entry block.
+    //unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createSameValue(nullptr, DwarfEHRegNum));
+    //CFIBuildInfos.push_back({&EntryMBB, nullptr, DebugLoc(), CFIIndex});
+    return;
+  }
+
+  int FrameIndex = std::numeric_limits<int>::min();
+  for (MachineInstr *Def : Defs) {
+    if (VisitedDefs.find(Def) != VisitedDefs.end())
+      continue;
+    VisitedDefs.insert(Def);
+
+    MachineBasicBlock &MBB = *Def->getParent();
+    const DebugLoc &DL = Def->getDebugLoc();
+
+    if (Register StoredReg = TII.isStoreToStackSlot(*Def, FrameIndex)) {
+      assert(FrameIndex == Register::stackSlot2Index(Reg));
+
+      Register FrameReg;
+      StackOffset Offset = MF.getSubtarget().getFrameLowering()->getFrameIndexReference(MF, FrameIndex, FrameReg);
+      int64_t FixedOffset = Offset.getFixed();
+      // TODO:
+      assert(Offset.getScalable() == 0);
+
+      // TODO: use getSPReg
+      std::string CommentBuffer;
+      llvm::raw_string_ostream Comment(CommentBuffer);
+      int DwarfFrameReg = TRI.getDwarfRegNum(FrameReg, true);
+      // Build up the expression (SP + FixedOffset)
+      SmallString<64> Expr;
+      uint8_t Buffer[16];
+
+      Comment << FixedOffset;
+      //0x11
+      Expr.push_back(dwarf::DW_OP_consts);
+      Expr.append(Buffer, Buffer + encodeSLEB128(FixedOffset, Buffer));
+
+      //0x92
+      Expr.push_back((uint8_t)dwarf::DW_OP_bregx);
+      //0x02
+      Expr.append(Buffer, Buffer + encodeULEB128(DwarfFrameReg, Buffer));
+      Expr.push_back(0);
+
+      //0x22
+      Expr.push_back((uint8_t)dwarf::DW_OP_plus);
+      // Wrap this into DW_CFA_def_cfa.
+      SmallString<64> DefCfaExpr;
+      // 0x10
+      DefCfaExpr.push_back(dwarf::DW_CFA_expression);
+      DefCfaExpr.append(Buffer, Buffer + encodeULEB128(DwarfEHRegNum, Buffer));
+      DefCfaExpr.append(Buffer, Buffer + encodeULEB128(Expr.size(), Buffer));
+      DefCfaExpr.append(Expr.str());
+      unsigned CFIIndex = MF.addFrameInst(
+        MCCFIInstruction::createEscape(
+          nullptr,
+          DefCfaExpr.str(),
+          SMLoc(),
+          Comment.str()
+        )
+      );
+
+      CFIBuildInfos.push_back({&MBB, Def, DL, CFIIndex, true, DwarfEHRegNum, DwarfFrameReg, FixedOffset});
+      trackRegisterAndEmitCFIs(MF, *Def, StoredReg, DwarfEHRegNum, RDA, TII, MFI, TRI, CFIBuildInfos, VisitedRestorePoints, VisitedDefs);
+    }
+    else if (Register LoadedReg = TII.isLoadFromStackSlot(*Def, FrameIndex)) {
+      assert(LoadedReg == Reg);
+
+      unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createRegister(
+          nullptr, DwarfEHRegNum, TRI.getDwarfRegNum(LoadedReg, true)));
+      CFIBuildInfos.push_back({&MBB, Def, DL, CFIIndex});
+      trackRegisterAndEmitCFIs(MF, *Def, Register::index2StackSlot(FrameIndex), DwarfEHRegNum, RDA, TII, MFI, TRI, CFIBuildInfos, VisitedRestorePoints, VisitedDefs);
+    }
+    else if (auto DstSrc = TII.isCopyInstr(*Def)) {
+      Register DstReg = DstSrc->Destination->getReg();
+      Register SrcReg = DstSrc->Source->getReg();
+      assert(DstReg == Reg);
+
+      unsigned CFIIndex = MF.addFrameInst(MCCFIInstruction::createRegister(
+          nullptr, DwarfEHRegNum, TRI.getDwarfRegNum(DstReg, true)));
+      CFIBuildInfos.push_back({&MBB, Def, DL, CFIIndex});
+      trackRegisterAndEmitCFIs(MF, *Def, SrcReg, DwarfEHRegNum, RDA, TII, MFI, TRI, CFIBuildInfos, VisitedRestorePoints, VisitedDefs);
+    }
+    else {
+      llvm_unreachable("Unexpected instruction");
+    }
+  }
+  return;
+}
+
+int RISCVFrameLowering::getInitialCFAOffset(const MachineFunction &MF) const {
+  return 0;
+}
+
+Register
+RISCVFrameLowering::getInitialCFARegister(const MachineFunction &MF) const {
+  return RISCV::X2;
+}
+
+void RISCVFrameLowering::emitCFIsForCSRsHandledByRA(MachineFunction &MF, ReachingDefAnalysis *RDA) const {
+  if (!STI.doCSRSavesInRA())
+    return;
+  const RISCVInstrInfo &TII = *STI.getInstrInfo();
+  const RISCVRegisterInfo &TRI = *STI.getRegisterInfo();
+  const MachineFrameInfo &MFI = MF.getFrameInfo();
+
+  BitVector MustCalleeSavedRegs;
+  determineMustCalleeSaves(MF, MustCalleeSavedRegs);
+  const MCPhysReg * CSRegs = MF.getRegInfo().getCalleeSavedRegs();
+  SmallVector<MCPhysReg, 4> EligibleRegs;
+  for (int i = 0; CSRegs[i]; ++i) {
+    unsigned Reg = CSRegs[i];
+    if (!MustCalleeSavedRegs.test(Reg))
+      EligibleRegs.push_back(CSRegs[i]);
+  }
+
+  SmallVector<MachineInstr *, 4> RestorePoints;
+  for (MachineBasicBlock &MBB : MF) {
+    if (MBB.isReturnBlock()) 
+      RestorePoints.push_back(&MBB.back());
+  }
+  std::vector<CFIBuildInfo> CFIBuildInfos;
+  for (MCPhysReg Reg : EligibleRegs) {
+    std::unordered_set<MachineInstr *> VisitedDefs;
+    for (MachineInstr *RestorePoint : RestorePoints) {
+      std::unordered_set<MachineInstr *> VisitedRestorePoints;
+      trackRegisterAndEmitCFIs(MF, *RestorePoint, Reg, TRI.getDwarfRegNum(Reg, true), *RDA, TII, MFI, TRI, CFIBuildInfos, VisitedRestorePoints, VisitedDefs);
+    }
+  }
+  for (CFIBuildInfo &Info : CFIBuildInfos) {
+    MachineBasicBlock::iterator InsertPos = Info.InsertAfterMI ? ++(Info.InsertAfterMI->getIterator()) : Info.MBB->begin();
+    MachineInstr *CFIInstr = BuildMI(*Info.MBB, InsertPos, Info.DL, TII.get(TargetOpcode::CFI_INSTRUCTION))
+        .addCFIIndex(Info.CFIIndex)
+        .setMIFlag(MachineInstr::FrameSetup);
+    if (Info.ShouldRecord) {
+      RISCVMachineFunctionInfo &RVFI = *MF.getInfo<RISCVMachineFunctionInfo>();
+      RVFI.recordCFIInfo(CFIInstr, Info.DwarfEHRegNum, Info.DwarfFrameReg, Info.FixedOffset);
+    }
+  }
+  return;
+}
+
+
 void RISCVFrameLowering::emitPrologue(MachineFunction &MF,
                                       MachineBasicBlock &MBB) const {
   MachineFrameInfo &MFI = MF.getFrameInfo();
@@ -1057,17 +1234,55 @@ RISCVFrameLowering::getFrameIndexReference(const MachineFunction &MF, int FI,
   return Offset;
 }
 
-void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
-                                              BitVector &SavedRegs,
-                                              RegScavenger *RS) const {
-  TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
-  // Unconditionally spill RA and FP only if the function uses a frame
-  // pointer.
+void RISCVFrameLowering::determineMustCalleeSaves(MachineFunction &MF,
+                                              BitVector &SavedRegs) const {
+  const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
+
+  // Resize before the early returns. Some backends expect that
+  // SavedRegs.size() == TRI.getNumRegs() after this call even if there are no
+  // saved registers.
+  SavedRegs.resize(TRI.getNumRegs());
+
+  // When interprocedural register allocation is enabled caller saved registers
+  // are preferred over callee saved registers.
+  if (MF.getTarget().Options.EnableIPRA &&
+      isSafeForNoCSROpt(MF.getFunction()) &&
+      isProfitableForNoCSROpt(MF.getFunction()))
+    return;
+
+  // Get the callee saved register list...
+  const MCPhysReg *CSRegs = MF.getRegInfo().getCalleeSavedRegs();
+
+  // Early exit if there are no callee saved registers.
+  if (!CSRegs || CSRegs[0] == 0)
+    return;
+
+  // In Naked functions we aren't going to save any registers.
+  if (MF.getFunction().hasFnAttribute(Attribute::Naked))
+    return;
+
+  // Noreturn+nounwind functions never restore CSR, so no saves are needed.
+  // Purely noreturn functions may still return through throws, so those must
+  // save CSR for caller exception handlers.
+  //
+  // If the function uses longjmp to break out of its current path of
+  // execution we do not need the CSR spills either: setjmp stores all CSRs
+  // it was called with into the jmp_buf, which longjmp then restores.
+  if (MF.getFunction().hasFnAttribute(Attribute::NoReturn) &&
+        MF.getFunction().hasFnAttribute(Attribute::NoUnwind) &&
+        !MF.getFunction().hasFnAttribute(Attribute::UWTable) &&
+        enableCalleeSaveSkip(MF))
+    return;
+
+  // Functions which call __builtin_unwind_init get all their registers saved.
+  if (MF.callsUnwindInit()) {
+    SavedRegs.set();
+    return;
+  }
   if (hasFP(MF)) {
-    SavedRegs.set(RAReg);
-    SavedRegs.set(FPReg);
+    SavedRegs.set(RISCV::X1);
+    SavedRegs.set(RISCV::X8);
   }
-  // Mark BP as used if function has dedicated base pointer.
   if (hasBP(MF))
     SavedRegs.set(RISCVABI::getBPReg());
 
@@ -1077,6 +1292,17 @@ void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
     SavedRegs.set(RISCV::X27);
 }
 
+void RISCVFrameLowering::determineCalleeSaves(MachineFunction &MF,
+                                              BitVector &SavedRegs,
+                                              RegScavenger *RS) const {
+  const auto &ST = MF.getSubtarget<RISCVSubtarget>();
+  determineMustCalleeSaves(MF, SavedRegs);
+  if (ST.doCSRSavesInRA())
+    return;
+
+  TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS);
+}
+
 std::pair<int64_t, Align>
 RISCVFrameLowering::assignRVVStackObjectOffsets(MachineFunction &MF) const {
   MachineFrameInfo &MFI = MF.getFrameInfo();
diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.h b/llvm/lib/Target/RISCV/RISCVFrameLowering.h
index f45fcdb0acd6bc8..e97c6ca7335de37 100644
--- a/llvm/lib/Target/RISCV/RISCVFrameLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.h
@@ -23,6 +23,11 @@ class RISCVFrameLowering : public TargetFrameLowering {
 public:
   explicit RISCVFrameLowering(const RISCVSubtarget &STI);
 
+  int getInitialCFAOffset(const MachineFunction &MF) const override;
+  Register
+  getInitialCFARegister(const MachineFunction &MF) const override;
+  void emitCFIsForCSRsHandledByRA(MachineFunction &MF, ReachingDefAnalysis *RDA) const override;
+
   void emitPrologue(MachineFunction &MF, MachineBasicBlock &MBB) const override;
   void emitEpilogue(MachineFunction &MF, MachineBasicBlock &MBB) const override;
 
@@ -31,6 +36,7 @@ class RISCVFrameLowering : public TargetFrameLowering {
   StackOffset getFrameIndexReference(const MachineFunction &MF, int FI,
                                      Register &FrameReg) const override;
 
+  void determineMustCalleeSaves(MachineFunction &MF, BitVector &SavedRegs) const;
   void determineCalleeSaves(MachineFunction &MF, BitVector &SavedRegs,
                             RegScavenger *RS) const override;
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index af7a39b2580a372..c21d8782d5aeb13 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -21874,6 +21874,125 @@ bool RISCVTargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
   return false;
 }
 
+static MachineInstr *findInstrWhichNeedAllCSRs(MachineBasicBlock &MBB) {
+  // Some instructions may require (implicitly) all CSRs to be saved.
+  // For example, call to __cxa_throw is noreturn, but expects that all CSRs are taken care of.
+  // TODO: try to speedup this?
+  for (MachineInstr &MI : MBB) {
+    unsigned Opc = MI.getOpcode();
+    if (Opc != RISCV::PseudoCALL && Opc != RISCV::PseudoTAIL)
+      continue;
+    MachineOperand &MO = MI.getOperand(0);
+    StringRef Name = "";
+    if (MO.isSymbol()) {
+      Name = MO.getSymbolName();
+    } else if (MO.isGlobal()) {
+      Name = MO.getGlobal()->getName();
+    } else {
+      llvm_unreachable("Unexpected operand type.");
+    }
+    if (
+      Name == "__cxa_throw"
+      || Name == "__cxa_rethrow"
+      || Name == "_Unwind_Resume"
+    )
+      return &MI;
+  }
+  return nullptr;
+}
+
+void RISCVTargetLowering::finalizeLowering(MachineFunction &MF) const {
+  if (!Subtarget.doCSRSavesInRA()) { 
+    TargetLoweringBase::finalizeLowering(MF);
+    return;
+  }
+
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  const TargetInstrInfo &TII = *MF.getSubtarget().getInstrInfo();
+  const RISCVRegisterInfo &TRI = *Subtarget.getRegisterInfo();
+  const RISCVFrameLowering &TFI = *Subtarget.getFrameLowering();
+
+  SmallVector<MachineInstr *, 4> RestorePoints;
+  SmallVector<MachineBasicBlock *, 4> SaveMBBs;
+  SaveMBBs.push_back(&MF.front());
+  for (MachineBasicBlock &MBB : MF) {
+    if (MBB.isReturnBlock()) 
+      RestorePoints.push_back(&MBB.back());
+    if (MachineInstr *CallToCxaThrow = findInstrWhichNeedAllCSRs(MBB)) {
+      //MachineBasicBlock::iterator MII = CallToCxaThrow->getIterator();
+      //++MII;
+      //assert(MII->getOpcode() == RISCV::ADJCALLSTACKUP && "Unexpected instruction");
+      //++MII;
+      MachineBasicBlock::iterator MII = MBB.getFirstTerminator();
+      MachineInstr *NewRetMI = BuildMI(
+        MBB,
+        MII,
+        CallToCxaThrow->getDebugLoc(),
+        TII.get(RISCV::UnreachableRET)
+      );
+      RestorePoints.push_back(NewRetMI);
+      MII = ++NewRetMI->getIterator();
+      MBB.erase(MII, MBB.end());
+    }
+  }
+
+  const MCPhysReg * CSRegs = MF.getRegInfo().getCalleeSavedRegs();
+  SmallVector<MCPhysReg, 4> EligibleRegs;
+  BitVector MustCalleeSavedRegs;
+  TFI.determineMustCalleeSaves(MF, MustCalleeSavedRegs);
+  for (int i = 0; CSRegs[i]; ++i) {
+    unsigned Reg = CSRegs[i];
+    if (!MustCalleeSavedRegs.test(Reg)) {
+      EligibleRegs.push_back(CSRegs[i]);
+    }
+  }
+
+  SmallVector<Register, 4> VRegs;
+  for (MachineBasicBlock *SaveMBB : SaveMBBs) {
+    for (MCPhysReg Reg : EligibleRegs) {
+      SaveMBB->addLiveIn(Reg);
+      // TODO: should we use Maximal register class instead?
+      Register VReg = MRI.createVirtualRegister(TRI.getLargestLegalSuperClass(TRI.getMinimalPhysRegClass(Reg), MF));
+      VRegs.push_back(VReg);
+      BuildMI(
+        *SaveMBB,
+        SaveMBB->begin(),
+        SaveMBB->findDebugLoc(SaveMBB->begin()),
+        TII.get(TargetOpcode::COPY),
+        VReg
+      )
+      .addReg(Reg);
+      MRI.setSimpleHint(VReg, Reg);
+    }
+  }
+
+  for (MachineInstr *RestorePoint : RestorePoints) {
+    auto VRegI = VRegs.begin();
+    for (MCPhysReg Reg : EligibleRegs) {
+      Register VReg = *VRegI;
+      BuildMI(
+        *RestorePoint->getParent(),
+        RestorePoint->getIterator(),
+        RestorePoint->getDebugLoc(),
+        TII.get(TargetOpcode::COPY),
+        Reg
+      )
+      .addReg(VReg);
+      RestorePoint->addOperand(
+        MF,
+        MachineOperand::CreateReg(
+          Reg,
+          /*isDef=*/false,
+          /*isImplicit=*/true
+        )
+      );
+      VRegI++;
+    }
+  }
+
+  TargetLoweringBase::finalizeLowering(MF);
+}
+
 SDValue
 RISCVTargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor,
                                    SelectionDAG &DAG,
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 0b07ad7d7a423f0..f625176bdcad403 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -883,6 +883,8 @@ class RISCVTargetLowering : public TargetLowering {
 
   bool fallBackToDAGISel(const Instruction &Inst) const override;
 
+  void finalizeLowering(MachineFunction &MF) const override;
+
   bool lowerInterleavedLoad(LoadInst *LI,
                             ArrayRef<ShuffleVectorInst *> Shuffles,
                             ArrayRef<unsigned> Indices,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index c3aa367486627a2..83eccedb204619d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -293,6 +293,14 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
 
   unsigned getTailDuplicateSize(CodeGenOptLevel OptLevel) const override;
 
+  bool expandPostRAPseudo(MachineInstr &MI) const override {
+    if (MI.getOpcode() == RISCV::UnreachableRET) {
+      MI.eraseFromParent();
+      return true;
+    }
+    return false;
+  }
+
 protected:
   const RISCVSubtarget &STI;
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
index a867368235584c0..0826f6daa390c35 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
@@ -1615,6 +1615,10 @@ let isBarrier = 1, isReturn = 1, isTerminator = 1 in
 def PseudoRET : Pseudo<(outs), (ins), [(riscv_ret_glue)]>,
                 PseudoInstExpansion<(JALR X0, X1, 0)>;
 
+let isBarrier = 1, isReturn = 1, isTerminator = 1, isMeta = 1, hasSideEffects = 1, mayLoad = 0, mayStore = 0 in
+def UnreachableRET : Pseudo<(outs), (ins), []>;
+
+
 // PseudoTAIL is a pseudo instruction similar to PseudoCALL and will eventually
 // expand to auipc and jalr while encoding.
 // Define AsmString to print "tail" when compile with -S flag.
diff --git a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp
index d0c363042f5118c..b2a582d0ae79f65 100644
--- a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.cpp
@@ -43,3 +43,31 @@ void RISCVMachineFunctionInfo::addSExt32Register(Register Reg) {
 bool RISCVMachineFunctionInfo::isSExt32Register(Register Reg) const {
   return is_contained(SExt32Registers, Reg);
 }
+
+void RISCVMachineFunctionInfo::recordCFIInfo(
+  MachineInstr* MI,
+  int Reg,
+  int FrameReg,
+  int64_t Offset
+) {
+  assert(Reg >= 0 && "Negative dwarf reg number!");
+  CFIInfoMap[MI] = std::make_tuple(Reg, FrameReg, Offset);
+}
+
+bool RISCVMachineFunctionInfo::getCFIInfo(
+  MachineInstr* MI,
+  int &Reg,
+  int &FrameReg,
+  int64_t &Offset
+) {
+  auto Found = CFIInfoMap.find(MI);
+  if (Found == CFIInfoMap.end()) {
+    return false;
+  }
+  Reg = get<0>(Found->second);
+  FrameReg = get<1>(Found->second);
+  assert(Reg >= 0 && "Negative dwarf reg number!");
+  assert(FrameReg >= 0 && "Negative dwarf reg number!");
+  Offset = get<2>(Found->second);
+  return true;
+}
diff --git a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h
index 779c652b4d8fc49..09aa81fdaaee1cb 100644
--- a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h
@@ -14,6 +14,7 @@
 #define LLVM_LIB_TARGET_RISCV_RISCVMACHINEFUNCTIONINFO_H
 
 #include "RISCVSubtarget.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/CodeGen/MIRYamlMapping.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
@@ -76,6 +77,9 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {
   unsigned RVPushRegs = 0;
   int RVPushRlist = llvm::RISCVZC::RLISTENCODE::INVALID_RLIST;
 
+
+  SmallDenseMap<MachineInstr *, std::tuple<int, int, int64_t>> CFIInfoMap;
+
 public:
   RISCVMachineFunctionInfo(const Function &F, const TargetSubtargetInfo *STI) {}
 
@@ -157,6 +161,19 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {
 
   bool isVectorCall() const { return IsVectorCall; }
   void setIsVectorCall() { IsVectorCall = true; }
+
+  void recordCFIInfo(
+    MachineInstr* MI,
+    int Reg,
+    int FrameReg,
+    int64_t Offset
+  );
+  bool getCFIInfo(
+    MachineInstr* MI,
+    int &Reg,
+    int &FrameReg,
+    int64_t &Offset
+  );
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
index 26195ef721db392..6c8dec484927273 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
@@ -755,6 +755,14 @@ RISCVRegisterInfo::getCallPreservedMask(const MachineFunction & MF,
 const TargetRegisterClass *
 RISCVRegisterInfo::getLargestLegalSuperClass(const TargetRegisterClass *RC,
                                              const MachineFunction &) const {
+  if (RC == &RISCV::GPRX1RegClass)
+    return &RISCV::GPRRegClass;
+  if (RC == &RISCV::GPRCRegClass)
+    return &RISCV::GPRRegClass;
+  if (RC == &RISCV::SR07RegClass)
+    return &RISCV::GPRRegClass;
+  if (RC == &RISCV::GPRJALRRegClass)
+    return &RISCV::GPRRegClass;
   if (RC == &RISCV::VMV0RegClass)
     return &RISCV::VRRegClass;
   if (RC == &RISCV::VRNoV0RegClass)
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
index e7db1ededf383b8..51bdd757f3006de 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp
@@ -61,6 +61,11 @@ static cl::opt<unsigned> RISCVMinimumJumpTableEntries(
     "riscv-min-jump-table-entries", cl::Hidden,
     cl::desc("Set minimum number of entries to use a jump table on RISCV"));
 
+static cl::opt<bool> RISCVEnableSaveCSRByRA(
+    "riscv-enable-save-csr-in-ra",
+    cl::desc("Let register alloctor do csr saves/restores"),
+    cl::init(false), cl::Hidden);
+
 void RISCVSubtarget::anchor() {}
 
 RISCVSubtarget &
@@ -129,6 +134,10 @@ bool RISCVSubtarget::useConstantPoolForLargeInts() const {
   return !RISCVDisableUsingConstantPoolForLargeInts;
 }
 
+bool RISCVSubtarget::doCSRSavesInRA() const {
+  return RISCVEnableSaveCSRByRA;
+}
+
 unsigned RISCVSubtarget::getMaxBuildIntsCost() const {
   // Loading integer from constant pool needs two instructions (the reason why
   // the minimum cost is 2): an address calculation instruction and a load
diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.h b/llvm/lib/Target/RISCV/RISCVSubtarget.h
index bf9ed3f3d716558..72e693c19ad14bd 100644
--- a/llvm/lib/Target/RISCV/RISCVSubtarget.h
+++ b/llvm/lib/Target/RISCV/RISCVSubtarget.h
@@ -271,6 +271,8 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
 
   bool useConstantPoolForLargeInts() const;
 
+  bool doCSRSavesInRA() const override;
+
   // Maximum cost used for building integers, integers will be put into constant
   // pool if exceeded.
   unsigned getMaxBuildIntsCost() const;
diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
index 72d74d2d79b1d5a..5c6f8503b30acce 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
@@ -542,6 +542,8 @@ void RISCVPassConfig::addPreEmitPass2() {
   addPass(createUnpackMachineBundles([&](const MachineFunction &MF) {
     return MF.getFunction().getParent()->getModuleFlag("kcfi");
   }));
+
+  addPass(createRISCVCFIInstrInserter());
 }
 
 void RISCVPassConfig::addMachineSSAOptimization() {



More information about the llvm-commits mailing list