[llvm] [llvm][RISCV] Support multiple save/restore points in prolog-epilog (PR #119358)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 10 13:16:34 PST 2024


================
@@ -809,21 +847,105 @@ class MachineFrameInfo {
   /// \copydoc getCalleeSavedInfo()
   std::vector<CalleeSavedInfo> &getCalleeSavedInfo() { return CSInfo; }
 
+  /// Returns callee saved info vector for provided save point in
+  /// the current function.
+  std::vector<CalleeSavedInfo> getCSInfoPerSave(MachineBasicBlock *MBB) const {
+    return CSInfoPerSave.get(MBB);
+  }
+
+  /// Returns callee saved info vector for provided restore point
+  /// in the current function.
+  std::vector<CalleeSavedInfo>
+  getCSInfoPerRestore(MachineBasicBlock *MBB) const {
+    return CSInfoPerRestore.get(MBB);
+  }
+
   /// Used by prolog/epilog inserter to set the function's callee saved
   /// information.
   void setCalleeSavedInfo(std::vector<CalleeSavedInfo> CSI) {
     CSInfo = std::move(CSI);
   }
 
+  /// Used by prolog/epilog inserter to set the function's callee saved
+  /// information for particular save point.
+  void setCSInfoPerSave(
+      DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>> CSI) {
+    CSInfoPerSave.set(CSI);
+  }
+
+  /// Used by prolog/epilog inserter to set the function's callee saved
+  /// information for particular restore point.
+  void setCSInfoPerRestore(
+      DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>> CSI) {
+    CSInfoPerRestore.set(CSI);
+  }
+
   /// Has the callee saved info been calculated yet?
   bool isCalleeSavedInfoValid() const { return CSIValid; }
 
   void setCalleeSavedInfoValid(bool v) { CSIValid = v; }
 
-  MachineBasicBlock *getSavePoint() const { return Save; }
-  void setSavePoint(MachineBasicBlock *NewSave) { Save = NewSave; }
-  MachineBasicBlock *getRestorePoint() const { return Restore; }
-  void setRestorePoint(MachineBasicBlock *NewRestore) { Restore = NewRestore; }
+  const SaveRestorePoints &getRestorePoints() const { return RestorePoints; }
+
+  const SaveRestorePoints &getSavePoints() const { return SavePoints; }
+
+  std::pair<MachineBasicBlock *, std::vector<Register>>
+  getRestorePoint(MachineBasicBlock *MBB) const {
+    if (auto It = RestorePoints.find(MBB); It != RestorePoints.end())
+      return *It;
+
+    std::vector<Register> Regs = {};
+    return std::make_pair(nullptr, Regs);
+  }
+
+  std::pair<MachineBasicBlock *, std::vector<Register>>
+  getSavePoint(MachineBasicBlock *MBB) const {
+    if (auto It = SavePoints.find(MBB); It != SavePoints.end())
+      return *It;
+
+    std::vector<Register> Regs = {};
+    return std::make_pair(nullptr, Regs);
+  }
+
+  void setSavePoints(SaveRestorePoints NewSavePoints) {
+    SavePoints = std::move(NewSavePoints);
+  }
+
+  void setRestorePoints(SaveRestorePoints NewRestorePoints) {
+    RestorePoints = std::move(NewRestorePoints);
+  }
+
+  void setSavePoint(MachineBasicBlock *MBB, std::vector<Register> &Regs) {
+    if (SavePoints.contains(MBB))
+      SavePoints[MBB] = Regs;
+    else
+      SavePoints.insert(std::make_pair(MBB, Regs));
+  }
+
+  static const SaveRestorePoints constructSaveRestorePoints(
+      const SaveRestorePoints &SRP,
+      const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &BBMap) {
+    SaveRestorePoints Pts{};
+    for (auto &Src : SRP) {
+      Pts.insert(std::make_pair(BBMap.find(Src.first)->second, Src.second));
+    }
+    return Pts;
+  }
+
+  void setRestorePoint(MachineBasicBlock *MBB, std::vector<Register> &Regs) {
+    if (RestorePoints.contains(MBB))
----------------
topperc wrote:

Why do we need to check contains?

https://github.com/llvm/llvm-project/pull/119358


More information about the llvm-commits mailing list