[llvm] [AArch64] Lower EH_RETURN (PR #76775)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 2 21:04:04 PST 2024


https://github.com/SihangZhu updated https://github.com/llvm/llvm-project/pull/76775

>From bff55af33bf35573501a5bae1c34eb4e989ee2fa Mon Sep 17 00:00:00 2001
From: SihangZhu <zhusihang at huawei.com>
Date: Wed, 3 Jan 2024 09:29:13 +0800
Subject: [PATCH] [AArch64] Lower EH_RETURN

---
 .../Target/AArch64/AArch64FrameLowering.cpp   | 50 ++++++++++++++++++-
 .../Target/AArch64/AArch64ISelLowering.cpp    | 25 ++++++++++
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  4 ++
 llvm/lib/Target/AArch64/AArch64InstrInfo.cpp  | 16 +++++-
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   | 12 +++++
 .../AArch64/AArch64MachineFunctionInfo.cpp    | 21 ++++++++
 .../AArch64/AArch64MachineFunctionInfo.h      | 14 ++++++
 7 files changed, 140 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index caab59201a8d69..98a6021be39c94 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -455,7 +455,7 @@ bool AArch64FrameLowering::hasFP(const MachineFunction &MF) const {
     return true;
   if (MFI.hasVarSizedObjects() || MFI.isFrameAddressTaken() ||
       MFI.hasStackMap() || MFI.hasPatchPoint() ||
-      RegInfo->hasStackRealignment(MF))
+      RegInfo->hasStackRealignment(MF) || AFI->callsEhReturn())
     return true;
   // With large callframes around we may need to use FP to access the scavenging
   // emergency spillslot.
@@ -1095,6 +1095,8 @@ bool AArch64FrameLowering::shouldCombineCSRLocalStackBump(
   if (AFI->getLocalStackSize() == 0)
     return false;
 
+  if (AFI->callsEhReturn())
+    return false;
   // For WinCFI, if optimizing for size, prefer to not combine the stack bump
   // (to force a stp with predecrement) to match the packed unwind format,
   // provided that there actually are any callee saved registers to merge the
@@ -2054,6 +2056,26 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF,
                        CFAOffset, MFI.hasVarSizedObjects());
   }
 
+  // Insert instructions that spill eh data registers.
+  if (AFI->callsEhReturn()) {
+    for (int I = 0; I < 4; ++I) {
+      if (!MBB.isLiveIn(AFI->GetEhDataReg(I)))
+        MBB.addLiveIn(AFI->GetEhDataReg(I));
+    }
+    BuildMI(MBB, MBBI, DL, TII->get(AArch64::STPXi))
+        .addReg(AFI->GetEhDataReg(0), RegState::Define)
+        .addReg(AFI->GetEhDataReg(1), RegState::Define)
+        .addReg(AArch64::SP)
+        .addImm(0)
+        .setMIFlags(MachineInstr::FrameSetup);
+    BuildMI(MBB, MBBI, DL, TII->get(AArch64::STPXi))
+        .addReg(AFI->GetEhDataReg(2), RegState::Define)
+        .addReg(AFI->GetEhDataReg(3), RegState::Define)
+        .addReg(AArch64::SP)
+        .addImm(2)
+        .setMIFlags(MachineInstr::FrameSetup);
+  }
+
   // If we need a base pointer, set it up here. It's whatever the value of the
   // stack pointer is at this point. Any variable size objects will be allocated
   // after this, so we can still use the base pointer to reference locals.
@@ -2266,6 +2288,13 @@ void AArch64FrameLowering::emitEpilogue(MachineFunction &MF,
     --EpilogStartI;
   }
 
+  if (AFI->callsEhReturn()) {
+    BuildMI(MBB, LastPopI, DL, TII->get(AArch64::STURXi))
+        .addReg(AFI->GetEhDataReg(0), RegState::Define)
+        .addReg(AArch64::FP)
+        .addImm((NumBytes - PrologueSaveSize) / 8);
+  }
+
   if (hasFP(MF) && AFI->hasSwiftAsyncContext()) {
     switch (MF.getTarget().Options.SwiftAsyncFramePointer) {
     case SwiftAsyncFramePointerMode::DeploymentBased:
@@ -2439,6 +2468,21 @@ void AArch64FrameLowering::emitEpilogue(MachineFunction &MF,
         .setMIFlags(MachineInstr::FrameDestroy);
   }
 
+  if (AFI->callsEhReturn()) {
+    BuildMI(MBB, LastPopI, DL, TII->get(AArch64::LDPXi))
+        .addReg(AFI->GetEhDataReg(0), RegState::Define)
+        .addReg(AFI->GetEhDataReg(1), RegState::Define)
+        .addReg(AArch64::FP)
+        .addImm(-NumBytes / 8)
+        .setMIFlags(MachineInstr::FrameDestroy);
+    BuildMI(MBB, LastPopI, DL, TII->get(AArch64::LDPXi))
+        .addReg(AFI->GetEhDataReg(2), RegState::Define)
+        .addReg(AFI->GetEhDataReg(3), RegState::Define)
+        .addReg(AArch64::FP)
+        .addImm(-NumBytes / 8 + 2)
+        .setMIFlags(MachineInstr::FrameDestroy);
+  }
+
   // This must be placed after the callee-save restore code because that code
   // assumes the SP is at the same location as it was after the callee-save save
   // code in the prologue.
@@ -3312,6 +3356,10 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
     SavedRegs.set(AArch64::LR);
   }
 
+  // Create spill slots for eh data registers if function calls eh_return.
+  if (AFI->callsEhReturn())
+    AFI->createEhDataRegsFI(MF);
+
   LLVM_DEBUG(dbgs() << "*** determineCalleeSaves\nSaved CSRs:";
              for (unsigned Reg
                   : SavedRegs.set_bits()) dbgs()
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 102fd0c3dae2ab..ae06b5df21b4a3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -553,6 +553,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i32, Custom);
   setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i64, Custom);
 
+  setOperationAction(ISD::EH_RETURN, MVT::Other, Custom);
   // Variable arguments.
   setOperationAction(ISD::VASTART, MVT::Other, Custom);
   setOperationAction(ISD::VAARG, MVT::Other, Custom);
@@ -2502,6 +2503,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::RDSVL)
     MAKE_CASE(AArch64ISD::BIC)
     MAKE_CASE(AArch64ISD::BIT)
+    MAKE_CASE(AArch64ISD::EH_RETURN)
     MAKE_CASE(AArch64ISD::CBZ)
     MAKE_CASE(AArch64ISD::CBNZ)
     MAKE_CASE(AArch64ISD::TBZ)
@@ -5484,6 +5486,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
 
     return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
   }
+  case ISD::EH_RETURN:
+    return LowerEH_RETURN(Op, DAG);
   }
 }
 
@@ -6595,6 +6599,27 @@ AArch64TargetLowering::allocateLazySaveBuffer(SDValue &Chain, const SDLoc &DL,
   return TPIDR2Obj;
 }
 
+SDValue AArch64TargetLowering::LowerEH_RETURN(SDValue Op,
+                                              SelectionDAG &DAG) const {
+  SDValue Chain     = Op.getOperand(0);
+  SDValue Offset    = Op.getOperand(1);
+  SDValue Handler   = Op.getOperand(2);
+  SDLoc dl(Op);
+
+  auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
+  AFI->setCallsEhReturn();
+
+  EVT PtrVT = getPointerTy(DAG.getDataLayout());
+  Register OffsetReg = AArch64::X4;
+  Register AddrReg = AArch64::X0;
+  Chain = DAG.getCopyToReg(Chain, dl, OffsetReg, Offset);
+  Chain = DAG.getCopyToReg(Chain, dl, AddrReg, Handler);
+
+  return DAG.getNode(AArch64ISD::EH_RETURN, dl, MVT::Other, Chain,
+                     DAG.getRegister(OffsetReg, MVT::i64),
+                     DAG.getRegister(AddrReg, PtrVT));
+}
+
 SDValue AArch64TargetLowering::LowerFormalArguments(
     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
     const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 6ddbcd41dcb769..23a68a6b2eae5b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -279,6 +279,9 @@ enum NodeType : unsigned {
   // Vector bitwise insertion
   BIT,
 
+  // Exception Handling helpers.
+  EH_RETURN,
+
   // Compare-and-branch
   CBZ,
   CBNZ,
@@ -1077,6 +1080,7 @@ class AArch64TargetLowering : public TargetLowering {
   template <class NodeTy>
   SDValue getAddrTiny(NodeTy *N, SelectionDAG &DAG, unsigned Flags = 0) const;
   SDValue LowerADDROFRETURNADDR(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerEH_RETURN(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerDarwinGlobalTLSAddress(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 1cfbf4737a6f72..11f56943e95bd1 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1943,7 +1943,8 @@ bool AArch64InstrInfo::removeCmpToZeroOrOne(
 
 bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
   if (MI.getOpcode() != TargetOpcode::LOAD_STACK_GUARD &&
-      MI.getOpcode() != AArch64::CATCHRET)
+      MI.getOpcode() != AArch64::CATCHRET &&
+      MI.getOpcode() != AArch64::EH_RETURN)
     return false;
 
   MachineBasicBlock &MBB = *MI.getParent();
@@ -1974,6 +1975,19 @@ bool AArch64InstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
     return true;
   }
 
+  if (MI.getOpcode() == AArch64::EH_RETURN) {
+    Register OffsetReg = MI.getOperand(0).getReg();
+    BuildMI(MBB, MI, DL, get(AArch64::ADDXrx64))
+        .addReg(AArch64::SP)
+        .addReg(AArch64::SP)
+        .addReg(OffsetReg)
+        .addImm(0)
+        .setMIFlags(MachineInstr::FrameDestroy);
+    BuildMI(MBB, MI, DL, get(AArch64::RET)).addReg(AArch64::LR);
+    MBB.erase(MI);
+    return true;
+  }
+
   Register Reg = MI.getOperand(0).getReg();
   Module &M = *MBB.getParent()->getFunction().getParent();
   if (M.getStackProtectorGuard() == "sysreg") {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 62b2bf490f37a2..2266737b09b394 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -401,6 +401,9 @@ def SDT_AArch64binvec : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
 def SDT_AArch64trivec : SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisSameAs<0,1>,
                                            SDTCisSameAs<0,2>,
                                            SDTCisSameAs<0,3>]>;
+
+def SDT_AArch64EHRET : SDTypeProfile<0, 2, [SDTCisInt<0>, SDTCisPtrTy<1>]>;
+
 def SDT_AArch64TCRET : SDTypeProfile<0, 2, [SDTCisPtrTy<0>]>;
 def SDT_AArch64PREFETCH : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, SDTCisPtrTy<1>]>;
 
@@ -755,6 +758,9 @@ def AArch64fcmltz: SDNode<"AArch64ISD::FCMLTz", SDT_AArch64fcmpz>;
 def AArch64bici: SDNode<"AArch64ISD::BICi", SDT_AArch64vecimm>;
 def AArch64orri: SDNode<"AArch64ISD::ORRi", SDT_AArch64vecimm>;
 
+def AArch64ehret : SDNode<"AArch64ISD::EH_RETURN", SDT_AArch64EHRET,
+                          [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
+
 def AArch64tcret: SDNode<"AArch64ISD::TC_RETURN", SDT_AArch64TCRET,
                   [SDNPHasChain,  SDNPOptInGlue, SDNPVariadic]>;
 
@@ -4794,6 +4800,12 @@ let isPseudo = 1 in {
 
 // Pseudo instructions for Windows EH
 //===----------------------------------------------------------------------===//
+let isTerminator = 1, isReturn = 1, isBarrier = 1,
+   hasCtrlDep = 1, isCodeGenOnly = 1, isPseudo = 1, AsmString = "ret" in {
+def EH_RETURN   : Pseudo<(outs), (ins GPR64sp:$addr, GPR64:$dst),
+                         [(AArch64ehret GPR64sp:$addr, GPR64:$dst)]>, Sched<[]>;
+}
+
 let isTerminator = 1, hasSideEffects = 1, isBarrier = 1, hasCtrlDep = 1,
     isCodeGenOnly = 1, isReturn = 1, isEHScopeReturn = 1, isPseudo = 1 in {
    def CLEANUPRET : Pseudo<(outs), (ins), [(cleanupret)]>, Sched<[]>;
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
index 1a8c71888a852f..8c34907897e121 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
@@ -205,3 +205,24 @@ bool AArch64FunctionInfo::needsAsyncDwarfUnwindInfo(
   }
   return *NeedsAsyncDwarfUnwindInfo;
 }
+
+bool AArch64FunctionInfo::isEhDataRegFI(int FI) const {
+  return CallsEhReturn && (FI == EhDataRegFI[0] || FI == EhDataRegFI[1] ||
+                           FI == EhDataRegFI[2] || FI == EhDataRegFI[3]);
+}
+
+void AArch64FunctionInfo::createEhDataRegsFI(MachineFunction &MF) {
+  const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
+  for (int &I : EhDataRegFI) {
+    const TargetRegisterClass &RC = AArch64::GPR64RegClass;
+
+    I = MF.getFrameInfo().CreateStackObject(TRI.getSpillSize(RC),
+                                            TRI.getSpillAlign(RC), false);
+  }
+}
+
+unsigned AArch64FunctionInfo::GetEhDataReg(unsigned I) const {
+  static const unsigned EhDataReg[] = {AArch64::X0, AArch64::X1, AArch64::X2,
+                                       AArch64::X3};
+  return EhDataReg[I];
+}
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index cd4a18bfbc23a8..927e8b68458d99 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -120,6 +120,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   /// which the sret argument is passed.
   Register SRetReturnReg;
 
+  /// CallsEhReturn - Whether the function calls llvm.eh.return.
+  bool CallsEhReturn = false;
+
+  /// Frame objects for spilling eh data registers.
+  int EhDataRegFI[4];
+
   /// SVE stack size (for predicates and data vectors) are maintained here
   /// rather than in FrameInfo, as the placement and Stack IDs are target
   /// specific.
@@ -274,6 +280,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
     HasCalleeSavedStackSize = true;
   }
 
+  bool callsEhReturn() const { return CallsEhReturn; }
+  void setCallsEhReturn() { CallsEhReturn = true; }
+
+  void createEhDataRegsFI(MachineFunction &MF);
+  int getEhDataRegFI(unsigned Reg) const { return EhDataRegFI[Reg]; }
+  bool isEhDataRegFI(int FI) const;
+  unsigned GetEhDataReg(unsigned I) const;
+
   // When CalleeSavedStackSize has not been set (for example when
   // some MachineIR pass is run in isolation), then recalculate
   // the CalleeSavedStackSize directly from the CalleeSavedInfo.



More information about the llvm-commits mailing list