[llvm] [llvm] Support multiple save/restore points in mir (PR #119357)

Elizaveta Noskova via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 6 07:25:03 PDT 2025


https://github.com/enoskova-sc updated https://github.com/llvm/llvm-project/pull/119357

>From 71aaf5a608a2a033feb4e384de7bf0988903044a Mon Sep 17 00:00:00 2001
From: ens-sc <elizaveta.noskova at syntacore.com>
Date: Wed, 12 Feb 2025 12:11:01 +0300
Subject: [PATCH] [llvm] support multiple save/restore points in mir

Currently mir supports only one save and one restore point specification:

```
  savePoint:       '%bb.1'
  restorePoint:    '%bb.2'
```

This patch provide possibility to specify multiple save and multiple restore points in mir:

```
  savePoint:
    - point:           '%bb.1'
  restorePoint:
    - point:           '%bb.2'
```
while maintaining backward compatibility.
---
 llvm/include/llvm/CodeGen/MIRYamlMapping.h    | 76 +++++++++++++++--
 llvm/include/llvm/CodeGen/MachineFrameInfo.h  | 27 +++++--
 llvm/lib/CodeGen/MIRParser/MIRParser.cpp      | 60 +++++++++++---
 llvm/lib/CodeGen/MIRPrinter.cpp               | 30 +++++--
 llvm/lib/CodeGen/MachineFrameInfo.cpp         | 16 ++++
 llvm/lib/CodeGen/PrologEpilogInserter.cpp     | 45 +++++++----
 llvm/lib/CodeGen/ShrinkWrap.cpp               | 10 ++-
 llvm/lib/Target/AMDGPU/SILowerSGPRSpills.cpp  | 13 ++-
 llvm/lib/Target/PowerPC/PPCFrameLowering.cpp  |  6 +-
 .../dont-shrink-wrap-stack-mayloadorstore.mir | 20 +++--
 .../CodeGen/ARM/invalidated-save-point.ll     |  4 +-
 llvm/test/CodeGen/MIR/Generic/frame-info.mir  |  4 +-
 ...nfo-multiple-save-restore-points-parse.mir | 81 +++++++++++++++++++
 .../X86/frame-info-save-restore-points.mir    |  6 +-
 .../CodeGen/X86/shrink_wrap_dbg_value.mir     |  6 +-
 .../llvm-reduce/mir/preserve-frame-info.mir   |  6 +-
 llvm/tools/llvm-reduce/ReducerWorkItem.cpp    | 20 ++++-
 17 files changed, 350 insertions(+), 80 deletions(-)
 create mode 100644 llvm/test/CodeGen/MIR/X86/frame-info-multiple-save-restore-points-parse.mir

diff --git a/llvm/include/llvm/CodeGen/MIRYamlMapping.h b/llvm/include/llvm/CodeGen/MIRYamlMapping.h
index 0f3f945905d3a..2fe5496235903 100644
--- a/llvm/include/llvm/CodeGen/MIRYamlMapping.h
+++ b/llvm/include/llvm/CodeGen/MIRYamlMapping.h
@@ -634,6 +634,55 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::CalledGlobal)
 namespace llvm {
 namespace yaml {
 
+// Struct representing one save/restore point in the 'savePoint'/'restorePoint'
+// list
+struct SaveRestorePointEntry {
+  StringValue Point;
+
+  bool operator==(const SaveRestorePointEntry &Other) const {
+    return Point == Other.Point;
+  }
+};
+
+using SaveRestorePoints =
+    std::variant<std::vector<SaveRestorePointEntry>, StringValue>;
+
+template <> struct PolymorphicTraits<SaveRestorePoints> {
+
+  static NodeKind getKind(const SaveRestorePoints &SRPoints) {
+    if (std::holds_alternative<std::vector<SaveRestorePointEntry>>(SRPoints))
+      return NodeKind::Sequence;
+    if (std::holds_alternative<StringValue>(SRPoints))
+      return NodeKind::Scalar;
+    llvm_unreachable("Unsupported NodeKind of SaveRestorePoints");
+  }
+
+  static SaveRestorePointEntry &getAsMap(SaveRestorePoints &SRPoints) {
+    llvm_unreachable("SaveRestorePoints can't be represented as Map");
+  }
+
+  static std::vector<SaveRestorePointEntry> &
+  getAsSequence(SaveRestorePoints &SRPoints) {
+    if (!std::holds_alternative<std::vector<SaveRestorePointEntry>>(SRPoints))
+      SRPoints = std::vector<SaveRestorePointEntry>();
+
+    return std::get<std::vector<SaveRestorePointEntry>>(SRPoints);
+  }
+
+  static StringValue &getAsScalar(SaveRestorePoints &SRPoints) {
+    if (!std::holds_alternative<StringValue>(SRPoints))
+      SRPoints = StringValue();
+
+    return std::get<StringValue>(SRPoints);
+  }
+};
+
+template <> struct MappingTraits<SaveRestorePointEntry> {
+  static void mapping(IO &YamlIO, SaveRestorePointEntry &Entry) {
+    YamlIO.mapRequired("point", Entry.Point);
+  }
+};
+
 template <> struct MappingTraits<MachineJumpTable> {
   static void mapping(IO &YamlIO, MachineJumpTable &JT) {
     YamlIO.mapRequired("kind", JT.Kind);
@@ -642,6 +691,14 @@ template <> struct MappingTraits<MachineJumpTable> {
   }
 };
 
+} // namespace yaml
+} // namespace llvm
+
+LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::SaveRestorePointEntry)
+
+namespace llvm {
+namespace yaml {
+
 /// Serializable representation of MachineFrameInfo.
 ///
 /// Doesn't serialize attributes like 'StackAlignment', 'IsStackRealignable' and
@@ -669,8 +726,8 @@ struct MachineFrameInfo {
   bool HasTailCall = false;
   bool IsCalleeSavedInfoValid = false;
   unsigned LocalFrameSize = 0;
-  StringValue SavePoint;
-  StringValue RestorePoint;
+  SaveRestorePoints SavePoints;
+  SaveRestorePoints RestorePoints;
 
   bool operator==(const MachineFrameInfo &Other) const {
     return IsFrameAddressTaken == Other.IsFrameAddressTaken &&
@@ -691,7 +748,8 @@ struct MachineFrameInfo {
            HasMustTailInVarArgFunc == Other.HasMustTailInVarArgFunc &&
            HasTailCall == Other.HasTailCall &&
            LocalFrameSize == Other.LocalFrameSize &&
-           SavePoint == Other.SavePoint && RestorePoint == Other.RestorePoint &&
+           SavePoints == Other.SavePoints &&
+           RestorePoints == Other.RestorePoints &&
            IsCalleeSavedInfoValid == Other.IsCalleeSavedInfoValid;
   }
 };
@@ -723,10 +781,14 @@ template <> struct MappingTraits<MachineFrameInfo> {
     YamlIO.mapOptional("isCalleeSavedInfoValid", MFI.IsCalleeSavedInfoValid,
                        false);
     YamlIO.mapOptional("localFrameSize", MFI.LocalFrameSize, (unsigned)0);
-    YamlIO.mapOptional("savePoint", MFI.SavePoint,
-                       StringValue()); // Don't print it out when it's empty.
-    YamlIO.mapOptional("restorePoint", MFI.RestorePoint,
-                       StringValue()); // Don't print it out when it's empty.
+    YamlIO.mapOptional(
+        "savePoint", MFI.SavePoints,
+        SaveRestorePoints(
+            StringValue())); // Don't print it out when it's empty.
+    YamlIO.mapOptional(
+        "restorePoint", MFI.RestorePoints,
+        SaveRestorePoints(
+            StringValue())); // Don't print it out when it's empty.
   }
 };
 
diff --git a/llvm/include/llvm/CodeGen/MachineFrameInfo.h b/llvm/include/llvm/CodeGen/MachineFrameInfo.h
index 403e5eda949f1..b413d3f2424fd 100644
--- a/llvm/include/llvm/CodeGen/MachineFrameInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineFrameInfo.h
@@ -332,10 +332,10 @@ class MachineFrameInfo {
   /// stack objects like arguments so we can't treat them as immutable.
   bool HasTailCall = false;
 
-  /// Not null, if shrink-wrapping found a better place for the prologue.
-  MachineBasicBlock *Save = nullptr;
-  /// Not null, if shrink-wrapping found a better place for the epilogue.
-  MachineBasicBlock *Restore = nullptr;
+  /// Not empty, if shrink-wrapping found a better place for the prologue.
+  std::vector<MachineBasicBlock *> SavePoints;
+  /// Not empty, if shrink-wrapping found a better place for the epilogue.
+  std::vector<MachineBasicBlock *> RestorePoints;
 
   /// Size of the UnsafeStack Frame
   uint64_t UnsafeStackSize = 0;
@@ -825,10 +825,21 @@ class MachineFrameInfo {
 
   void setCalleeSavedInfoValid(bool v) { CSIValid = v; }
 
-  MachineBasicBlock *getSavePoint() const { return Save; }
-  void setSavePoint(MachineBasicBlock *NewSave) { Save = NewSave; }
-  MachineBasicBlock *getRestorePoint() const { return Restore; }
-  void setRestorePoint(MachineBasicBlock *NewRestore) { Restore = NewRestore; }
+  const std::vector<MachineBasicBlock *> &getSavePoints() const {
+    return SavePoints;
+  }
+  void setSavePoints(std::vector<MachineBasicBlock *> NewSavePoints) {
+    SavePoints = std::move(NewSavePoints);
+  }
+  const std::vector<MachineBasicBlock *> getRestorePoints() const {
+    return RestorePoints;
+  }
+  void setRestorePoints(std::vector<MachineBasicBlock *> NewRestorePoints) {
+    RestorePoints = std::move(NewRestorePoints);
+  }
+
+  void clearSavePoints() { SavePoints.clear(); }
+  void clearRestorePoints() { RestorePoints.clear(); }
 
   uint64_t getUnsafeStackSize() const { return UnsafeStackSize; }
   void setUnsafeStackSize(uint64_t Size) { UnsafeStackSize = Size; }
diff --git a/llvm/lib/CodeGen/MIRParser/MIRParser.cpp b/llvm/lib/CodeGen/MIRParser/MIRParser.cpp
index 3e99e576fa364..acaff8b71ba2b 100644
--- a/llvm/lib/CodeGen/MIRParser/MIRParser.cpp
+++ b/llvm/lib/CodeGen/MIRParser/MIRParser.cpp
@@ -124,6 +124,10 @@ class MIRParserImpl {
   bool initializeFrameInfo(PerFunctionMIParsingState &PFS,
                            const yaml::MachineFunction &YamlMF);
 
+  bool initializeSaveRestorePoints(PerFunctionMIParsingState &PFS,
+                                   const yaml::SaveRestorePoints &YamlSRPoints,
+                                   bool IsSavePoints);
+
   bool initializeCallSiteInfo(PerFunctionMIParsingState &PFS,
                               const yaml::MachineFunction &YamlMF);
 
@@ -867,18 +871,12 @@ bool MIRParserImpl::initializeFrameInfo(PerFunctionMIParsingState &PFS,
   MFI.setHasTailCall(YamlMFI.HasTailCall);
   MFI.setCalleeSavedInfoValid(YamlMFI.IsCalleeSavedInfoValid);
   MFI.setLocalFrameSize(YamlMFI.LocalFrameSize);
-  if (!YamlMFI.SavePoint.Value.empty()) {
-    MachineBasicBlock *MBB = nullptr;
-    if (parseMBBReference(PFS, MBB, YamlMFI.SavePoint))
-      return true;
-    MFI.setSavePoint(MBB);
-  }
-  if (!YamlMFI.RestorePoint.Value.empty()) {
-    MachineBasicBlock *MBB = nullptr;
-    if (parseMBBReference(PFS, MBB, YamlMFI.RestorePoint))
-      return true;
-    MFI.setRestorePoint(MBB);
-  }
+  if (initializeSaveRestorePoints(PFS, YamlMFI.SavePoints,
+                                  /*IsSavePoints=*/true))
+    return true;
+  if (initializeSaveRestorePoints(PFS, YamlMFI.RestorePoints,
+                                  /*IsSavePoints=*/false))
+    return true;
 
   std::vector<CalleeSavedInfo> CSIInfo;
   // Initialize the fixed frame objects.
@@ -1093,6 +1091,44 @@ bool MIRParserImpl::initializeConstantPool(PerFunctionMIParsingState &PFS,
   return false;
 }
 
+// Return true if basic block was incorrectly specified in MIR
+bool MIRParserImpl::initializeSaveRestorePoints(
+    PerFunctionMIParsingState &PFS, const yaml::SaveRestorePoints &YamlSRPoints,
+    bool IsSavePoints) {
+  MachineBasicBlock *MBB = nullptr;
+  std::vector<MachineBasicBlock *> SaveRestorePoints;
+  if (std::holds_alternative<std::vector<yaml::SaveRestorePointEntry>>(
+          YamlSRPoints)) {
+    const auto &VectorRepr =
+        std::get<std::vector<yaml::SaveRestorePointEntry>>(YamlSRPoints);
+    if (VectorRepr.empty())
+      return false;
+    for (const yaml::SaveRestorePointEntry &Entry : VectorRepr) {
+      const yaml::StringValue &MBBSource = Entry.Point;
+      if (parseMBBReference(PFS, MBB, MBBSource.Value))
+        return true;
+      SaveRestorePoints.push_back(MBB);
+    }
+  } else {
+    yaml::StringValue StringRepr = std::get<yaml::StringValue>(YamlSRPoints);
+    if (StringRepr.Value.empty())
+      return false;
+    if (parseMBBReference(PFS, MBB, StringRepr))
+      return true;
+    SaveRestorePoints.push_back(MBB);
+  }
+
+  MachineFunction &MF = PFS.MF;
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+
+  if (IsSavePoints)
+    MFI.setSavePoints(SaveRestorePoints);
+  else
+    MFI.setRestorePoints(SaveRestorePoints);
+
+  return false;
+}
+
 bool MIRParserImpl::initializeJumpTableInfo(PerFunctionMIParsingState &PFS,
     const yaml::MachineJumpTable &YamlJTI) {
   MachineJumpTableInfo *JTI = PFS.MF.getOrCreateJumpTableInfo(YamlJTI.Kind);
diff --git a/llvm/lib/CodeGen/MIRPrinter.cpp b/llvm/lib/CodeGen/MIRPrinter.cpp
index ce1834a90ca54..de9e8a4a147ee 100644
--- a/llvm/lib/CodeGen/MIRPrinter.cpp
+++ b/llvm/lib/CodeGen/MIRPrinter.cpp
@@ -150,6 +150,8 @@ static void convertMJTI(ModuleSlotTracker &MST, yaml::MachineJumpTable &YamlJTI,
                         const MachineJumpTableInfo &JTI);
 static void convertMFI(ModuleSlotTracker &MST, yaml::MachineFrameInfo &YamlMFI,
                        const MachineFrameInfo &MFI);
+static void convertSRPoints(ModuleSlotTracker &MST, yaml::SaveRestorePoints &YamlSRPoints,
+                            std::vector<MachineBasicBlock *> SaveRestorePoints);
 static void convertStackObjects(yaml::MachineFunction &YMF,
                                 const MachineFunction &MF,
                                 ModuleSlotTracker &MST, MFPrintState &State);
@@ -355,14 +357,10 @@ static void convertMFI(ModuleSlotTracker &MST, yaml::MachineFrameInfo &YamlMFI,
   YamlMFI.HasTailCall = MFI.hasTailCall();
   YamlMFI.IsCalleeSavedInfoValid = MFI.isCalleeSavedInfoValid();
   YamlMFI.LocalFrameSize = MFI.getLocalFrameSize();
-  if (MFI.getSavePoint()) {
-    raw_string_ostream StrOS(YamlMFI.SavePoint.Value);
-    StrOS << printMBBReference(*MFI.getSavePoint());
-  }
-  if (MFI.getRestorePoint()) {
-    raw_string_ostream StrOS(YamlMFI.RestorePoint.Value);
-    StrOS << printMBBReference(*MFI.getRestorePoint());
-  }
+  if (!MFI.getSavePoints().empty())
+    convertSRPoints(MST, YamlMFI.SavePoints, MFI.getSavePoints());
+  if (!MFI.getRestorePoints().empty())
+    convertSRPoints(MST, YamlMFI.RestorePoints, MFI.getRestorePoints());
 }
 
 static void convertEntryValueObjects(yaml::MachineFunction &YMF,
@@ -616,6 +614,22 @@ static void convertMCP(yaml::MachineFunction &MF,
   }
 }
 
+static void convertSRPoints(ModuleSlotTracker &MST,
+                            yaml::SaveRestorePoints &YamlSRPoints,
+                            std::vector<MachineBasicBlock *> SRPoints) {
+  auto &Points =
+      std::get<std::vector<yaml::SaveRestorePointEntry>>(YamlSRPoints);
+  for (const auto &MBB : SRPoints) {
+    SmallString<16> Str;
+    yaml::SaveRestorePointEntry Entry;
+    raw_svector_ostream StrOS(Str);
+    StrOS << printMBBReference(*MBB);
+    Entry.Point = StrOS.str().str();
+    Str.clear();
+    Points.push_back(Entry);
+  }
+}
+
 static void convertMJTI(ModuleSlotTracker &MST, yaml::MachineJumpTable &YamlJTI,
                         const MachineJumpTableInfo &JTI) {
   YamlJTI.Kind = JTI.getEntryKind();
diff --git a/llvm/lib/CodeGen/MachineFrameInfo.cpp b/llvm/lib/CodeGen/MachineFrameInfo.cpp
index e4b993850f73d..a8306b2ef2e5b 100644
--- a/llvm/lib/CodeGen/MachineFrameInfo.cpp
+++ b/llvm/lib/CodeGen/MachineFrameInfo.cpp
@@ -244,6 +244,22 @@ void MachineFrameInfo::print(const MachineFunction &MF, raw_ostream &OS) const{
     }
     OS << "\n";
   }
+  OS << "save/restore points:\n";
+
+  if (!SavePoints.empty()) {
+    OS << "save points:\n";
+
+    for (auto &item : SavePoints)
+      OS << printMBBReference(*item) << "\n";
+  } else
+    OS << "save points are empty\n";
+
+  if (!RestorePoints.empty()) {
+    OS << "restore points:\n";
+    for (auto &item : RestorePoints)
+      OS << printMBBReference(*item) << "\n";
+  } else
+    OS << "restore points are empty\n";
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
diff --git a/llvm/lib/CodeGen/PrologEpilogInserter.cpp b/llvm/lib/CodeGen/PrologEpilogInserter.cpp
index f66f54682c84c..3582a101deaaa 100644
--- a/llvm/lib/CodeGen/PrologEpilogInserter.cpp
+++ b/llvm/lib/CodeGen/PrologEpilogInserter.cpp
@@ -351,8 +351,8 @@ bool PEIImpl::run(MachineFunction &MF) {
   delete RS;
   SaveBlocks.clear();
   RestoreBlocks.clear();
-  MFI.setSavePoint(nullptr);
-  MFI.setRestorePoint(nullptr);
+  MFI.clearSavePoints();
+  MFI.clearRestorePoints();
   return true;
 }
 
@@ -423,16 +423,20 @@ void PEIImpl::calculateCallFrameInfo(MachineFunction &MF) {
 /// callee-saved registers, and placing prolog and epilog code.
 void PEIImpl::calculateSaveRestoreBlocks(MachineFunction &MF) {
   const MachineFrameInfo &MFI = MF.getFrameInfo();
-
   // Even when we do not change any CSR, we still want to insert the
   // prologue and epilogue of the function.
   // So set the save points for those.
 
   // Use the points found by shrink-wrapping, if any.
-  if (MFI.getSavePoint()) {
-    SaveBlocks.push_back(MFI.getSavePoint());
-    assert(MFI.getRestorePoint() && "Both restore and save must be set");
-    MachineBasicBlock *RestoreBlock = MFI.getRestorePoint();
+  if (!MFI.getSavePoints().empty()) {
+    assert(MFI.getSavePoints().size() < 2 &&
+           "MFI can't contain multiple save points!");
+    SaveBlocks.push_back(MFI.getSavePoints().front());
+    assert(!MFI.getRestorePoints().empty() &&
+           "Both restore and save must be set");
+    assert(MFI.getRestorePoints().size() < 2 &&
+           "MFI can't contain multiple restore points!");
+    MachineBasicBlock *RestoreBlock = MFI.getRestorePoints().front();
     // If RestoreBlock does not have any successor and is not a return block
     // then the end point is unreachable and we do not need to insert any
     // epilogue.
@@ -441,13 +445,15 @@ void PEIImpl::calculateSaveRestoreBlocks(MachineFunction &MF) {
     return;
   }
 
-  // Save refs to entry and return blocks.
-  SaveBlocks.push_back(&MF.front());
-  for (MachineBasicBlock &MBB : MF) {
-    if (MBB.isEHFuncletEntry())
-      SaveBlocks.push_back(&MBB);
-    if (MBB.isReturnBlock())
-      RestoreBlocks.push_back(&MBB);
+  if (MFI.getSavePoints().empty()) {
+    // Save refs to entry and return blocks.
+    SaveBlocks.push_back(&MF.front());
+    for (MachineBasicBlock &MBB : MF) {
+      if (MBB.isEHFuncletEntry())
+        SaveBlocks.push_back(&MBB);
+      if (MBB.isReturnBlock())
+        RestoreBlocks.push_back(&MBB);
+    }
   }
 }
 
@@ -558,7 +564,11 @@ static void updateLiveness(MachineFunction &MF) {
   SmallPtrSet<MachineBasicBlock *, 8> Visited;
   SmallVector<MachineBasicBlock *, 8> WorkList;
   MachineBasicBlock *Entry = &MF.front();
-  MachineBasicBlock *Save = MFI.getSavePoint();
+
+  assert(MFI.getSavePoints().size() < 2 &&
+         "MFI can't contain multiple save points!");
+  MachineBasicBlock *Save =
+      MFI.getSavePoints().empty() ? nullptr : MFI.getSavePoints().front();
 
   if (!Save)
     Save = Entry;
@@ -569,7 +579,10 @@ static void updateLiveness(MachineFunction &MF) {
   }
   Visited.insert(Save);
 
-  MachineBasicBlock *Restore = MFI.getRestorePoint();
+  assert(MFI.getRestorePoints().size() < 2 &&
+         "MFI can't contain multiple restore points!");
+  MachineBasicBlock *Restore =
+      MFI.getRestorePoints().empty() ? nullptr : MFI.getRestorePoints().front();
   if (Restore)
     // By construction Restore cannot be visited, otherwise it
     // means there exists a path to Restore that does not go
diff --git a/llvm/lib/CodeGen/ShrinkWrap.cpp b/llvm/lib/CodeGen/ShrinkWrap.cpp
index 41e956caa7b34..6121f8b70d9ea 100644
--- a/llvm/lib/CodeGen/ShrinkWrap.cpp
+++ b/llvm/lib/CodeGen/ShrinkWrap.cpp
@@ -967,8 +967,14 @@ bool ShrinkWrapImpl::run(MachineFunction &MF) {
                     << "\nRestore: " << printMBBReference(*Restore) << '\n');
 
   MachineFrameInfo &MFI = MF.getFrameInfo();
-  MFI.setSavePoint(Save);
-  MFI.setRestorePoint(Restore);
+  std::vector<MachineBasicBlock *> SavePoints;
+  std::vector<MachineBasicBlock *> RestorePoints;
+  if (Save) {
+    SavePoints.push_back(Save);
+    RestorePoints.push_back(Restore);
+  }
+  MFI.setSavePoints(SavePoints);
+  MFI.setRestorePoints(RestorePoints);
   ++NumCandidates;
   return Changed;
 }
diff --git a/llvm/lib/Target/AMDGPU/SILowerSGPRSpills.cpp b/llvm/lib/Target/AMDGPU/SILowerSGPRSpills.cpp
index 9509199a8d5ef..b3de4fe45ebb6 100644
--- a/llvm/lib/Target/AMDGPU/SILowerSGPRSpills.cpp
+++ b/llvm/lib/Target/AMDGPU/SILowerSGPRSpills.cpp
@@ -209,10 +209,15 @@ void SILowerSGPRSpills::calculateSaveRestoreBlocks(MachineFunction &MF) {
   // So set the save points for those.
 
   // Use the points found by shrink-wrapping, if any.
-  if (MFI.getSavePoint()) {
-    SaveBlocks.push_back(MFI.getSavePoint());
-    assert(MFI.getRestorePoint() && "Both restore and save must be set");
-    MachineBasicBlock *RestoreBlock = MFI.getRestorePoint();
+  if (!MFI.getSavePoints().empty()) {
+    assert(MFI.getSavePoints().size() < 2 &&
+           "MFI can't contain multiple save points!");
+    SaveBlocks.push_back(MFI.getSavePoints().front());
+    assert(!MFI.getRestorePoints().empty() &&
+           "Both restore and save must be set");
+    assert(MFI.getRestorePoints().size() < 2 &&
+           "MFI can't contain multiple restore points!");
+    MachineBasicBlock *RestoreBlock = MFI.getRestorePoints().front();
     // If RestoreBlock does not have any successor and is not a return block
     // then the end point is unreachable and we do not need to insert any
     // epilogue.
diff --git a/llvm/lib/Target/PowerPC/PPCFrameLowering.cpp b/llvm/lib/Target/PowerPC/PPCFrameLowering.cpp
index c0860fcdaa3b0..2ad3ed21732ed 100644
--- a/llvm/lib/Target/PowerPC/PPCFrameLowering.cpp
+++ b/llvm/lib/Target/PowerPC/PPCFrameLowering.cpp
@@ -2078,8 +2078,10 @@ void PPCFrameLowering::processFunctionBeforeFrameFinalized(MachineFunction &MF,
   // tail call might not be in the new RestoreBlock, so real branch instruction
   // won't be generated by emitEpilogue(), because shrink-wrap has chosen new
   // RestoreBlock. So we handle this case here.
-  if (MFI.getSavePoint() && MFI.hasTailCall()) {
-    MachineBasicBlock *RestoreBlock = MFI.getRestorePoint();
+  if (!MFI.getSavePoints().empty() && MFI.hasTailCall()) {
+    assert(MFI.getRestorePoints().size() < 2 &&
+           "MFI can't contain multiple restore points!");
+    MachineBasicBlock *RestoreBlock = MFI.getRestorePoints().front();
     for (MachineBasicBlock &MBB : MF) {
       if (MBB.isReturnBlock() && (&MBB) != RestoreBlock)
         createTailCallBranchInstr(MBB);
diff --git a/llvm/test/CodeGen/AArch64/dont-shrink-wrap-stack-mayloadorstore.mir b/llvm/test/CodeGen/AArch64/dont-shrink-wrap-stack-mayloadorstore.mir
index b8fd76681c5eb..b912a346c1b66 100644
--- a/llvm/test/CodeGen/AArch64/dont-shrink-wrap-stack-mayloadorstore.mir
+++ b/llvm/test/CodeGen/AArch64/dont-shrink-wrap-stack-mayloadorstore.mir
@@ -7,17 +7,23 @@
  ; RUN: llc -x=mir -simplify-mir -passes='shrink-wrap' -o - %s | FileCheck %s
  ; CHECK:      name:            compiler_pop_stack
  ; CHECK:      frameInfo:
- ; CHECK:      savePoint:       '%bb.1'
- ; CHECK:      restorePoint:    '%bb.7'
+ ; CHECK:        savePoint:
+ ; CHECK-NEXT:     - point:           '%bb.1'
+ ; CHECK:        restorePoint:
+ ; CHECK-NEXT:     - point:           '%bb.7'
  ; CHECK:      name:            compiler_pop_stack_no_memoperands
  ; CHECK:      frameInfo:
- ; CHECK:      savePoint:       '%bb.1'
- ; CHECK:      restorePoint:    '%bb.7'
+ ; CHECK:        savePoint:
+ ; CHECK-NEXT:     - point:           '%bb.1'
+ ; CHECK:        restorePoint:
+ ; CHECK-NEXT:     - point:           '%bb.7'
  ; CHECK:      name:            f
  ; CHECK:      frameInfo:
- ; CHECK:      savePoint:       '%bb.2'
- ; CHECK-NEXT: restorePoint:    '%bb.4'
- ; CHECK-NEXT: stack:
+ ; CHECK:        savePoint:
+ ; CHECK-NEXT:     - point:           '%bb.2'
+ ; CHECK:        restorePoint:
+ ; CHECK-NEXT:     - point:           '%bb.4'
+ ; CHECK:      stack:
 
   target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
   target triple = "aarch64"
diff --git a/llvm/test/CodeGen/ARM/invalidated-save-point.ll b/llvm/test/CodeGen/ARM/invalidated-save-point.ll
index bb602308a1793..4179316572990 100644
--- a/llvm/test/CodeGen/ARM/invalidated-save-point.ll
+++ b/llvm/test/CodeGen/ARM/invalidated-save-point.ll
@@ -4,8 +4,8 @@
 ; this point. Notably, if it isn't is will be invalid and reference a
 ; deleted block (%bb.-1.if.end)
 
-; CHECK: savePoint: ''
-; CHECK: restorePoint: ''
+; CHECK: savePoint: []
+; CHECK: restorePoint: []
 
 target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64"
 target triple = "thumbv7"
diff --git a/llvm/test/CodeGen/MIR/Generic/frame-info.mir b/llvm/test/CodeGen/MIR/Generic/frame-info.mir
index d5e014cf62991..e8f3a83fcda89 100644
--- a/llvm/test/CodeGen/MIR/Generic/frame-info.mir
+++ b/llvm/test/CodeGen/MIR/Generic/frame-info.mir
@@ -46,8 +46,8 @@ tracksRegLiveness: true
 # CHECK-NEXT: hasTailCall: false
 # CHECK-NEXT: isCalleeSavedInfoValid: false
 # CHECK-NEXT: localFrameSize: 0
-# CHECK-NEXT: savePoint:       ''
-# CHECK-NEXT: restorePoint:    ''
+# CHECK-NEXT: savePoint:       []
+# CHECK-NEXT: restorePoint:    []
 # CHECK: body
 frameInfo:
   maxAlignment:    4
diff --git a/llvm/test/CodeGen/MIR/X86/frame-info-multiple-save-restore-points-parse.mir b/llvm/test/CodeGen/MIR/X86/frame-info-multiple-save-restore-points-parse.mir
new file mode 100644
index 0000000000000..4c60ccd573595
--- /dev/null
+++ b/llvm/test/CodeGen/MIR/X86/frame-info-multiple-save-restore-points-parse.mir
@@ -0,0 +1,81 @@
+# RUN: llc -mtriple=x86_64 -run-pass none -o - %s | FileCheck %s
+# This test ensures that the MIR parser parses the save and restore points in
+# the machine frame info correctly.
+
+--- |
+
+  define i32 @foo(i32 %a, i32 %b) {
+  entry:
+    %tmp = alloca i32, align 4
+    %tmp2 = icmp slt i32 %a, %b
+    br i1 %tmp2, label %true, label %false
+
+  true:
+    store i32 %a, ptr %tmp, align 4
+    %tmp4 = call i32 @doSomething(i32 0, ptr %tmp)
+    br label %false
+
+  false:
+    %tmp.0 = phi i32 [ %tmp4, %true ], [ %a, %entry ]
+    ret i32 %tmp.0
+  }
+
+  declare i32 @doSomething(i32, ptr)
+
+...
+---
+name:            foo
+tracksRegLiveness: true
+liveins:
+  - { reg: '$edi' }
+  - { reg: '$esi' }
+# CHECK: frameInfo:
+# CHECK:      savePoint:
+# CHECK-NEXT:   - point:           '%bb.1'
+# CHECK-NEXT:   - point:           '%bb.2'
+# CHECK:      restorePoint:
+# CHECK-NEXT:   - point:           '%bb.2'
+# CHECK-NEXT:   - point:           '%bb.3'
+# CHECK: stack
+frameInfo:
+  maxAlignment:  4
+  hasCalls:      true
+  savePoint:
+    - point:     '%bb.1'
+    - point:     '%bb.2'
+  restorePoint:  
+    - point:     '%bb.2'
+    - point:     '%bb.3'
+stack:
+  - { id: 0, name: tmp, offset: 0, size: 4, alignment: 4 }
+body: |
+  bb.0:
+    successors: %bb.2, %bb.1
+    liveins: $edi, $esi
+
+    $eax = COPY $edi
+    CMP32rr $eax, killed $esi, implicit-def $eflags
+    JCC_1 %bb.2, 12, implicit killed $eflags
+
+  bb.1:
+    successors: %bb.3
+    liveins: $eax
+
+    JMP_1 %bb.3
+
+  bb.2.true:
+    successors: %bb.3
+    liveins: $eax
+
+    MOV32mr %stack.0.tmp, 1, _, 0, _, killed $eax
+    ADJCALLSTACKDOWN64 0, 0, 0, implicit-def $rsp, implicit-def $ssp, implicit-def dead $eflags, implicit $rsp, implicit $ssp
+    $rsi = LEA64r %stack.0.tmp, 1, _, 0, _
+    $edi = MOV32r0 implicit-def dead $eflags
+    CALL64pcrel32 @doSomething, csr_64, implicit $rsp, implicit $ssp, implicit $edi, implicit $rsi, implicit-def $rsp, implicit-def $ssp, implicit-def $eax
+    ADJCALLSTACKUP64 0, 0, implicit-def $rsp, implicit-def $ssp, implicit-def dead $eflags, implicit $rsp, implicit $ssp
+
+  bb.3.false:
+    liveins: $eax
+
+    RET64 $eax
+...
diff --git a/llvm/test/CodeGen/MIR/X86/frame-info-save-restore-points.mir b/llvm/test/CodeGen/MIR/X86/frame-info-save-restore-points.mir
index e26233f946606..bd2d45046123a 100644
--- a/llvm/test/CodeGen/MIR/X86/frame-info-save-restore-points.mir
+++ b/llvm/test/CodeGen/MIR/X86/frame-info-save-restore-points.mir
@@ -30,8 +30,10 @@ liveins:
   - { reg: '$edi' }
   - { reg: '$esi' }
 # CHECK: frameInfo:
-# CHECK:      savePoint: '%bb.2'
-# CHECK-NEXT: restorePoint: '%bb.2'
+# CHECK:      savePoint:
+# CHECK-NEXT:   - point:           '%bb.2'
+# CHECK:      restorePoint:
+# CHECK-NEXT:   - point:           '%bb.2'
 # CHECK: stack
 frameInfo:
   maxAlignment:  4
diff --git a/llvm/test/CodeGen/X86/shrink_wrap_dbg_value.mir b/llvm/test/CodeGen/X86/shrink_wrap_dbg_value.mir
index af2a18d92b34f..139a3c46c51c1 100644
--- a/llvm/test/CodeGen/X86/shrink_wrap_dbg_value.mir
+++ b/llvm/test/CodeGen/X86/shrink_wrap_dbg_value.mir
@@ -120,8 +120,10 @@ frameInfo:
   hasOpaqueSPAdjustment: false
   hasVAStart:      false
   hasMustTailInVarArgFunc: false
-  # CHECK: savePoint:       '%bb.1'
-  # CHECK: restorePoint:    '%bb.3'
+  # CHECK:      savePoint:
+  # CHECK-NEXT:   - point:           '%bb.1'
+  # CHECK:      restorePoint:
+  # CHECK-NEXT:   - point:           '%bb.3'
   savePoint:       ''
   restorePoint:    ''
 fixedStack:      
diff --git a/llvm/test/tools/llvm-reduce/mir/preserve-frame-info.mir b/llvm/test/tools/llvm-reduce/mir/preserve-frame-info.mir
index d7ad5f88874d7..e47fd689912b1 100644
--- a/llvm/test/tools/llvm-reduce/mir/preserve-frame-info.mir
+++ b/llvm/test/tools/llvm-reduce/mir/preserve-frame-info.mir
@@ -20,8 +20,10 @@
 # RESULT-NEXT: hasVAStart:      true
 # RESULT-NEXT: hasMustTailInVarArgFunc: true
 # RESULT-NEXT: hasTailCall:     true
-# RESULT-NEXT: savePoint:       '%bb.1'
-# RESULT-NEXT: restorePoint:    '%bb.2'
+# RESULT-NEXT: savePoint:
+# RESULT-NEXT:   - point:           '%bb.1'
+# RESULT-NEXT: restorePoint:
+# RESULT-NEXT:   - point:           '%bb.1'
 
 # RESULT-NEXT: fixedStack:
 # RESULT-NEXT:  - { id: 0, offset: 56, size: 4, alignment: 8, callee-saved-register: '$sgpr44',
diff --git a/llvm/tools/llvm-reduce/ReducerWorkItem.cpp b/llvm/tools/llvm-reduce/ReducerWorkItem.cpp
index 3f84e8f6c8d40..6178e7cb79a2d 100644
--- a/llvm/tools/llvm-reduce/ReducerWorkItem.cpp
+++ b/llvm/tools/llvm-reduce/ReducerWorkItem.cpp
@@ -92,11 +92,23 @@ static void cloneFrameInfo(
   DstMFI.setCVBytesOfCalleeSavedRegisters(
       SrcMFI.getCVBytesOfCalleeSavedRegisters());
 
-  if (MachineBasicBlock *SavePt = SrcMFI.getSavePoint())
-    DstMFI.setSavePoint(Src2DstMBB.find(SavePt)->second);
-  if (MachineBasicBlock *RestorePt = SrcMFI.getRestorePoint())
-    DstMFI.setRestorePoint(Src2DstMBB.find(RestorePt)->second);
+  assert(SrcMFI.getSavePoints().size() < 2 &&
+         "MFI can't contain multiple save points!");
+  if (!SrcMFI.getSavePoints().empty()) {
+    MachineBasicBlock *SavePt = SrcMFI.getSavePoints().front();
+    std::vector<MachineBasicBlock *> SavePts;
+    SavePts.push_back(Src2DstMBB.find(SavePt)->second);
+    DstMFI.setSavePoints(SavePts);
+  }
 
+  assert(SrcMFI.getRestorePoints().size() < 2 &&
+         "MFI can't contain multiple restore points!");
+  if (!SrcMFI.getRestorePoints().empty()) {
+    MachineBasicBlock *RestorePt = SrcMFI.getRestorePoints().front();
+    std::vector<MachineBasicBlock *> RestorePts;
+    RestorePts.push_back(Src2DstMBB.find(RestorePt)->second);
+    DstMFI.setRestorePoints(RestorePts);
+  }
 
   auto CopyObjectProperties = [](MachineFrameInfo &DstMFI,
                                  const MachineFrameInfo &SrcMFI, int FI) {



More information about the llvm-commits mailing list