[llvm] 1d2b558 - [AArch64][PAC] Check authenticated LR value during tail call

Anatoly Trosinenko via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 11 07:39:41 PDT 2023


Author: Anatoly Trosinenko
Date: 2023-10-11T17:38:17+03:00
New Revision: 1d2b558265bd9c9c50599b78e210eeebc78a1ae3

URL: https://github.com/llvm/llvm-project/commit/1d2b558265bd9c9c50599b78e210eeebc78a1ae3
DIFF: https://github.com/llvm/llvm-project/commit/1d2b558265bd9c9c50599b78e210eeebc78a1ae3.diff

LOG: [AArch64][PAC] Check authenticated LR value during tail call

When performing a tail call, check the value of LR register after
authentication to prevent the callee from signing and spilling an
untrusted value. This commit implements a few variants of check,
more can be added later.

If it is safe to assume that executable pages are always readable,
LR can be checked just by dereferencing the LR value via LDR.

As an alternative, LR can be checked as follows:

    ; lowered AUT* instruction
    ; <some variant of check that LR contains a valid address>
    b.cond break_block
  ret_block:
    ; lowered TCRETURN
  break_block:
    brk 0xc471

As the existing methods either break the compatibility with execute-only
memory mappings or can degrade the performance, they are disabled by
default and can be explicitly enabled with a command line option.

Individual subtargets can opt-in to use one of the available methods
by updating AArch64FrameLowering::getAuthenticatedLRCheckMethod().

Reviewed By: kristof.beyls

Differential Revision: https://reviews.llvm.org/D156716

Added: 
    llvm/lib/Target/AArch64/AArch64PointerAuth.h
    llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.h
    llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
    llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
    llvm/lib/Target/AArch64/AArch64PointerAuth.cpp
    llvm/lib/Target/AArch64/AArch64Subtarget.cpp
    llvm/lib/Target/AArch64/AArch64Subtarget.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index 78f85faaf69bf96..e68d67c6e78de2c 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -269,14 +269,10 @@ STATISTIC(NumRedZoneFunctions, "Number of functions using red zone");
 static int64_t getArgumentStackToRestore(MachineFunction &MF,
                                          MachineBasicBlock &MBB) {
   MachineBasicBlock::iterator MBBI = MBB.getLastNonDebugInstr();
-  bool IsTailCallReturn = false;
-  if (MBB.end() != MBBI) {
-    unsigned RetOpcode = MBBI->getOpcode();
-    IsTailCallReturn = RetOpcode == AArch64::TCRETURNdi ||
-                       RetOpcode == AArch64::TCRETURNri ||
-                       RetOpcode == AArch64::TCRETURNriBTI;
-  }
   AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
+  bool IsTailCallReturn = (MBB.end() != MBBI)
+                              ? AArch64InstrInfo::isTailCallReturnInst(*MBBI)
+                              : false;
 
   int64_t ArgumentPopSize = 0;
   if (IsTailCallReturn) {
@@ -300,7 +296,6 @@ static int64_t getArgumentStackToRestore(MachineFunction &MF,
 static bool produceCompactUnwindFrame(MachineFunction &MF);
 static bool needsWinCFI(const MachineFunction &MF);
 static StackOffset getSVEStackSize(const MachineFunction &MF);
-static bool needsShadowCallStackPrologueEpilogue(MachineFunction &MF);
 
 /// Returns true if a homogeneous prolog or epilog code can be emitted
 /// for the size optimization. If possible, a frame helper call is injected.
@@ -617,7 +612,7 @@ void AArch64FrameLowering::resetCFIToInitialState(
   }
 
   // Shadow call stack uses X18, reset it.
-  if (needsShadowCallStackPrologueEpilogue(MF))
+  if (MFI.needsShadowCallStackPrologueEpilogue(MF))
     insertCFISameValue(CFIDesc, MF, MBB, InsertPt,
                        TRI.getDwarfRegNum(AArch64::X18, true));
 
@@ -1290,19 +1285,6 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
   }
 }
 
-static bool needsShadowCallStackPrologueEpilogue(MachineFunction &MF) {
-  if (!(llvm::any_of(
-            MF.getFrameInfo().getCalleeSavedInfo(),
-            [](const auto &Info) { return Info.getReg() == AArch64::LR; }) &&
-        MF.getFunction().hasFnAttribute(Attribute::ShadowCallStack)))
-    return false;
-
-  if (!MF.getSubtarget<AArch64Subtarget>().isXRegisterReserved(18))
-    report_fatal_error("Must reserve x18 to use shadow call stack");
-
-  return true;
-}
-
 static void emitShadowCallStackPrologue(const TargetInstrInfo &TII,
                                         MachineFunction &MF,
                                         MachineBasicBlock &MBB,
@@ -1414,7 +1396,7 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF,
   DebugLoc DL;
 
   const auto &MFnI = *MF.getInfo<AArch64FunctionInfo>();
-  if (needsShadowCallStackPrologueEpilogue(MF))
+  if (MFnI.needsShadowCallStackPrologueEpilogue(MF))
     emitShadowCallStackPrologue(*TII, MF, MBB, MBBI, DL, NeedsWinCFI,
                                 MFnI.needsDwarfUnwindInfo(MF));
 
@@ -1945,7 +1927,7 @@ void AArch64FrameLowering::emitEpilogue(MachineFunction &MF,
       if (NeedsWinCFI)
         HasWinCFI = true; // AArch64PointerAuth pass will insert SEH_PACSignLR
     }
-    if (needsShadowCallStackPrologueEpilogue(MF))
+    if (AFI->needsShadowCallStackPrologueEpilogue(MF))
       emitShadowCallStackEpilogue(*TII, MF, MBB, MBB.getFirstTerminator(), DL);
     if (EmitCFI)
       emitCalleeSavedGPRRestores(MBB, MBB.getFirstTerminator());

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 5a234bceb25ed0a..c804312da5369ef 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -14,6 +14,7 @@
 #include "AArch64InstrInfo.h"
 #include "AArch64FrameLowering.h"
 #include "AArch64MachineFunctionInfo.h"
+#include "AArch64PointerAuth.h"
 #include "AArch64Subtarget.h"
 #include "MCTargetDesc/AArch64AddressingModes.h"
 #include "Utils/AArch64BaseInfo.h"
@@ -2490,6 +2491,20 @@ bool AArch64InstrInfo::isPairableLdStInst(const MachineInstr &MI) {
   }
 }
 
+bool AArch64InstrInfo::isTailCallReturnInst(const MachineInstr &MI) {
+  switch (MI.getOpcode()) {
+  default:
+    assert((!MI.isCall() || !MI.isReturn()) &&
+           "Unexpected instruction - was a new tail call opcode introduced?");
+    return false;
+  case AArch64::TCRETURNdi:
+  case AArch64::TCRETURNri:
+  case AArch64::TCRETURNriBTI:
+  case AArch64::TCRETURNriALL:
+    return true;
+  }
+}
+
 unsigned AArch64InstrInfo::convertToFlagSettingOpc(unsigned Opc) {
   switch (Opc) {
   default:
@@ -8217,12 +8232,24 @@ AArch64InstrInfo::getOutliningCandidateInfo(
   // necessary. However, at this point we don't know if the outlined function
   // will have a RET instruction so we assume the worst.
   const TargetRegisterInfo &TRI = getRegisterInfo();
+  // Performing a tail call may require extra checks when PAuth is enabled.
+  // If PAuth is disabled, set it to zero for uniformity.
+  unsigned NumBytesToCheckLRInTCEpilogue = 0;
   if (FirstCand.getMF()
           ->getInfo<AArch64FunctionInfo>()
           ->shouldSignReturnAddress(true)) {
     // One PAC and one AUT instructions
     NumBytesToCreateFrame += 8;
 
+    // PAuth is enabled - set extra tail call cost, if any.
+    auto LRCheckMethod = Subtarget.getAuthenticatedLRCheckMethod();
+    NumBytesToCheckLRInTCEpilogue =
+        AArch64PAuth::getCheckerSizeInBytes(LRCheckMethod);
+    // Checking the authenticated LR value may significantly impact
+    // SequenceSize, so account for it for more precise results.
+    if (isTailCallReturnInst(*RepeatedSequenceLocs[0].back()))
+      SequenceSize += NumBytesToCheckLRInTCEpilogue;
+
     // We have to check if sp modifying instructions would get outlined.
     // If so we only allow outlining if sp is unchanged overall, so matching
     // sub and add instructions are okay to outline, all other sp modifications
@@ -8393,7 +8420,8 @@ AArch64InstrInfo::getOutliningCandidateInfo(
   if (RepeatedSequenceLocs[0].back()->isTerminator()) {
     FrameID = MachineOutlinerTailCall;
     NumBytesToCreateFrame = 0;
-    SetCandidateCallInfo(MachineOutlinerTailCall, 4);
+    unsigned NumBytesForCall = 4 + NumBytesToCheckLRInTCEpilogue;
+    SetCandidateCallInfo(MachineOutlinerTailCall, NumBytesForCall);
   }
 
   else if (LastInstrOpcode == AArch64::BL ||
@@ -8402,7 +8430,7 @@ AArch64InstrInfo::getOutliningCandidateInfo(
             !HasBTI)) {
     // FIXME: Do we need to check if the code after this uses the value of LR?
     FrameID = MachineOutlinerThunk;
-    NumBytesToCreateFrame = 0;
+    NumBytesToCreateFrame = NumBytesToCheckLRInTCEpilogue;
     SetCandidateCallInfo(MachineOutlinerThunk, 4);
   }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index f5874d7856f8d24..4a40b2fa122159f 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -129,6 +129,9 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
   /// Return true if pairing the given load or store may be paired with another.
   static bool isPairableLdStInst(const MachineInstr &MI);
 
+  /// Returns true if MI is one of the TCRETURN* instructions.
+  static bool isTailCallReturnInst(const MachineInstr &MI);
+
   /// Return the opcode that set flags when possible.  The caller is
   /// responsible for ensuring the opc has a flag setting equivalent.
   static unsigned convertToFlagSettingOpc(unsigned Opc);

diff  --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
index 961a19317d6660b..7bb5041b8ba9481 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
@@ -122,11 +122,27 @@ bool AArch64FunctionInfo::shouldSignReturnAddress(bool SpillsLR) const {
   return SpillsLR;
 }
 
+static bool isLRSpilled(const MachineFunction &MF) {
+  return llvm::any_of(
+      MF.getFrameInfo().getCalleeSavedInfo(),
+      [](const auto &Info) { return Info.getReg() == AArch64::LR; });
+}
+
 bool AArch64FunctionInfo::shouldSignReturnAddress(
     const MachineFunction &MF) const {
-  return shouldSignReturnAddress(llvm::any_of(
-      MF.getFrameInfo().getCalleeSavedInfo(),
-      [](const auto &Info) { return Info.getReg() == AArch64::LR; }));
+  return shouldSignReturnAddress(isLRSpilled(MF));
+}
+
+bool AArch64FunctionInfo::needsShadowCallStackPrologueEpilogue(
+    MachineFunction &MF) const {
+  if (!(isLRSpilled(MF) &&
+        MF.getFunction().hasFnAttribute(Attribute::ShadowCallStack)))
+    return false;
+
+  if (!MF.getSubtarget<AArch64Subtarget>().isXRegisterReserved(18))
+    report_fatal_error("Must reserve x18 to use shadow call stack");
+
+  return true;
 }
 
 bool AArch64FunctionInfo::needsDwarfUnwindInfo(

diff  --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index 8df95ff1e6eaea2..0b8bfb04a572c77 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -431,6 +431,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   bool shouldSignReturnAddress(const MachineFunction &MF) const;
   bool shouldSignReturnAddress(bool SpillsLR) const;
 
+  bool needsShadowCallStackPrologueEpilogue(MachineFunction &MF) const;
+
   bool shouldSignWithBKey() const { return SignWithBKey; }
   bool isMTETagged() const { return IsMTETagged; }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp b/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp
index f8f11b43e35be51..f9b3027c35bb3dd 100644
--- a/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp
+++ b/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp
@@ -6,7 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "AArch64PointerAuth.h"
+
 #include "AArch64.h"
+#include "AArch64InstrInfo.h"
 #include "AArch64MachineFunctionInfo.h"
 #include "AArch64Subtarget.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
@@ -14,6 +17,7 @@
 #include "llvm/CodeGen/MachineModuleInfo.h"
 
 using namespace llvm;
+using namespace llvm::AArch64PAuth;
 
 #define AARCH64_POINTER_AUTH_NAME "AArch64 Pointer Authentication"
 
@@ -30,13 +34,19 @@ class AArch64PointerAuth : public MachineFunctionPass {
   StringRef getPassName() const override { return AARCH64_POINTER_AUTH_NAME; }
 
 private:
+  /// An immediate operand passed to BRK instruction, if it is ever emitted.
+  const unsigned BrkOperand = 0xc471;
+
   const AArch64Subtarget *Subtarget = nullptr;
   const AArch64InstrInfo *TII = nullptr;
+  const AArch64RegisterInfo *TRI = nullptr;
 
   void signLR(MachineFunction &MF, MachineBasicBlock::iterator MBBI) const;
 
   void authenticateLR(MachineFunction &MF,
                       MachineBasicBlock::iterator MBBI) const;
+
+  bool checkAuthenticatedLR(MachineBasicBlock::iterator TI) const;
 };
 
 } // end anonymous namespace
@@ -132,22 +142,179 @@ void AArch64PointerAuth::authenticateLR(
   }
 }
 
+namespace {
+
+// Mark dummy LDR instruction as volatile to prevent removing it as dead code.
+MachineMemOperand *createCheckMemOperand(MachineFunction &MF,
+                                         const AArch64Subtarget &Subtarget) {
+  MachinePointerInfo PointerInfo(Subtarget.getAddressCheckPSV());
+  auto MOVolatileLoad =
+      MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile;
+
+  return MF.getMachineMemOperand(PointerInfo, MOVolatileLoad, 4, Align(4));
+}
+
+} // namespace
+
+MachineBasicBlock &llvm::AArch64PAuth::checkAuthenticatedRegister(
+    MachineBasicBlock::iterator MBBI, AuthCheckMethod Method,
+    Register AuthenticatedReg, Register TmpReg, bool UseIKey, unsigned BrkImm) {
+
+  MachineBasicBlock &MBB = *MBBI->getParent();
+  MachineFunction &MF = *MBB.getParent();
+  const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>();
+  const AArch64InstrInfo *TII = Subtarget.getInstrInfo();
+  DebugLoc DL = MBBI->getDebugLoc();
+
+  // First, handle the methods not requiring creating extra MBBs.
+  switch (Method) {
+  default:
+    break;
+  case AuthCheckMethod::None:
+    return MBB;
+  case AuthCheckMethod::DummyLoad:
+    BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDRWui), getWRegFromXReg(TmpReg))
+        .addReg(AArch64::LR)
+        .addImm(0)
+        .addMemOperand(createCheckMemOperand(MF, Subtarget));
+    return MBB;
+  }
+
+  // Control flow has to be changed, so arrange new MBBs.
+
+  // At now, at least an AUT* instruction is expected before MBBI
+  assert(MBBI != MBB.begin() &&
+         "Cannot insert the check at the very beginning of MBB");
+  // The block to insert check into.
+  MachineBasicBlock *CheckBlock = &MBB;
+  // The remaining part of the original MBB that is executed on success.
+  MachineBasicBlock *SuccessBlock = MBB.splitAt(*std::prev(MBBI));
+
+  // The block that explicitly generates a break-point exception on failure.
+  MachineBasicBlock *BreakBlock =
+      MF.CreateMachineBasicBlock(MBB.getBasicBlock());
+  MF.push_back(BreakBlock);
+  MBB.splitSuccessor(SuccessBlock, BreakBlock);
+
+  assert(CheckBlock->getFallThrough() == SuccessBlock);
+  BuildMI(BreakBlock, DL, TII->get(AArch64::BRK)).addImm(BrkImm);
+
+  switch (Method) {
+  case AuthCheckMethod::None:
+  case AuthCheckMethod::DummyLoad:
+    llvm_unreachable("Should be handled above");
+  case AuthCheckMethod::HighBitsNoTBI:
+    BuildMI(CheckBlock, DL, TII->get(AArch64::EORXrs), TmpReg)
+        .addReg(AuthenticatedReg)
+        .addReg(AuthenticatedReg)
+        .addImm(1);
+    BuildMI(CheckBlock, DL, TII->get(AArch64::TBNZX))
+        .addReg(TmpReg)
+        .addImm(62)
+        .addMBB(BreakBlock);
+    return *SuccessBlock;
+  case AuthCheckMethod::XPACHint:
+    assert(AuthenticatedReg == AArch64::LR &&
+           "XPACHint mode is only compatible with checking the LR register");
+    assert(UseIKey && "XPACHint mode is only compatible with I-keys");
+    BuildMI(CheckBlock, DL, TII->get(AArch64::ORRXrs), TmpReg)
+        .addReg(AArch64::XZR)
+        .addReg(AArch64::LR)
+        .addImm(0);
+    BuildMI(CheckBlock, DL, TII->get(AArch64::XPACLRI));
+    BuildMI(CheckBlock, DL, TII->get(AArch64::SUBSXrs), AArch64::XZR)
+        .addReg(TmpReg)
+        .addReg(AArch64::LR)
+        .addImm(0);
+    BuildMI(CheckBlock, DL, TII->get(AArch64::Bcc))
+        .addImm(AArch64CC::NE)
+        .addMBB(BreakBlock);
+    return *SuccessBlock;
+  }
+}
+
+unsigned llvm::AArch64PAuth::getCheckerSizeInBytes(AuthCheckMethod Method) {
+  switch (Method) {
+  case AuthCheckMethod::None:
+    return 0;
+  case AuthCheckMethod::DummyLoad:
+    return 4;
+  case AuthCheckMethod::HighBitsNoTBI:
+    return 12;
+  case AuthCheckMethod::XPACHint:
+    return 20;
+  }
+}
+
+bool AArch64PointerAuth::checkAuthenticatedLR(
+    MachineBasicBlock::iterator TI) const {
+  AuthCheckMethod Method = Subtarget->getAuthenticatedLRCheckMethod();
+
+  if (Method == AuthCheckMethod::None)
+    return false;
+
+  // FIXME If FEAT_FPAC is implemented by the CPU, this check can be skipped.
+
+  assert(!TI->getMF()->hasWinCFI() && "WinCFI is not yet supported");
+
+  // The following code may create a signing oracle:
+  //
+  //   <authenticate LR>
+  //   TCRETURN          ; the callee may sign and spill the LR in its prologue
+  //
+  // To avoid generating a signing oracle, check the authenticated value
+  // before possibly re-signing it in the callee, as follows:
+  //
+  //   <authenticate LR>
+  //   <check if LR contains a valid address>
+  //   b.<cond> break_block
+  // ret_block:
+  //   TCRETURN
+  // break_block:
+  //   brk <BrkOperand>
+  //
+  // or just
+  //
+  //   <authenticate LR>
+  //   ldr tmp, [lr]
+  //   TCRETURN
+
+  // TmpReg is chosen assuming X16 and X17 are dead after TI.
+  assert(AArch64InstrInfo::isTailCallReturnInst(*TI) &&
+         "Tail call is expected");
+  Register TmpReg =
+      TI->readsRegister(AArch64::X16, TRI) ? AArch64::X17 : AArch64::X16;
+  assert(!TI->readsRegister(TmpReg, TRI) &&
+         "More than a single register is used by TCRETURN");
+
+  checkAuthenticatedRegister(TI, Method, AArch64::LR, TmpReg, /*UseIKey=*/true,
+                             BrkOperand);
+
+  return true;
+}
+
 bool AArch64PointerAuth::runOnMachineFunction(MachineFunction &MF) {
-  if (!MF.getInfo<AArch64FunctionInfo>()->shouldSignReturnAddress(true))
+  const auto *MFnI = MF.getInfo<AArch64FunctionInfo>();
+  if (!MFnI->shouldSignReturnAddress(true))
     return false;
 
   Subtarget = &MF.getSubtarget<AArch64Subtarget>();
   TII = Subtarget->getInstrInfo();
+  TRI = Subtarget->getRegisterInfo();
 
   SmallVector<MachineBasicBlock::iterator> DeletedInstrs;
+  SmallVector<MachineBasicBlock::iterator> TailCallInstrs;
+
   bool Modified = false;
+  bool HasAuthenticationInstrs = false;
 
   for (auto &MBB : MF) {
     for (auto &MI : MBB) {
       auto It = MI.getIterator();
       switch (MI.getOpcode()) {
       default:
-        // do nothing
+        if (AArch64InstrInfo::isTailCallReturnInst(MI))
+          TailCallInstrs.push_back(It);
         break;
       case AArch64::PAUTH_PROLOGUE:
         signLR(MF, It);
@@ -158,11 +325,20 @@ bool AArch64PointerAuth::runOnMachineFunction(MachineFunction &MF) {
         authenticateLR(MF, It);
         DeletedInstrs.push_back(It);
         Modified = true;
+        HasAuthenticationInstrs = true;
         break;
       }
     }
   }
 
+  // FIXME Do we need to emit any PAuth-related epilogue code at all
+  //       when SCS is enabled?
+  if (HasAuthenticationInstrs &&
+      !MFnI->needsShadowCallStackPrologueEpilogue(MF)) {
+    for (auto TailCall : TailCallInstrs)
+      Modified |= checkAuthenticatedLR(TailCall);
+  }
+
   for (auto MI : DeletedInstrs)
     MI->eraseFromParent();
 

diff  --git a/llvm/lib/Target/AArch64/AArch64PointerAuth.h b/llvm/lib/Target/AArch64/AArch64PointerAuth.h
new file mode 100644
index 000000000000000..e1ceaed58abe47c
--- /dev/null
+++ b/llvm/lib/Target/AArch64/AArch64PointerAuth.h
@@ -0,0 +1,116 @@
+//===-- AArch64PointerAuth.h -- Harden code using PAuth ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_AARCH64_AARCH64POINTERAUTH_H
+#define LLVM_LIB_TARGET_AARCH64_AARCH64POINTERAUTH_H
+
+#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/CodeGen/Register.h"
+
+namespace llvm {
+namespace AArch64PAuth {
+
+/// Variants of check performed on an authenticated pointer.
+///
+/// In cases such as authenticating the LR value when performing a tail call
+/// or when re-signing a signed pointer with a 
diff erent signing schema,
+/// a failed authentication may not generate an exception on its own and may
+/// create an authentication or signing oracle if not checked explicitly.
+///
+/// A number of check methods modify control flow in a similar way by
+/// rewriting the code
+///
+/// ```
+///   <authenticate LR>
+///   <more instructions>
+/// ```
+///
+/// as follows:
+///
+/// ```
+///   <authenticate LR>
+///   <method-specific checker>
+/// ret_block:
+///   <more instructions>
+///   ...
+///
+/// break_block:
+///   brk <code>
+/// ```
+enum class AuthCheckMethod {
+  /// Do not check the value at all
+  None,
+  /// Perform a load to a temporary register
+  DummyLoad,
+  /// Check by comparing bits 62 and 61 of the authenticated address.
+  ///
+  /// This method modifies control flow and inserts the following checker:
+  ///
+  /// ```
+  ///   eor Xtmp, Xn, Xn, lsl #1
+  ///   tbnz Xtmp, #62, break_block
+  /// ```
+  HighBitsNoTBI,
+  /// Check by comparing the authenticated value with an XPAC-ed one without
+  /// using PAuth instructions not encoded as HINT. Can only be applied to LR.
+  ///
+  /// This method modifies control flow and inserts the following checker:
+  ///
+  /// ```
+  ///   mov Xtmp, LR
+  ///   xpaclri           ; encoded as "hint #7"
+  ///   ; Note: at this point, the LR register contains the address as if
+  ///   ; the authentication succeeded and the temporary register contains the
+  ///   ; *real* result of authentication.
+  ///   cmp Xtmp, LR
+  ///   b.ne break_block
+  /// ```
+  XPACHint,
+};
+
+#define AUTH_CHECK_METHOD_CL_VALUES_COMMON                                     \
+      clEnumValN(AArch64PAuth::AuthCheckMethod::None, "none",                  \
+                 "Do not check authenticated address"),                        \
+      clEnumValN(AArch64PAuth::AuthCheckMethod::DummyLoad, "load",             \
+                 "Perform dummy load from authenticated address"),             \
+      clEnumValN(AArch64PAuth::AuthCheckMethod::HighBitsNoTBI,                 \
+                 "high-bits-notbi",                                            \
+                 "Compare bits 62 and 61 of address (TBI should be disabled)")
+
+#define AUTH_CHECK_METHOD_CL_VALUES_LR                                         \
+      AUTH_CHECK_METHOD_CL_VALUES_COMMON,                                      \
+      clEnumValN(AArch64PAuth::AuthCheckMethod::XPACHint, "xpac-hint",         \
+                 "Compare with the result of XPACLRI")
+
+/// Explicitly checks that pointer authentication succeeded.
+///
+/// Assuming AuthenticatedReg contains a value returned by one of the AUT*
+/// instructions, check the value using Method just before the instruction
+/// pointed to by MBBI. If the check succeeds, execution proceeds to the
+/// instruction pointed to by MBBI, otherwise a CPU exception is generated.
+///
+/// Some of the methods may need to know if the pointer was authenticated
+/// using an I-key or D-key and which register can be used as temporary.
+/// If an explicit BRK instruction is used to generate an exception, BrkImm
+/// specifies its immediate operand.
+///
+/// \returns The machine basic block containing the code that is executed
+///          after the check succeeds.
+MachineBasicBlock &checkAuthenticatedRegister(MachineBasicBlock::iterator MBBI,
+                                              AuthCheckMethod Method,
+                                              Register AuthenticatedReg,
+                                              Register TmpReg, bool UseIKey,
+                                              unsigned BrkImm);
+
+/// Returns the number of bytes added by checkAuthenticatedRegister.
+unsigned getCheckerSizeInBytes(AuthCheckMethod Method);
+
+} // end namespace AArch64PAuth
+} // end namespace llvm
+
+#endif

diff  --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
index 8946c0b71e2ba46..e3c3bff8e32984e 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
@@ -70,6 +70,13 @@ static cl::opt<bool> ForceStreamingCompatibleSVE(
         "Force the use of streaming-compatible SVE code for all functions"),
     cl::Hidden);
 
+static cl::opt<AArch64PAuth::AuthCheckMethod>
+    AuthenticatedLRCheckMethod("aarch64-authenticated-lr-check-method",
+                               cl::Hidden,
+                               cl::desc("Override the variant of check applied "
+                                        "to authenticated LR during tail call"),
+                               cl::values(AUTH_CHECK_METHOD_CL_VALUES_LR));
+
 unsigned AArch64Subtarget::getVectorInsertExtractBaseCost() const {
   if (OverrideVectorInsertExtractBaseCost.getNumOccurrences() > 0)
     return OverrideVectorInsertExtractBaseCost;
@@ -335,6 +342,8 @@ AArch64Subtarget::AArch64Subtarget(const Triple &TT, StringRef CPU,
   // X29 is named FP, so we can't use TRI->getName to check X29.
   if (ReservedRegNames.count("X29") || ReservedRegNames.count("FP"))
     ReserveXRegisterForRA.set(29);
+
+  AddressCheckPSV.reset(new AddressCheckPseudoSourceValue(TM));
 }
 
 const CallLowering *AArch64Subtarget::getCallLowering() const {
@@ -490,3 +499,26 @@ bool AArch64Subtarget::isSVEAvailable() const{
   // as we don't yet support the feature in LLVM.
   return hasSVE() && !isStreaming() && !isStreamingCompatible();
 }
+
+// If return address signing is enabled, tail calls are emitted as follows:
+//
+// ```
+//   <authenticate LR>
+//   <check LR>
+//   TCRETURN          ; the callee may sign and spill the LR in its prologue
+// ```
+//
+// LR may require explicit checking because if FEAT_FPAC is not implemented
+// and LR was tampered with, then `<authenticate LR>` will not generate an
+// exception on its own. Later, if the callee spills the signed LR value and
+// neither FEAT_PAuth2 nor FEAT_EPAC are implemented, the valid PAC replaces
+// the higher bits of LR thus hiding the authentication failure.
+AArch64PAuth::AuthCheckMethod
+AArch64Subtarget::getAuthenticatedLRCheckMethod() const {
+  if (AuthenticatedLRCheckMethod.getNumOccurrences())
+    return AuthenticatedLRCheckMethod;
+
+  // At now, use None by default because checks may introduce an unexpected
+  // performance regression or incompatibility with execute-only mappings.
+  return AArch64PAuth::AuthCheckMethod::None;
+}

diff  --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index 5d0fd7f9f45b59c..b91c5c81ed4d274 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -16,6 +16,7 @@
 #include "AArch64FrameLowering.h"
 #include "AArch64ISelLowering.h"
 #include "AArch64InstrInfo.h"
+#include "AArch64PointerAuth.h"
 #include "AArch64RegisterInfo.h"
 #include "AArch64SelectionDAGInfo.h"
 #include "llvm/CodeGen/GlobalISel/CallLowering.h"
@@ -432,6 +433,32 @@ class AArch64Subtarget final : public AArch64GenSubtargetInfo {
       return "__security_check_cookie_arm64ec";
     return "__security_check_cookie";
   }
+
+  /// Choose a method of checking LR before performing a tail call.
+  AArch64PAuth::AuthCheckMethod getAuthenticatedLRCheckMethod() const;
+
+  const PseudoSourceValue *getAddressCheckPSV() const {
+    return AddressCheckPSV.get();
+  }
+
+private:
+  /// Pseudo value representing memory load performed to check an address.
+  ///
+  /// This load operation is solely used for its side-effects: if the address
+  /// is not mapped (or not readable), it triggers CPU exception, otherwise
+  /// execution proceeds and the value is not used.
+  class AddressCheckPseudoSourceValue : public PseudoSourceValue {
+  public:
+    AddressCheckPseudoSourceValue(const TargetMachine &TM)
+        : PseudoSourceValue(TargetCustom, TM) {}
+
+    bool isConstant(const MachineFrameInfo *) const override { return false; }
+    bool isAliased(const MachineFrameInfo *) const override { return true; }
+    bool mayAlias(const MachineFrameInfo *) const override { return true; }
+    void printCustom(raw_ostream &OS) const override { OS << "AddressCheck"; }
+  };
+
+  std::unique_ptr<AddressCheckPseudoSourceValue> AddressCheckPSV;
 };
 } // End llvm namespace
 

diff  --git a/llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll b/llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll
new file mode 100644
index 000000000000000..ec04e553cac6e37
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll
@@ -0,0 +1,121 @@
+; RUN: llc -mtriple=aarch64 -asm-verbose=0 < %s | FileCheck -DAUTIASP="hint #29" --check-prefixes=COMMON %s
+; RUN: llc -mtriple=aarch64 -asm-verbose=0 -aarch64-authenticated-lr-check-method=load                   < %s | FileCheck -DAUTIASP="hint #29" --check-prefixes=COMMON,LDR %s
+; RUN: llc -mtriple=aarch64 -asm-verbose=0 -aarch64-authenticated-lr-check-method=high-bits-notbi        < %s | FileCheck -DAUTIASP="hint #29" --check-prefixes=COMMON,BITS-NOTBI,BRK %s
+; RUN: llc -mtriple=aarch64 -asm-verbose=0 -aarch64-authenticated-lr-check-method=xpac-hint              < %s | FileCheck -DAUTIASP="hint #29" -DXPACLRI="hint #7" --check-prefixes=COMMON,XPAC,BRK %s
+; RUN: llc -mtriple=aarch64 -asm-verbose=0 -aarch64-authenticated-lr-check-method=xpac-hint -mattr=v8.3a < %s | FileCheck -DAUTIASP="autiasp"  -DXPACLRI="xpaclri" --check-prefixes=COMMON,XPAC,BRK %s
+
+define i32 @tailcall_direct() "sign-return-address"="non-leaf" {
+; COMMON-LABEL: tailcall_direct:
+; COMMON:         str x30, [sp, #-16]!
+; COMMON:         ldr x30, [sp], #16
+;
+; COMMON-NEXT:    [[AUTIASP]]
+;
+; LDR-NEXT:       ldr w16, [x30]
+;
+; BITS-NOTBI-NEXT: eor x16, x30, x30, lsl #1
+; BITS-NOTBI-NEXT: tbnz x16, #62, .[[FAIL:LBB[_0-9]+]]
+;
+; XPAC-NEXT:      mov x16, x30
+; XPAC-NEXT:      [[XPACLRI]]
+; XPAC-NEXT:      cmp x16, x30
+; XPAC-NEXT:      b.ne .[[FAIL:LBB[_0-9]+]]
+;
+; COMMON-NEXT:    b callee
+; BRK-NEXT:     .[[FAIL]]:
+; BRK-NEXT:       brk #0xc471
+  tail call void asm sideeffect "", "~{lr}"()
+  %call = tail call i32 @callee()
+  ret i32 %call
+}
+
+define i32 @tailcall_indirect(ptr %fptr) "sign-return-address"="non-leaf" {
+; COMMON-LABEL: tailcall_indirect:
+; COMMON:         str x30, [sp, #-16]!
+; COMMON:         ldr x30, [sp], #16
+;
+; COMMON-NEXT:    [[AUTIASP]]
+;
+; LDR-NEXT:       ldr w16, [x30]
+;
+; BITS-NOTBI-NEXT: eor x16, x30, x30, lsl #1
+; BITS-NOTBI-NEXT: tbnz x16, #62, .[[FAIL:LBB[_0-9]+]]
+;
+; XPAC-NEXT:      mov x16, x30
+; XPAC-NEXT:      [[XPACLRI]]
+; XPAC-NEXT:      cmp x16, x30
+; XPAC-NEXT:      b.ne .[[FAIL:LBB[_0-9]+]]
+;
+; COMMON-NEXT:    br x0
+; BRK-NEXT:     .[[FAIL]]:
+; BRK-NEXT:       brk #0xc471
+  tail call void asm sideeffect "", "~{lr}"()
+  %call = tail call i32 %fptr()
+  ret i32 %call
+}
+
+define i32 @tailcall_direct_noframe() "sign-return-address"="non-leaf" {
+; COMMON-LABEL: tailcall_direct_noframe:
+; COMMON-NEXT:    .cfi_startproc
+; COMMON-NEXT:    b callee
+  %call = tail call i32 @callee()
+  ret i32 %call
+}
+
+define i32 @tailcall_indirect_noframe(ptr %fptr) "sign-return-address"="non-leaf" {
+; COMMON-LABEL: tailcall_indirect_noframe:
+; COMMON-NEXT:    .cfi_startproc
+; COMMON-NEXT:    br x0
+  %call = tail call i32 %fptr()
+  ret i32 %call
+}
+
+define i32 @tailcall_direct_noframe_sign_all() "sign-return-address"="all" {
+; COMMON-LABEL: tailcall_direct_noframe_sign_all:
+; COMMON-NOT:     str{{.*}}x30
+; COMMON-NOT:     ldr{{.*}}x30
+;
+; COMMON:         [[AUTIASP]]
+;
+; LDR-NEXT:       ldr w16, [x30]
+;
+; BITS-NOTBI-NEXT: eor x16, x30, x30, lsl #1
+; BITS-NOTBI-NEXT: tbnz x16, #62, .[[FAIL:LBB[_0-9]+]]
+;
+; XPAC-NEXT:      mov x16, x30
+; XPAC-NEXT:      [[XPACLRI]]
+; XPAC-NEXT:      cmp x16, x30
+; XPAC-NEXT:      b.ne .[[FAIL:LBB[_0-9]+]]
+;
+; COMMON-NEXT:    b callee
+; BRK-NEXT:     .[[FAIL]]:
+; BRK-NEXT:       brk #0xc471
+  %call = tail call i32 @callee()
+  ret i32 %call
+}
+
+define i32 @tailcall_indirect_noframe_sign_all(ptr %fptr) "sign-return-address"="all" {
+; COMMON-LABEL: tailcall_indirect_noframe_sign_all:
+; COMMON-NOT:     str{{.*}}x30
+; COMMON-NOT:     ldr{{.*}}x30
+;
+; COMMON:         [[AUTIASP]]
+;
+; LDR-NEXT:       ldr w16, [x30]
+;
+; BITS-NOTBI-NEXT: eor x16, x30, x30, lsl #1
+; BITS-NOTBI-NEXT: tbnz x16, #62, .[[FAIL:LBB[_0-9]+]]
+;
+; XPAC-NEXT:      mov x16, x30
+; XPAC-NEXT:      [[XPACLRI]]
+; XPAC-NEXT:      cmp x16, x30
+; XPAC-NEXT:      b.ne .[[FAIL:LBB[_0-9]+]]
+;
+; COMMON-NEXT:    br x0
+; BRK-NEXT:     .[[FAIL]]:
+; BRK-NEXT:       brk #0xc471
+  %call = tail call i32 %fptr()
+  ret i32 %call
+}
+
+declare i32 @callee()


        


More information about the llvm-commits mailing list