[llvm] [llvm] Support multiple save/restore points in mir (PR #119357)

Elizaveta Noskova via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 12 01:16:06 PST 2025


https://github.com/enoskova-sc updated https://github.com/llvm/llvm-project/pull/119357

>From 1daee32b5ee8f05f1a86f8a40cd76e085c3780fa Mon Sep 17 00:00:00 2001
From: ens-sc <elizaveta.noskova at syntacore.com>
Date: Wed, 12 Feb 2025 12:11:01 +0300
Subject: [PATCH] [llvm] support multiple save/restore points in mir

Currently mir supports only one save and one restore point specification:

```
  savePoint:       '%bb.1'
  restorePoint:    '%bb.2'
```

This patch provide possibility to specify multiple save and multiple restore points in mir:

```
  savePoint:
    - point:           '%bb.1'
  restorePoint:
    - point:           '%bb.2'
```
while maintaining backward compatibility.
---
 llvm/include/llvm/CodeGen/MIRYamlMapping.h    | 74 +++++++++++++++++--
 llvm/lib/CodeGen/MIRParser/MIRParser.cpp      | 55 ++++++++++----
 llvm/lib/CodeGen/MIRPrinter.cpp               | 26 +++++--
 .../dont-shrink-wrap-stack-mayloadorstore.mir | 20 +++--
 .../CodeGen/ARM/invalidated-save-point.ll     |  4 +-
 llvm/test/CodeGen/MIR/Generic/frame-info.mir  |  4 +-
 .../X86/frame-info-save-restore-points.mir    |  6 +-
 .../CodeGen/X86/shrink_wrap_dbg_value.mir     |  6 +-
 .../llvm-reduce/mir/preserve-frame-info.mir   |  8 +-
 9 files changed, 156 insertions(+), 47 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MIRYamlMapping.h b/llvm/include/llvm/CodeGen/MIRYamlMapping.h
index dbad3469d047d..6f7a1f1d65e46 100644
--- a/llvm/include/llvm/CodeGen/MIRYamlMapping.h
+++ b/llvm/include/llvm/CodeGen/MIRYamlMapping.h
@@ -631,6 +631,53 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::CalledGlobal)
 namespace llvm {
 namespace yaml {
 
+struct SaveRestorePointEntry {
+  StringValue Point;
+
+  bool operator==(const SaveRestorePointEntry &Other) const {
+    return Point == Other.Point;
+  }
+};
+
+using SaveRestorePoints =
+    std::variant<std::vector<SaveRestorePointEntry>, StringValue>;
+
+template <> struct PolymorphicTraits<SaveRestorePoints> {
+
+  static NodeKind getKind(const SaveRestorePoints &SRPoints) {
+    if (std::holds_alternative<std::vector<SaveRestorePointEntry>>(SRPoints))
+      return NodeKind::Sequence;
+    if (std::holds_alternative<StringValue>(SRPoints))
+      return NodeKind::Scalar;
+    llvm_unreachable("Unknown map value kind in SaveRestorePoints");
+  }
+
+  static SaveRestorePointEntry &getAsMap(SaveRestorePoints &SRPoints) {
+    llvm_unreachable("111");
+  }
+
+  static std::vector<SaveRestorePointEntry> &
+  getAsSequence(SaveRestorePoints &SRPoints) {
+    if (!std::holds_alternative<std::vector<SaveRestorePointEntry>>(SRPoints))
+      SRPoints = std::vector<SaveRestorePointEntry>();
+
+    return std::get<std::vector<SaveRestorePointEntry>>(SRPoints);
+  }
+
+  static StringValue &getAsScalar(SaveRestorePoints &SRPoints) {
+    if (!std::holds_alternative<StringValue>(SRPoints))
+      SRPoints = StringValue();
+
+    return std::get<StringValue>(SRPoints);
+  }
+};
+
+template <> struct MappingTraits<SaveRestorePointEntry> {
+  static void mapping(IO &YamlIO, SaveRestorePointEntry &Entry) {
+    YamlIO.mapRequired("point", Entry.Point);
+  }
+};
+
 template <> struct MappingTraits<MachineJumpTable> {
   static void mapping(IO &YamlIO, MachineJumpTable &JT) {
     YamlIO.mapRequired("kind", JT.Kind);
@@ -639,6 +686,14 @@ template <> struct MappingTraits<MachineJumpTable> {
   }
 };
 
+} // namespace yaml
+} // namespace llvm
+
+LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::SaveRestorePointEntry)
+
+namespace llvm {
+namespace yaml {
+
 /// Serializable representation of MachineFrameInfo.
 ///
 /// Doesn't serialize attributes like 'StackAlignment', 'IsStackRealignable' and
@@ -666,8 +721,8 @@ struct MachineFrameInfo {
   bool HasTailCall = false;
   bool IsCalleeSavedInfoValid = false;
   unsigned LocalFrameSize = 0;
-  StringValue SavePoint;
-  StringValue RestorePoint;
+  SaveRestorePoints SavePoints;
+  SaveRestorePoints RestorePoints;
 
   bool operator==(const MachineFrameInfo &Other) const {
     return IsFrameAddressTaken == Other.IsFrameAddressTaken &&
@@ -688,7 +743,8 @@ struct MachineFrameInfo {
            HasMustTailInVarArgFunc == Other.HasMustTailInVarArgFunc &&
            HasTailCall == Other.HasTailCall &&
            LocalFrameSize == Other.LocalFrameSize &&
-           SavePoint == Other.SavePoint && RestorePoint == Other.RestorePoint &&
+           SavePoints == Other.SavePoints &&
+           RestorePoints == Other.RestorePoints &&
            IsCalleeSavedInfoValid == Other.IsCalleeSavedInfoValid;
   }
 };
@@ -720,10 +776,14 @@ template <> struct MappingTraits<MachineFrameInfo> {
     YamlIO.mapOptional("isCalleeSavedInfoValid", MFI.IsCalleeSavedInfoValid,
                        false);
     YamlIO.mapOptional("localFrameSize", MFI.LocalFrameSize, (unsigned)0);
-    YamlIO.mapOptional("savePoint", MFI.SavePoint,
-                       StringValue()); // Don't print it out when it's empty.
-    YamlIO.mapOptional("restorePoint", MFI.RestorePoint,
-                       StringValue()); // Don't print it out when it's empty.
+    YamlIO.mapOptional(
+        "savePoint", MFI.SavePoints,
+        SaveRestorePoints(
+            StringValue())); // Don't print it out when it's empty.
+    YamlIO.mapOptional(
+        "restorePoint", MFI.RestorePoints,
+        SaveRestorePoints(
+            StringValue())); // Don't print it out when it's empty.
   }
 };
 
diff --git a/llvm/lib/CodeGen/MIRParser/MIRParser.cpp b/llvm/lib/CodeGen/MIRParser/MIRParser.cpp
index de2fe925c2d5c..e7f709135291e 100644
--- a/llvm/lib/CodeGen/MIRParser/MIRParser.cpp
+++ b/llvm/lib/CodeGen/MIRParser/MIRParser.cpp
@@ -124,6 +124,10 @@ class MIRParserImpl {
   bool initializeFrameInfo(PerFunctionMIParsingState &PFS,
                            const yaml::MachineFunction &YamlMF);
 
+  bool initializeSaveRestorePoints(PerFunctionMIParsingState &PFS,
+                                   const yaml::SaveRestorePoints &YamlSRP,
+                                   bool IsSavePoints);
+
   bool initializeCallSiteInfo(PerFunctionMIParsingState &PFS,
                               const yaml::MachineFunction &YamlMF);
 
@@ -851,18 +855,9 @@ bool MIRParserImpl::initializeFrameInfo(PerFunctionMIParsingState &PFS,
   MFI.setHasTailCall(YamlMFI.HasTailCall);
   MFI.setCalleeSavedInfoValid(YamlMFI.IsCalleeSavedInfoValid);
   MFI.setLocalFrameSize(YamlMFI.LocalFrameSize);
-  if (!YamlMFI.SavePoint.Value.empty()) {
-    MachineBasicBlock *MBB = nullptr;
-    if (parseMBBReference(PFS, MBB, YamlMFI.SavePoint))
-      return true;
-    MFI.setSavePoint(MBB);
-  }
-  if (!YamlMFI.RestorePoint.Value.empty()) {
-    MachineBasicBlock *MBB = nullptr;
-    if (parseMBBReference(PFS, MBB, YamlMFI.RestorePoint))
-      return true;
-    MFI.setRestorePoint(MBB);
-  }
+  initializeSaveRestorePoints(PFS, YamlMFI.SavePoints, true /*IsSavePoints*/);
+  initializeSaveRestorePoints(PFS, YamlMFI.RestorePoints,
+                              false /*IsSavePoints*/);
 
   std::vector<CalleeSavedInfo> CSIInfo;
   // Initialize the fixed frame objects.
@@ -1077,8 +1072,40 @@ bool MIRParserImpl::initializeConstantPool(PerFunctionMIParsingState &PFS,
   return false;
 }
 
-bool MIRParserImpl::initializeJumpTableInfo(PerFunctionMIParsingState &PFS,
-    const yaml::MachineJumpTable &YamlJTI) {
+bool MIRParserImpl::initializeSaveRestorePoints(
+    PerFunctionMIParsingState &PFS, const yaml::SaveRestorePoints &YamlSRP,
+    bool IsSavePoints) {
+  MachineBasicBlock *MBB = nullptr;
+  if (std::holds_alternative<std::vector<yaml::SaveRestorePointEntry>>(
+          YamlSRP)) {
+    const auto &VectorRepr =
+        std::get<std::vector<yaml::SaveRestorePointEntry>>(YamlSRP);
+    if (VectorRepr.empty())
+      return false;
+
+    const auto &Entry = VectorRepr.front();
+    const auto &MBBSource = Entry.Point;
+    if (parseMBBReference(PFS, MBB, MBBSource.Value))
+      return true;
+  } else {
+    yaml::StringValue StringRepr = std::get<yaml::StringValue>(YamlSRP);
+    if (StringRepr.Value.empty() || parseMBBReference(PFS, MBB, StringRepr))
+      return true;
+  }
+
+  MachineFunction &MF = PFS.MF;
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+
+  if (IsSavePoints)
+    MFI.setSavePoint(MBB);
+  else
+    MFI.setRestorePoint(MBB);
+
+  return false;
+}
+
+bool MIRParserImpl::initializeJumpTableInfo(
+    PerFunctionMIParsingState &PFS, const yaml::MachineJumpTable &YamlJTI) {
   MachineJumpTableInfo *JTI = PFS.MF.getOrCreateJumpTableInfo(YamlJTI.Kind);
   for (const auto &Entry : YamlJTI.Entries) {
     std::vector<MachineBasicBlock *> Blocks;
diff --git a/llvm/lib/CodeGen/MIRPrinter.cpp b/llvm/lib/CodeGen/MIRPrinter.cpp
index 0b41c90442a5d..8675d754aac28 100644
--- a/llvm/lib/CodeGen/MIRPrinter.cpp
+++ b/llvm/lib/CodeGen/MIRPrinter.cpp
@@ -118,6 +118,8 @@ class MIRPrinter {
                const TargetRegisterInfo *TRI);
   void convert(ModuleSlotTracker &MST, yaml::MachineFrameInfo &YamlMFI,
                const MachineFrameInfo &MFI);
+  void convert(ModuleSlotTracker &MST, yaml::SaveRestorePoints &YamlSRP,
+               MachineBasicBlock *SaveRestorePoint);
   void convert(yaml::MachineFunction &MF,
                const MachineConstantPool &ConstantPool);
   void convert(ModuleSlotTracker &MST, yaml::MachineJumpTable &YamlJTI,
@@ -397,14 +399,10 @@ void MIRPrinter::convert(ModuleSlotTracker &MST,
   YamlMFI.HasTailCall = MFI.hasTailCall();
   YamlMFI.IsCalleeSavedInfoValid = MFI.isCalleeSavedInfoValid();
   YamlMFI.LocalFrameSize = MFI.getLocalFrameSize();
-  if (MFI.getSavePoint()) {
-    raw_string_ostream StrOS(YamlMFI.SavePoint.Value);
-    StrOS << printMBBReference(*MFI.getSavePoint());
-  }
-  if (MFI.getRestorePoint()) {
-    raw_string_ostream StrOS(YamlMFI.RestorePoint.Value);
-    StrOS << printMBBReference(*MFI.getRestorePoint());
-  }
+  if (MFI.getSavePoint())
+    convert(MST, YamlMFI.SavePoints, MFI.getSavePoint());
+  if (MFI.getRestorePoint())
+    convert(MST, YamlMFI.RestorePoints, MFI.getRestorePoint());
 }
 
 void MIRPrinter::convertEntryValueObjects(yaml::MachineFunction &YMF,
@@ -646,6 +644,18 @@ void MIRPrinter::convert(yaml::MachineFunction &MF,
   }
 }
 
+void MIRPrinter::convert(ModuleSlotTracker &MST,
+                         yaml::SaveRestorePoints &YamlSRP,
+                         MachineBasicBlock *SRP) {
+  std::string Str;
+  yaml::SaveRestorePointEntry Entry;
+  raw_string_ostream StrOS(Str);
+  StrOS << printMBBReference(*SRP);
+  Entry.Point = StrOS.str();
+  auto &Points = std::get<std::vector<yaml::SaveRestorePointEntry>>(YamlSRP);
+  Points.push_back(Entry);
+}
+
 void MIRPrinter::convert(ModuleSlotTracker &MST,
                          yaml::MachineJumpTable &YamlJTI,
                          const MachineJumpTableInfo &JTI) {
diff --git a/llvm/test/CodeGen/AArch64/dont-shrink-wrap-stack-mayloadorstore.mir b/llvm/test/CodeGen/AArch64/dont-shrink-wrap-stack-mayloadorstore.mir
index 1c4447bffd872..6122cf843c230 100644
--- a/llvm/test/CodeGen/AArch64/dont-shrink-wrap-stack-mayloadorstore.mir
+++ b/llvm/test/CodeGen/AArch64/dont-shrink-wrap-stack-mayloadorstore.mir
@@ -6,17 +6,23 @@
  ; RUN: llc -x=mir -simplify-mir -run-pass=shrink-wrap -o - %s | FileCheck %s
  ; CHECK:      name:            compiler_pop_stack
  ; CHECK:      frameInfo:
- ; CHECK:      savePoint:       '%bb.1'
- ; CHECK:      restorePoint:    '%bb.7'
+ ; CHECK:        savePoint:
+ ; CHECK-NEXT:     - point:           '%bb.1'
+ ; CHECK:        restorePoint:
+ ; CHECK-NEXT:     - point:           '%bb.7'
  ; CHECK:      name:            compiler_pop_stack_no_memoperands
  ; CHECK:      frameInfo:
- ; CHECK:      savePoint:       '%bb.1'
- ; CHECK:      restorePoint:    '%bb.7'
+ ; CHECK:        savePoint:
+ ; CHECK-NEXT:     - point:           '%bb.1'
+ ; CHECK:        restorePoint:
+ ; CHECK-NEXT:     - point:           '%bb.7'
  ; CHECK:      name:            f
  ; CHECK:      frameInfo:
- ; CHECK:      savePoint:       '%bb.2'
- ; CHECK-NEXT: restorePoint:    '%bb.4'
- ; CHECK-NEXT: stack:
+ ; CHECK:        savePoint:
+ ; CHECK-NEXT:     - point:           '%bb.2'
+ ; CHECK:        restorePoint:
+ ; CHECK-NEXT:     - point:           '%bb.4'
+ ; CHECK:      stack:
 
   target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
   target triple = "aarch64"
diff --git a/llvm/test/CodeGen/ARM/invalidated-save-point.ll b/llvm/test/CodeGen/ARM/invalidated-save-point.ll
index bb602308a1793..4179316572990 100644
--- a/llvm/test/CodeGen/ARM/invalidated-save-point.ll
+++ b/llvm/test/CodeGen/ARM/invalidated-save-point.ll
@@ -4,8 +4,8 @@
 ; this point. Notably, if it isn't is will be invalid and reference a
 ; deleted block (%bb.-1.if.end)
 
-; CHECK: savePoint: ''
-; CHECK: restorePoint: ''
+; CHECK: savePoint: []
+; CHECK: restorePoint: []
 
 target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64"
 target triple = "thumbv7"
diff --git a/llvm/test/CodeGen/MIR/Generic/frame-info.mir b/llvm/test/CodeGen/MIR/Generic/frame-info.mir
index d5e014cf62991..e8f3a83fcda89 100644
--- a/llvm/test/CodeGen/MIR/Generic/frame-info.mir
+++ b/llvm/test/CodeGen/MIR/Generic/frame-info.mir
@@ -46,8 +46,8 @@ tracksRegLiveness: true
 # CHECK-NEXT: hasTailCall: false
 # CHECK-NEXT: isCalleeSavedInfoValid: false
 # CHECK-NEXT: localFrameSize: 0
-# CHECK-NEXT: savePoint:       ''
-# CHECK-NEXT: restorePoint:    ''
+# CHECK-NEXT: savePoint:       []
+# CHECK-NEXT: restorePoint:    []
 # CHECK: body
 frameInfo:
   maxAlignment:    4
diff --git a/llvm/test/CodeGen/MIR/X86/frame-info-save-restore-points.mir b/llvm/test/CodeGen/MIR/X86/frame-info-save-restore-points.mir
index e26233f946606..bd2d45046123a 100644
--- a/llvm/test/CodeGen/MIR/X86/frame-info-save-restore-points.mir
+++ b/llvm/test/CodeGen/MIR/X86/frame-info-save-restore-points.mir
@@ -30,8 +30,10 @@ liveins:
   - { reg: '$edi' }
   - { reg: '$esi' }
 # CHECK: frameInfo:
-# CHECK:      savePoint: '%bb.2'
-# CHECK-NEXT: restorePoint: '%bb.2'
+# CHECK:      savePoint:
+# CHECK-NEXT:   - point:           '%bb.2'
+# CHECK:      restorePoint:
+# CHECK-NEXT:   - point:           '%bb.2'
 # CHECK: stack
 frameInfo:
   maxAlignment:  4
diff --git a/llvm/test/CodeGen/X86/shrink_wrap_dbg_value.mir b/llvm/test/CodeGen/X86/shrink_wrap_dbg_value.mir
index aa7befc18d4fe..66110c75be145 100644
--- a/llvm/test/CodeGen/X86/shrink_wrap_dbg_value.mir
+++ b/llvm/test/CodeGen/X86/shrink_wrap_dbg_value.mir
@@ -119,8 +119,10 @@ frameInfo:
   hasOpaqueSPAdjustment: false
   hasVAStart:      false
   hasMustTailInVarArgFunc: false
-  # CHECK: savePoint:       '%bb.1'
-  # CHECK: restorePoint:    '%bb.3'
+  # CHECK:      savePoint:
+  # CHECK-NEXT:   - point:           '%bb.1'
+  # CHECK:      restorePoint:
+  # CHECK-NEXT:   - point:           '%bb.3'
   savePoint:       ''
   restorePoint:    ''
 fixedStack:      
diff --git a/llvm/test/tools/llvm-reduce/mir/preserve-frame-info.mir b/llvm/test/tools/llvm-reduce/mir/preserve-frame-info.mir
index d7ad5f88874d7..957c90929d3fe 100644
--- a/llvm/test/tools/llvm-reduce/mir/preserve-frame-info.mir
+++ b/llvm/test/tools/llvm-reduce/mir/preserve-frame-info.mir
@@ -20,8 +20,10 @@
 # RESULT-NEXT: hasVAStart:      true
 # RESULT-NEXT: hasMustTailInVarArgFunc: true
 # RESULT-NEXT: hasTailCall:     true
-# RESULT-NEXT: savePoint:       '%bb.1'
-# RESULT-NEXT: restorePoint:    '%bb.2'
+# RESULT-NEXT: savePoint:
+# RESULT-NEXT:   - point:           '%bb.1'
+# RESULT-NEXT: restorePoint:
+# RESULT-NEXT:   - point:           '%bb.1'
 
 # RESULT-NEXT: fixedStack:
 # RESULT-NEXT:  - { id: 0, offset: 56, size: 4, alignment: 8, callee-saved-register: '$sgpr44',
@@ -117,7 +119,7 @@ frameInfo:
   hasTailCall:     true
   localFrameSize:  0
   savePoint:       '%bb.1'
-  restorePoint:    '%bb.2'
+  restorePoint:    '%bb.1'
 
 fixedStack:
   - { id: 0, offset: 0, size: 8, alignment: 4, isImmutable: true, isAliased: false }



More information about the llvm-commits mailing list