[llvm-branch-commits] [llvm] [AArch64][PAC] Move emission of LR checks in tail calls to AsmPrinter (PR #110705)

Anatoly Trosinenko via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Oct 25 05:28:33 PDT 2024


https://github.com/atrosinenko updated https://github.com/llvm/llvm-project/pull/110705

>From aec7d908c567a857d63a731eab044bbdd2925558 Mon Sep 17 00:00:00 2001
From: Anatoly Trosinenko <atrosinenko at accesssoftek.com>
Date: Mon, 23 Sep 2024 19:51:55 +0300
Subject: [PATCH 1/3] [AArch64][PAC] Move emission of LR checks in tail calls
 to AsmPrinter

Move the emission of the checks performed on the authenticated LR value
during tail calls to AArch64AsmPrinter class, so that different checker
sequences can be reused by pseudo instructions expanded there.
This adds one more option to AuthCheckMethod enumeration, the generic
XPAC variant which is not restricted to checking the LR register.
---
 llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp | 143 +++++++++++---
 llvm/lib/Target/AArch64/AArch64InstrInfo.cpp  |  13 ++
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |   2 +
 .../lib/Target/AArch64/AArch64PointerAuth.cpp | 182 +-----------------
 llvm/lib/Target/AArch64/AArch64PointerAuth.h  |  40 ++--
 llvm/lib/Target/AArch64/AArch64Subtarget.cpp  |   2 -
 llvm/lib/Target/AArch64/AArch64Subtarget.h    |  23 ---
 llvm/test/CodeGen/AArch64/ptrauth-ret-trap.ll |  36 ++--
 .../AArch64/sign-return-address-tailcall.ll   |  54 +++---
 9 files changed, 192 insertions(+), 303 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 6d2dd0ecbccf31..50502477706ccf 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -153,6 +153,7 @@ class AArch64AsmPrinter : public AsmPrinter {
   void emitPtrauthCheckAuthenticatedValue(Register TestedReg,
                                           Register ScratchReg,
                                           AArch64PACKey::ID Key,
+                                          AArch64PAuth::AuthCheckMethod Method,
                                           bool ShouldTrap,
                                           const MCSymbol *OnFailure);
 
@@ -1731,7 +1732,8 @@ unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
 /// of proceeding to the next instruction (only if ShouldTrap is false).
 void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
     Register TestedReg, Register ScratchReg, AArch64PACKey::ID Key,
-    bool ShouldTrap, const MCSymbol *OnFailure) {
+    AArch64PAuth::AuthCheckMethod Method, bool ShouldTrap,
+    const MCSymbol *OnFailure) {
   // Insert a sequence to check if authentication of TestedReg succeeded,
   // such as:
   //
@@ -1757,38 +1759,70 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
   //    Lsuccess:
   //      ...
   //
-  // This sequence is expensive, but we need more information to be able to
-  // do better.
-  //
-  // We can't TBZ the poison bit because EnhancedPAC2 XORs the PAC bits
-  // on failure.
-  // We can't TST the PAC bits because we don't always know how the address
-  // space is setup for the target environment (and the bottom PAC bit is
-  // based on that).
-  // Either way, we also don't always know whether TBI is enabled or not for
-  // the specific target environment.
+  // See the documentation on AuthCheckMethod enumeration constants for
+  // the specific code sequences that can be used to perform the check.
+  using AArch64PAuth::AuthCheckMethod;
 
-  unsigned XPACOpc = getXPACOpcodeForKey(Key);
+  if (Method == AuthCheckMethod::None)
+    return;
+  if (Method == AuthCheckMethod::DummyLoad) {
+    EmitToStreamer(MCInstBuilder(AArch64::LDRWui)
+                       .addReg(getWRegFromXReg(ScratchReg))
+                       .addReg(TestedReg)
+                       .addImm(0));
+    assert(ShouldTrap && !OnFailure && "DummyLoad always traps on error");
+    return;
+  }
 
   MCSymbol *SuccessSym = createTempSymbol("auth_success_");
+  if (Method == AuthCheckMethod::XPAC || Method == AuthCheckMethod::XPACHint) {
+    //  mov Xscratch, Xtested
+    emitMovXReg(ScratchReg, TestedReg);
 
-  //  mov Xscratch, Xtested
-  emitMovXReg(ScratchReg, TestedReg);
-
-  //  xpac(i|d) Xscratch
-  EmitToStreamer(MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
+    if (Method == AuthCheckMethod::XPAC) {
+      //  xpac(i|d) Xscratch
+      unsigned XPACOpc = getXPACOpcodeForKey(Key);
+      EmitToStreamer(
+          MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
+    } else {
+      //  xpaclri
+
+      // Note that this method applies XPAC to TestedReg instead of ScratchReg.
+      assert(TestedReg == AArch64::LR &&
+             "XPACHint mode is only compatible with checking the LR register");
+      assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
+             "XPACHint mode is only compatible with I-keys");
+      EmitToStreamer(MCInstBuilder(AArch64::XPACLRI));
+    }
 
-  //  cmp Xtested, Xscratch
-  EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
-                     .addReg(AArch64::XZR)
-                     .addReg(TestedReg)
-                     .addReg(ScratchReg)
-                     .addImm(0));
+    //  cmp Xtested, Xscratch
+    EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
+                       .addReg(AArch64::XZR)
+                       .addReg(TestedReg)
+                       .addReg(ScratchReg)
+                       .addImm(0));
 
-  //  b.eq Lsuccess
-  EmitToStreamer(MCInstBuilder(AArch64::Bcc)
-                     .addImm(AArch64CC::EQ)
-                     .addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
+    //  b.eq Lsuccess
+    EmitToStreamer(
+        MCInstBuilder(AArch64::Bcc)
+            .addImm(AArch64CC::EQ)
+            .addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
+  } else if (Method == AuthCheckMethod::HighBitsNoTBI) {
+    //  eor Xscratch, Xtested, Xtested, lsl #1
+    EmitToStreamer(MCInstBuilder(AArch64::EORXrs)
+                       .addReg(ScratchReg)
+                       .addReg(TestedReg)
+                       .addReg(TestedReg)
+                       .addImm(1));
+    //  tbz Xscratch, #62, Lsuccess
+    EmitToStreamer(
+        MCInstBuilder(AArch64::TBZX)
+            .addReg(ScratchReg)
+            .addImm(62)
+            .addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
+  } else {
+    llvm_unreachable("Unsupported check method");
+  }
 
   if (ShouldTrap) {
     assert(!OnFailure && "Cannot specify OnFailure with ShouldTrap");
@@ -1802,9 +1836,26 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
     // Note that this can introduce an authentication oracle (such as based on
     // the high bits of the re-signed value).
 
-    // FIXME: Can we simply return the AUT result, already in TestedReg?
-    //  mov Xtested, Xscratch
-    emitMovXReg(TestedReg, ScratchReg);
+    // FIXME: The XPAC method can be optimized by applying XPAC to TestedReg
+    //        instead of ScratchReg, thus eliminating one `mov` instruction.
+    //        Both XPAC and XPACHint can be further optimized by not using a
+    //        conditional branch jumping over an unconditional one.
+
+    switch (Method) {
+    case AuthCheckMethod::XPACHint:
+      // LR is already XPAC-ed at this point.
+      break;
+    case AuthCheckMethod::XPAC:
+      //  mov Xtested, Xscratch
+      emitMovXReg(TestedReg, ScratchReg);
+      break;
+    default:
+      // If Xtested was not XPAC-ed so far, emit XPAC here.
+      //  xpac(i|d) Xtested
+      unsigned XPACOpc = getXPACOpcodeForKey(Key);
+      EmitToStreamer(
+          MCInstBuilder(XPACOpc).addReg(TestedReg).addReg(TestedReg));
+    }
 
     if (OnFailure) {
       //  b Lend
@@ -1830,7 +1881,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
   //      ; sign x16 (if AUTPAC)
   //    Lend:   ; if not trapping on failure
   //
-  // with the checking sequence chosen depending on whether we should check
+  // with the checking sequence chosen depending on whether/how we should check
   // the pointer and whether we should trap on failure.
 
   // By default, auth/resign sequences check for auth failures.
@@ -1890,6 +1941,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
       EndSym = createTempSymbol("resign_end_");
 
     emitPtrauthCheckAuthenticatedValue(AArch64::X16, AArch64::X17, AUTKey,
+                                       AArch64PAuth::AuthCheckMethod::XPAC,
                                        ShouldTrap, EndSym);
   }
 
@@ -2260,11 +2312,34 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
     OutStreamer->emitLabel(LOHLabel);
   }
 
+  // With Pointer Authentication, it may be needed to explicitly check the
+  // authenticated value in LR when performing a tail call.
+  // Otherwise, the callee may re-sign the invalid return address,
+  // introducing a signing oracle.
+  auto CheckLRInTailCall = [this](Register CallDestinationReg) {
+    if (!AArch64FI->shouldSignReturnAddress(*MF))
+      return;
+
+    auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
+    if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
+      return;
+
+    Register ScratchReg =
+        CallDestinationReg == AArch64::X16 ? AArch64::X17 : AArch64::X16;
+    AArch64PACKey::ID Key =
+        AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
+    emitPtrauthCheckAuthenticatedValue(
+        AArch64::LR, ScratchReg, Key, LRCheckMethod,
+        /*ShouldTrap=*/true, /*OnFailure=*/nullptr);
+  };
+
   AArch64TargetStreamer *TS =
     static_cast<AArch64TargetStreamer *>(OutStreamer->getTargetStreamer());
   // Do any manual lowerings.
   switch (MI->getOpcode()) {
   default:
+    assert(!AArch64InstrInfo::isTailCallReturnInst(*MI) &&
+           "Unhandled tail call instruction");
     break;
   case AArch64::HINT: {
     // CurrentPatchableFunctionEntrySym can be CurrentFnBegin only for
@@ -2404,6 +2479,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
                               ? AArch64::X17
                               : AArch64::X16;
 
+    CheckLRInTailCall(MI->getOperand(0).getReg());
+
     unsigned DiscReg = AddrDisc;
     if (Disc) {
       if (AddrDisc != AArch64::NoRegister) {
@@ -2434,6 +2511,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
   case AArch64::TCRETURNrix17:
   case AArch64::TCRETURNrinotx16:
   case AArch64::TCRETURNriALL: {
+    CheckLRInTailCall(MI->getOperand(0).getReg());
+
     MCInst TmpInst;
     TmpInst.setOpcode(AArch64::BR);
     TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
@@ -2441,6 +2520,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
     return;
   }
   case AArch64::TCRETURNdi: {
+    CheckLRInTailCall(AArch64::NoRegister);
+
     MCOperand Dest;
     MCInstLowering.lowerOperand(MI->getOperand(0), Dest);
     MCInst TmpInst;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 32bc0e7d0d6475..d54582df819604 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -107,6 +107,19 @@ unsigned AArch64InstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
   unsigned NumBytes = 0;
   const MCInstrDesc &Desc = MI.getDesc();
 
+  if (!MI.isBundle() && isTailCallReturnInst(MI)) {
+    NumBytes = Desc.getSize() ? Desc.getSize() : 4;
+
+    const auto *MFI = MF->getInfo<AArch64FunctionInfo>();
+    if (!MFI->shouldSignReturnAddress(MF))
+      return NumBytes;
+
+    auto &STI = MF->getSubtarget<AArch64Subtarget>();
+    auto Method = STI.getAuthenticatedLRCheckMethod(*MF);
+    NumBytes += AArch64PAuth::getCheckerSizeInBytes(Method);
+    return NumBytes;
+  }
+
   // Size should be preferably set in
   // llvm/lib/Target/AArch64/AArch64InstrInfo.td (default case).
   // Specific cases handle instructions of variable sizes
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index fe3c8578b52aa4..d2d132dd421a10 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1952,6 +1952,8 @@ let Predicates = [HasPAuth] in {
   }
 
   // Size 16: 4 fixed + 8 variable, to compute discriminator.
+  // The size returned by getInstSizeInBytes() is incremented according
+  // to the variant of LR check.
   let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Size = 16,
       Uses = [SP] in {
     def AUTH_TCRETURN
diff --git a/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp b/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp
index 92ab4b5c3d251f..e966234296df57 100644
--- a/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp
+++ b/llvm/lib/Target/AArch64/AArch64PointerAuth.cpp
@@ -12,7 +12,6 @@
 #include "AArch64InstrInfo.h"
 #include "AArch64MachineFunctionInfo.h"
 #include "AArch64Subtarget.h"
-#include "Utils/AArch64BaseInfo.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
@@ -35,15 +34,8 @@ class AArch64PointerAuth : public MachineFunctionPass {
   StringRef getPassName() const override { return AARCH64_POINTER_AUTH_NAME; }
 
 private:
-  /// An immediate operand passed to BRK instruction, if it is ever emitted.
-  static unsigned BrkOperandForKey(AArch64PACKey::ID KeyId) {
-    const unsigned BrkOperandBase = 0xc470;
-    return BrkOperandBase + KeyId;
-  }
-
   const AArch64Subtarget *Subtarget = nullptr;
   const AArch64InstrInfo *TII = nullptr;
-  const AArch64RegisterInfo *TRI = nullptr;
 
   void signLR(MachineFunction &MF, MachineBasicBlock::iterator MBBI) const;
 
@@ -230,97 +222,6 @@ void AArch64PointerAuth::authenticateLR(
   }
 }
 
-namespace {
-
-// Mark dummy LDR instruction as volatile to prevent removing it as dead code.
-MachineMemOperand *createCheckMemOperand(MachineFunction &MF,
-                                         const AArch64Subtarget &Subtarget) {
-  MachinePointerInfo PointerInfo(Subtarget.getAddressCheckPSV());
-  auto MOVolatileLoad =
-      MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile;
-
-  return MF.getMachineMemOperand(PointerInfo, MOVolatileLoad, 4, Align(4));
-}
-
-} // namespace
-
-void llvm::AArch64PAuth::checkAuthenticatedRegister(
-    MachineBasicBlock::iterator MBBI, AuthCheckMethod Method,
-    Register AuthenticatedReg, Register TmpReg, bool UseIKey, unsigned BrkImm) {
-
-  MachineBasicBlock &MBB = *MBBI->getParent();
-  MachineFunction &MF = *MBB.getParent();
-  const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>();
-  const AArch64InstrInfo *TII = Subtarget.getInstrInfo();
-  DebugLoc DL = MBBI->getDebugLoc();
-
-  // All terminator instructions should be grouped at the end of the machine
-  // basic block, with no non-terminator instructions between them. Depending on
-  // the method requested, we will insert some regular instructions, maybe
-  // followed by a conditional branch instruction, which is a terminator, before
-  // MBBI. Thus, MBBI is expected to be the first terminator of its MBB.
-  assert(MBBI->isTerminator() && MBBI == MBB.getFirstTerminator() &&
-         "MBBI should be the first terminator in MBB");
-
-  // First, handle the methods not requiring creating extra MBBs.
-  switch (Method) {
-  default:
-    break;
-  case AuthCheckMethod::None:
-    return;
-  case AuthCheckMethod::DummyLoad:
-    BuildMI(MBB, MBBI, DL, TII->get(AArch64::LDRWui), getWRegFromXReg(TmpReg))
-        .addReg(AuthenticatedReg)
-        .addImm(0)
-        .addMemOperand(createCheckMemOperand(MF, Subtarget));
-    return;
-  }
-
-  // Control flow has to be changed, so arrange new MBBs.
-
-  // The block that explicitly generates a break-point exception on failure.
-  MachineBasicBlock *BreakBlock =
-      MF.CreateMachineBasicBlock(MBB.getBasicBlock());
-  MF.push_back(BreakBlock);
-  MBB.addSuccessor(BreakBlock);
-
-  BuildMI(BreakBlock, DL, TII->get(AArch64::BRK)).addImm(BrkImm);
-
-  switch (Method) {
-  case AuthCheckMethod::None:
-  case AuthCheckMethod::DummyLoad:
-    llvm_unreachable("Should be handled above");
-  case AuthCheckMethod::HighBitsNoTBI:
-    BuildMI(MBB, MBBI, DL, TII->get(AArch64::EORXrs), TmpReg)
-        .addReg(AuthenticatedReg)
-        .addReg(AuthenticatedReg)
-        .addImm(1);
-    BuildMI(MBB, MBBI, DL, TII->get(AArch64::TBNZX))
-        .addReg(TmpReg)
-        .addImm(62)
-        .addMBB(BreakBlock);
-    return;
-  case AuthCheckMethod::XPACHint:
-    assert(AuthenticatedReg == AArch64::LR &&
-           "XPACHint mode is only compatible with checking the LR register");
-    assert(UseIKey && "XPACHint mode is only compatible with I-keys");
-    BuildMI(MBB, MBBI, DL, TII->get(AArch64::ORRXrs), TmpReg)
-        .addReg(AArch64::XZR)
-        .addReg(AArch64::LR)
-        .addImm(0);
-    BuildMI(MBB, MBBI, DL, TII->get(AArch64::XPACLRI));
-    BuildMI(MBB, MBBI, DL, TII->get(AArch64::SUBSXrs), AArch64::XZR)
-        .addReg(TmpReg)
-        .addReg(AArch64::LR)
-        .addImm(0);
-    BuildMI(MBB, MBBI, DL, TII->get(AArch64::Bcc))
-        .addImm(AArch64CC::NE)
-        .addMBB(BreakBlock);
-    return;
-  }
-  llvm_unreachable("Unknown AuthCheckMethod enum");
-}
-
 unsigned llvm::AArch64PAuth::getCheckerSizeInBytes(AuthCheckMethod Method) {
   switch (Method) {
   case AuthCheckMethod::None:
@@ -330,63 +231,12 @@ unsigned llvm::AArch64PAuth::getCheckerSizeInBytes(AuthCheckMethod Method) {
   case AuthCheckMethod::HighBitsNoTBI:
     return 12;
   case AuthCheckMethod::XPACHint:
+  case AuthCheckMethod::XPAC:
     return 20;
   }
   llvm_unreachable("Unknown AuthCheckMethod enum");
 }
 
-bool AArch64PointerAuth::checkAuthenticatedLR(
-    MachineBasicBlock::iterator TI) const {
-  const AArch64FunctionInfo *MFnI = TI->getMF()->getInfo<AArch64FunctionInfo>();
-  AArch64PACKey::ID KeyId =
-      MFnI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
-
-  AuthCheckMethod Method =
-      Subtarget->getAuthenticatedLRCheckMethod(*TI->getMF());
-
-  if (Method == AuthCheckMethod::None)
-    return false;
-
-  // FIXME If FEAT_FPAC is implemented by the CPU, this check can be skipped.
-
-  assert(!TI->getMF()->hasWinCFI() && "WinCFI is not yet supported");
-
-  // The following code may create a signing oracle:
-  //
-  //   <authenticate LR>
-  //   TCRETURN          ; the callee may sign and spill the LR in its prologue
-  //
-  // To avoid generating a signing oracle, check the authenticated value
-  // before possibly re-signing it in the callee, as follows:
-  //
-  //   <authenticate LR>
-  //   <check if LR contains a valid address>
-  //   b.<cond> break_block
-  // ret_block:
-  //   TCRETURN
-  // break_block:
-  //   brk <BrkOperand>
-  //
-  // or just
-  //
-  //   <authenticate LR>
-  //   ldr tmp, [lr]
-  //   TCRETURN
-
-  // TmpReg is chosen assuming X16 and X17 are dead after TI.
-  assert(AArch64InstrInfo::isTailCallReturnInst(*TI) &&
-         "Tail call is expected");
-  Register TmpReg =
-      TI->readsRegister(AArch64::X16, TRI) ? AArch64::X17 : AArch64::X16;
-  assert(!TI->readsRegister(TmpReg, TRI) &&
-         "More than a single register is used by TCRETURN");
-
-  checkAuthenticatedRegister(TI, Method, AArch64::LR, TmpReg, /*UseIKey=*/true,
-                             BrkOperandForKey(KeyId));
-
-  return true;
-}
-
 void AArch64PointerAuth::emitBlend(MachineBasicBlock::iterator MBBI,
                                    Register Result, Register AddrDisc,
                                    unsigned IntDisc) const {
@@ -414,38 +264,21 @@ void AArch64PointerAuth::expandPAuthBlend(
 }
 
 bool AArch64PointerAuth::runOnMachineFunction(MachineFunction &MF) {
-  const auto *MFnI = MF.getInfo<AArch64FunctionInfo>();
-
   Subtarget = &MF.getSubtarget<AArch64Subtarget>();
   TII = Subtarget->getInstrInfo();
-  TRI = Subtarget->getRegisterInfo();
 
   SmallVector<MachineBasicBlock::instr_iterator> PAuthPseudoInstrs;
-  SmallVector<MachineBasicBlock::instr_iterator> TailCallInstrs;
 
   bool Modified = false;
-  bool HasAuthenticationInstrs = false;
 
   for (auto &MBB : MF) {
-    // Using instr_iterator to catch unsupported bundled TCRETURN* instructions
-    // instead of just skipping them.
-    for (auto &MI : MBB.instrs()) {
+    for (auto &MI : MBB) {
       switch (MI.getOpcode()) {
       default:
-        // Bundled TCRETURN* instructions (such as created by KCFI)
-        // are not supported yet, but no support is required if no
-        // PAUTH_EPILOGUE instructions exist in the same function.
-        // Skip the BUNDLE instruction itself (actual bundled instructions
-        // follow it in the instruction list).
-        if (MI.isBundle())
-          continue;
-        if (AArch64InstrInfo::isTailCallReturnInst(MI))
-          TailCallInstrs.push_back(MI.getIterator());
         break;
       case AArch64::PAUTH_PROLOGUE:
       case AArch64::PAUTH_EPILOGUE:
       case AArch64::PAUTH_BLEND:
-        assert(!MI.isBundled());
         PAuthPseudoInstrs.push_back(MI.getIterator());
         break;
       }
@@ -459,7 +292,6 @@ bool AArch64PointerAuth::runOnMachineFunction(MachineFunction &MF) {
       break;
     case AArch64::PAUTH_EPILOGUE:
       authenticateLR(MF, It);
-      HasAuthenticationInstrs = true;
       break;
     case AArch64::PAUTH_BLEND:
       expandPAuthBlend(It);
@@ -471,15 +303,5 @@ bool AArch64PointerAuth::runOnMachineFunction(MachineFunction &MF) {
     Modified = true;
   }
 
-  // FIXME Do we need to emit any PAuth-related epilogue code at all
-  //       when SCS is enabled?
-  if (HasAuthenticationInstrs &&
-      !MFnI->needsShadowCallStackPrologueEpilogue(MF)) {
-    for (auto TailCall : TailCallInstrs) {
-      assert(!TailCall->isBundled() && "Not yet supported");
-      Modified |= checkAuthenticatedLR(TailCall);
-    }
-  }
-
   return Modified;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64PointerAuth.h b/llvm/lib/Target/AArch64/AArch64PointerAuth.h
index 4ffda747822452..1e1d82fd50c7e0 100644
--- a/llvm/lib/Target/AArch64/AArch64PointerAuth.h
+++ b/llvm/lib/Target/AArch64/AArch64PointerAuth.h
@@ -35,27 +35,29 @@ namespace AArch64PAuth {
 /// ```
 ///   <authenticate LR>
 ///   <method-specific checker>
-/// ret_block:
+/// on_fail:
+///   brk <code>
+/// on_success:
 ///   <more instructions>
-///   ...
 ///
-/// break_block:
-///   brk <code>
 /// ```
 enum class AuthCheckMethod {
   /// Do not check the value at all
   None,
+
   /// Perform a load to a temporary register
   DummyLoad,
+
   /// Check by comparing bits 62 and 61 of the authenticated address.
   ///
   /// This method modifies control flow and inserts the following checker:
   ///
   /// ```
   ///   eor Xtmp, Xn, Xn, lsl #1
-  ///   tbnz Xtmp, #62, break_block
+  ///   tbz Xtmp, #62, on_success
   /// ```
   HighBitsNoTBI,
+
   /// Check by comparing the authenticated value with an XPAC-ed one without
   /// using PAuth instructions not encoded as HINT. Can only be applied to LR.
   ///
@@ -68,9 +70,19 @@ enum class AuthCheckMethod {
   ///   ; the authentication succeeded and the temporary register contains the
   ///   ; *real* result of authentication.
   ///   cmp Xtmp, LR
-  ///   b.ne break_block
+  ///   b.eq on_success
   /// ```
   XPACHint,
+
+  /// Similar to XPACHint but using Armv8.3-only XPAC instruction, thus
+  /// not restricted to LR:
+  /// ```
+  ///   mov Xtmp, Xn
+  ///   xpac(i|d) Xn
+  ///   cmp Xtmp, Xn
+  ///   b.eq on_success
+  /// ```
+  XPAC,
 };
 
 #define AUTH_CHECK_METHOD_CL_VALUES_COMMON                                     \
@@ -87,22 +99,6 @@ enum class AuthCheckMethod {
       clEnumValN(AArch64PAuth::AuthCheckMethod::XPACHint, "xpac-hint",         \
                  "Compare with the result of XPACLRI")
 
-/// Explicitly checks that pointer authentication succeeded.
-///
-/// Assuming AuthenticatedReg contains a value returned by one of the AUT*
-/// instructions, check the value using Method just before the instruction
-/// pointed to by MBBI. If the check succeeds, execution proceeds to the
-/// instruction pointed to by MBBI, otherwise a CPU exception is generated.
-///
-/// Some of the methods may need to know if the pointer was authenticated
-/// using an I-key or D-key and which register can be used as temporary.
-/// If an explicit BRK instruction is used to generate an exception, BrkImm
-/// specifies its immediate operand.
-void checkAuthenticatedRegister(MachineBasicBlock::iterator MBBI,
-                                AuthCheckMethod Method,
-                                Register AuthenticatedReg, Register TmpReg,
-                                bool UseIKey, unsigned BrkImm);
-
 /// Returns the number of bytes added by checkAuthenticatedRegister.
 unsigned getCheckerSizeInBytes(AuthCheckMethod Method);
 
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
index 32db1e8c2477a8..486c34410d4232 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp
@@ -367,8 +367,6 @@ AArch64Subtarget::AArch64Subtarget(const Triple &TT, StringRef CPU,
   // X29 is named FP, so we can't use TRI->getName to check X29.
   if (ReservedRegNames.count("X29") || ReservedRegNames.count("FP"))
     ReserveXRegisterForRA.set(29);
-
-  AddressCheckPSV.reset(new AddressCheckPseudoSourceValue(TM));
 }
 
 const CallLowering *AArch64Subtarget::getCallLowering() const {
diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h
index 9856415361e50d..fa7915e10be4e6 100644
--- a/llvm/lib/Target/AArch64/AArch64Subtarget.h
+++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h
@@ -434,29 +434,6 @@ class AArch64Subtarget final : public AArch64GenSubtargetInfo {
   /// a function.
   std::optional<uint16_t>
   getPtrAuthBlockAddressDiscriminatorIfEnabled(const Function &ParentFn) const;
-
-  const PseudoSourceValue *getAddressCheckPSV() const {
-    return AddressCheckPSV.get();
-  }
-
-private:
-  /// Pseudo value representing memory load performed to check an address.
-  ///
-  /// This load operation is solely used for its side-effects: if the address
-  /// is not mapped (or not readable), it triggers CPU exception, otherwise
-  /// execution proceeds and the value is not used.
-  class AddressCheckPseudoSourceValue : public PseudoSourceValue {
-  public:
-    AddressCheckPseudoSourceValue(const TargetMachine &TM)
-        : PseudoSourceValue(TargetCustom, TM) {}
-
-    bool isConstant(const MachineFrameInfo *) const override { return false; }
-    bool isAliased(const MachineFrameInfo *) const override { return true; }
-    bool mayAlias(const MachineFrameInfo *) const override { return true; }
-    void printCustom(raw_ostream &OS) const override { OS << "AddressCheck"; }
-  };
-
-  std::unique_ptr<AddressCheckPseudoSourceValue> AddressCheckPSV;
 };
 } // End llvm namespace
 
diff --git a/llvm/test/CodeGen/AArch64/ptrauth-ret-trap.ll b/llvm/test/CodeGen/AArch64/ptrauth-ret-trap.ll
index 42a3050eda1127..4821b3c0f274be 100644
--- a/llvm/test/CodeGen/AArch64/ptrauth-ret-trap.ll
+++ b/llvm/test/CodeGen/AArch64/ptrauth-ret-trap.ll
@@ -7,10 +7,10 @@
 ; CHECK-NEXT:   ldr x30, [sp], #16
 ; CHECK-NEXT:   autibsp
 ; CHECK-NEXT:   eor x16, x30, x30, lsl #1
-; CHECK-NEXT:   tbnz x16, #62, [[BAD:.L.*]]
-; CHECK-NEXT:   b bar
-; CHECK-NEXT:   [[BAD]]:
+; CHECK-NEXT:   tbz x16, #62, [[GOOD:.L.*]]
 ; CHECK-NEXT:   brk #0xc471
+; CHECK-NEXT:   [[GOOD]]:
+; CHECK-NEXT:   b bar
 define i32 @test_tailcall() #0 {
   call i32 @bar()
   %c = tail call i32 @bar()
@@ -27,10 +27,10 @@ define i32 @test_tailcall_noframe() #0 {
 ; CHECK-LABEL: test_tailcall_indirect:
 ; CHECK:         autibsp
 ; CHECK:         eor     x16, x30, x30, lsl #1
-; CHECK:         tbnz    x16, #62, [[BAD:.L.*]]
-; CHECK:         br      x0
-; CHECK: [[BAD]]:
+; CHECK:         tbz     x16, #62, [[GOOD:.L.*]]
 ; CHECK:         brk     #0xc471
+; CHECK: [[GOOD]]:
+; CHECK:         br      x0
 define void @test_tailcall_indirect(ptr %fptr) #0 {
   call i32 @test_tailcall()
   tail call void %fptr()
@@ -40,10 +40,10 @@ define void @test_tailcall_indirect(ptr %fptr) #0 {
 ; CHECK-LABEL: test_tailcall_indirect_in_x9:
 ; CHECK:         autibsp
 ; CHECK:         eor     x16, x30, x30, lsl #1
-; CHECK:         tbnz    x16, #62, [[BAD:.L.*]]
-; CHECK:         br      x9
-; CHECK: [[BAD]]:
+; CHECK:         tbz     x16, #62, [[GOOD:.L.*]]
 ; CHECK:         brk     #0xc471
+; CHECK: [[GOOD]]:
+; CHECK:         br      x9
 define void @test_tailcall_indirect_in_x9(ptr sret(i64) %ret, [8 x i64] %in, ptr %fptr) #0 {
   %ptr = alloca i8, i32 16
   call i32 @test_tailcall()
@@ -54,11 +54,11 @@ define void @test_tailcall_indirect_in_x9(ptr sret(i64) %ret, [8 x i64] %in, ptr
 ; CHECK-LABEL: test_auth_tailcall_indirect:
 ; CHECK:         autibsp
 ; CHECK:         eor     x16, x30, x30, lsl #1
-; CHECK:         tbnz    x16, #62, [[BAD:.L.*]]
+; CHECK:         tbz     x16, #62, [[GOOD:.L.*]]
+; CHECK:         brk     #0xc471
+; CHECK: [[GOOD]]:
 ; CHECK:         mov x16, #42
 ; CHECK:         braa      x0, x16
-; CHECK: [[BAD]]:
-; CHECK:         brk     #0xc471
 define void @test_auth_tailcall_indirect(ptr %fptr) #0 {
   call i32 @test_tailcall()
   tail call void %fptr() [ "ptrauth"(i32 0, i64 42) ]
@@ -68,10 +68,10 @@ define void @test_auth_tailcall_indirect(ptr %fptr) #0 {
 ; CHECK-LABEL: test_auth_tailcall_indirect_in_x9:
 ; CHECK:         autibsp
 ; CHECK:         eor     x16, x30, x30, lsl #1
-; CHECK:         tbnz    x16, #62, [[BAD:.L.*]]
-; CHECK:         brabz      x9
-; CHECK: [[BAD]]:
+; CHECK:         tbz     x16, #62, [[GOOD:.L.*]]
 ; CHECK:         brk     #0xc471
+; CHECK: [[GOOD]]:
+; CHECK:         brabz      x9
 define void @test_auth_tailcall_indirect_in_x9(ptr sret(i64) %ret, [8 x i64] %in, ptr %fptr) #0 {
   %ptr = alloca i8, i32 16
   call i32 @test_tailcall()
@@ -82,10 +82,10 @@ define void @test_auth_tailcall_indirect_in_x9(ptr sret(i64) %ret, [8 x i64] %in
 ; CHECK-LABEL: test_auth_tailcall_indirect_bti:
 ; CHECK:         autibsp
 ; CHECK:         eor     x17, x30, x30, lsl #1
-; CHECK:         tbnz    x17, #62, [[BAD:.L.*]]
-; CHECK:         brabz      x16
-; CHECK: [[BAD]]:
+; CHECK:         tbz     x17, #62, [[GOOD:.L.*]]
 ; CHECK:         brk     #0xc471
+; CHECK: [[GOOD]]:
+; CHECK:         brabz      x16
 define void @test_auth_tailcall_indirect_bti(ptr sret(i64) %ret, [8 x i64] %in, ptr %fptr) #0 "branch-target-enforcement"="true" {
   %ptr = alloca i8, i32 16
   call i32 @test_tailcall()
diff --git a/llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll b/llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll
index 0cc707298e4582..3e5c5c1695b899 100644
--- a/llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll
+++ b/llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll
@@ -14,16 +14,16 @@ define i32 @tailcall_direct() "sign-return-address"="non-leaf" {
 ; LDR-NEXT:       ldr w16, [x30]
 ;
 ; BITS-NOTBI-NEXT: eor x16, x30, x30, lsl #1
-; BITS-NOTBI-NEXT: tbnz x16, #62, .[[FAIL:LBB[_0-9]+]]
+; BITS-NOTBI-NEXT: tbz x16, #62, .[[GOOD:Lauth_success[_0-9]+]]
 ;
 ; XPAC-NEXT:      mov x16, x30
 ; XPAC-NEXT:      [[XPACLRI]]
-; XPAC-NEXT:      cmp x16, x30
-; XPAC-NEXT:      b.ne .[[FAIL:LBB[_0-9]+]]
+; XPAC-NEXT:      cmp x30, x16
+; XPAC-NEXT:      b.eq .[[GOOD:Lauth_success[_0-9]+]]
 ;
-; COMMON-NEXT:    b callee
-; BRK-NEXT:     .[[FAIL]]:
 ; BRK-NEXT:       brk #0xc470
+; BRK-NEXT:     .[[GOOD]]:
+; COMMON-NEXT:    b callee
   tail call void asm sideeffect "", "~{lr}"()
   %call = tail call i32 @callee()
   ret i32 %call
@@ -39,16 +39,16 @@ define i32 @tailcall_indirect(ptr %fptr) "sign-return-address"="non-leaf" {
 ; LDR-NEXT:       ldr w16, [x30]
 ;
 ; BITS-NOTBI-NEXT: eor x16, x30, x30, lsl #1
-; BITS-NOTBI-NEXT: tbnz x16, #62, .[[FAIL:LBB[_0-9]+]]
+; BITS-NOTBI-NEXT: tbz x16, #62, .[[GOOD:Lauth_success[_0-9]+]]
 ;
 ; XPAC-NEXT:      mov x16, x30
 ; XPAC-NEXT:      [[XPACLRI]]
-; XPAC-NEXT:      cmp x16, x30
-; XPAC-NEXT:      b.ne .[[FAIL:LBB[_0-9]+]]
+; XPAC-NEXT:      cmp x30, x16
+; XPAC-NEXT:      b.eq .[[GOOD:Lauth_success[_0-9]+]]
 ;
-; COMMON-NEXT:    br x0
-; BRK-NEXT:     .[[FAIL]]:
 ; BRK-NEXT:       brk #0xc470
+; BRK-NEXT:     .[[GOOD]]:
+; COMMON-NEXT:    br x0
   tail call void asm sideeffect "", "~{lr}"()
   %call = tail call i32 %fptr()
   ret i32 %call
@@ -80,16 +80,16 @@ define i32 @tailcall_direct_noframe_sign_all() "sign-return-address"="all" {
 ; LDR-NEXT:       ldr w16, [x30]
 ;
 ; BITS-NOTBI-NEXT: eor x16, x30, x30, lsl #1
-; BITS-NOTBI-NEXT: tbnz x16, #62, .[[FAIL:LBB[_0-9]+]]
+; BITS-NOTBI-NEXT: tbz x16, #62, .[[GOOD:Lauth_success[_0-9]+]]
 ;
 ; XPAC-NEXT:      mov x16, x30
 ; XPAC-NEXT:      [[XPACLRI]]
-; XPAC-NEXT:      cmp x16, x30
-; XPAC-NEXT:      b.ne .[[FAIL:LBB[_0-9]+]]
+; XPAC-NEXT:      cmp x30, x16
+; XPAC-NEXT:      b.eq .[[GOOD:Lauth_success[_0-9]+]]
 ;
-; COMMON-NEXT:    b callee
-; BRK-NEXT:     .[[FAIL]]:
 ; BRK-NEXT:       brk #0xc470
+; BRK-NEXT:     .[[GOOD]]:
+; COMMON-NEXT:    b callee
   %call = tail call i32 @callee()
   ret i32 %call
 }
@@ -104,16 +104,16 @@ define i32 @tailcall_indirect_noframe_sign_all(ptr %fptr) "sign-return-address"=
 ; LDR-NEXT:       ldr w16, [x30]
 ;
 ; BITS-NOTBI-NEXT: eor x16, x30, x30, lsl #1
-; BITS-NOTBI-NEXT: tbnz x16, #62, .[[FAIL:LBB[_0-9]+]]
+; BITS-NOTBI-NEXT: tbz x16, #62, .[[GOOD:Lauth_success[_0-9]+]]
 ;
 ; XPAC-NEXT:      mov x16, x30
 ; XPAC-NEXT:      [[XPACLRI]]
-; XPAC-NEXT:      cmp x16, x30
-; XPAC-NEXT:      b.ne .[[FAIL:LBB[_0-9]+]]
+; XPAC-NEXT:      cmp x30, x16
+; XPAC-NEXT:      b.eq .[[GOOD:Lauth_success[_0-9]+]]
 ;
-; COMMON-NEXT:    br x0
-; BRK-NEXT:     .[[FAIL]]:
 ; BRK-NEXT:       brk #0xc470
+; BRK-NEXT:     .[[GOOD]]:
+; COMMON-NEXT:    br x0
   %call = tail call i32 %fptr()
   ret i32 %call
 }
@@ -121,9 +121,9 @@ define i32 @tailcall_indirect_noframe_sign_all(ptr %fptr) "sign-return-address"=
 define i32 @tailcall_ib_key() "sign-return-address"="all" "sign-return-address-key"="b_key" {
 ; COMMON-LABEL: tailcall_ib_key:
 ;
+; BRK:            brk #0xc471
+; BRK-NEXT:     .{{Lauth_success.*}}:
 ; COMMON:         b callee
-; BRK-NEXT:     .{{LBB.*}}:
-; BRK-NEXT:       brk #0xc471
   tail call void asm sideeffect "", "~{lr}"()
   %call = tail call i32 @callee()
   ret i32 %call
@@ -141,16 +141,16 @@ define i32 @tailcall_two_branches(i1 %0) "sign-return-address"="all" {
 ; LDR-NEXT:          ldr w16, [x30]
 ;
 ; BITS-NOTBI-NEXT:   eor x16, x30, x30, lsl #1
-; BITS-NOTBI-NEXT:   tbnz x16, #62, .[[FAIL:LBB[_0-9]+]]
+; BITS-NOTBI-NEXT:   tbz x16, #62, .[[GOOD:Lauth_success[_0-9]+]]
 ;
 ; XPAC-NEXT:         mov x16, x30
 ; XPAC-NEXT:         [[XPACLRI]]
-; XPAC-NEXT:         cmp x16, x30
-; XPAC-NEXT:         b.ne .[[FAIL:LBB[_0-9]+]]
+; XPAC-NEXT:         cmp x30, x16
+; XPAC-NEXT:         b.eq .[[GOOD:Lauth_success[_0-9]+]]
 ;
-; COMMON-NEXT:       b callee
-; BRK-NEXT:        .[[FAIL]]:
 ; BRK-NEXT:          brk #0xc470
+; BRK-NEXT:        .[[GOOD]]:
+; COMMON-NEXT:       b callee
   br i1 %0, label %2, label %3
 2:
   call void @callee2()

>From 248ab08f126b1c295586df2dba704e4b568ba3b9 Mon Sep 17 00:00:00 2001
From: Anatoly Trosinenko <atrosinenko at accesssoftek.com>
Date: Mon, 7 Oct 2024 17:49:52 +0300
Subject: [PATCH 2/3] Check both register operands of AUTH_TCRETURN*

---
 llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp | 54 ++++++++++---------
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |  8 ++-
 .../lib/Target/AArch64/AArch64RegisterInfo.td |  4 ++
 llvm/test/CodeGen/AArch64/ptrauth-call.ll     |  4 +-
 4 files changed, 42 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 50502477706ccf..0a7156ef471119 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -157,6 +157,9 @@ class AArch64AsmPrinter : public AsmPrinter {
                                           bool ShouldTrap,
                                           const MCSymbol *OnFailure);
 
+  // Check authenticated LR before tail calling.
+  void emitPtrauthTailCallHardening(const MachineInstr *TC);
+
   // Emit the sequence for AUT or AUTPAC.
   void emitPtrauthAuthResign(const MachineInstr *MI);
 
@@ -1870,6 +1873,30 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
   OutStreamer->emitLabel(SuccessSym);
 }
 
+// With Pointer Authentication, it may be needed to explicitly check the
+// authenticated value in LR before performing a tail call.
+// Otherwise, the callee may re-sign the invalid return address,
+// introducing a signing oracle.
+void AArch64AsmPrinter::emitPtrauthTailCallHardening(const MachineInstr *TC) {
+  if (!AArch64FI->shouldSignReturnAddress(*MF))
+    return;
+
+  auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
+  if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
+    return;
+
+  const AArch64RegisterInfo *TRI = STI->getRegisterInfo();
+  Register ScratchReg =
+      TC->readsRegister(AArch64::X16, TRI) ? AArch64::X17 : AArch64::X16;
+  assert(!TC->readsRegister(ScratchReg, TRI) &&
+         "Neither x16 nor x17 is available as a scratch register");
+  AArch64PACKey::ID Key =
+      AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
+  emitPtrauthCheckAuthenticatedValue(
+      AArch64::LR, ScratchReg, Key, LRCheckMethod,
+      /*ShouldTrap=*/true, /*OnFailure=*/nullptr);
+}
+
 void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
   const bool IsAUTPAC = MI->getOpcode() == AArch64::AUTPAC;
 
@@ -2312,27 +2339,6 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
     OutStreamer->emitLabel(LOHLabel);
   }
 
-  // With Pointer Authentication, it may be needed to explicitly check the
-  // authenticated value in LR when performing a tail call.
-  // Otherwise, the callee may re-sign the invalid return address,
-  // introducing a signing oracle.
-  auto CheckLRInTailCall = [this](Register CallDestinationReg) {
-    if (!AArch64FI->shouldSignReturnAddress(*MF))
-      return;
-
-    auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
-    if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
-      return;
-
-    Register ScratchReg =
-        CallDestinationReg == AArch64::X16 ? AArch64::X17 : AArch64::X16;
-    AArch64PACKey::ID Key =
-        AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
-    emitPtrauthCheckAuthenticatedValue(
-        AArch64::LR, ScratchReg, Key, LRCheckMethod,
-        /*ShouldTrap=*/true, /*OnFailure=*/nullptr);
-  };
-
   AArch64TargetStreamer *TS =
     static_cast<AArch64TargetStreamer *>(OutStreamer->getTargetStreamer());
   // Do any manual lowerings.
@@ -2479,7 +2485,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
                               ? AArch64::X17
                               : AArch64::X16;
 
-    CheckLRInTailCall(MI->getOperand(0).getReg());
+    emitPtrauthTailCallHardening(MI);
 
     unsigned DiscReg = AddrDisc;
     if (Disc) {
@@ -2511,7 +2517,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
   case AArch64::TCRETURNrix17:
   case AArch64::TCRETURNrinotx16:
   case AArch64::TCRETURNriALL: {
-    CheckLRInTailCall(MI->getOperand(0).getReg());
+    emitPtrauthTailCallHardening(MI);
 
     MCInst TmpInst;
     TmpInst.setOpcode(AArch64::BR);
@@ -2520,7 +2526,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
     return;
   }
   case AArch64::TCRETURNdi: {
-    CheckLRInTailCall(AArch64::NoRegister);
+    emitPtrauthTailCallHardening(MI);
 
     MCOperand Dest;
     MCInstLowering.lowerOperand(MI->getOperand(0), Dest);
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index d2d132dd421a10..c7615f3a751ced 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1954,15 +1954,19 @@ let Predicates = [HasPAuth] in {
   // Size 16: 4 fixed + 8 variable, to compute discriminator.
   // The size returned by getInstSizeInBytes() is incremented according
   // to the variant of LR check.
+  // As the check requires either x16 or x17 as a scratch register and
+  // authenticated tail call instructions have two register operands,
+  // make sure at least one register is usable as a scratch one - for that
+  // purpose, use tcGPRnotx16x17 register class for the second operand.
   let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Size = 16,
       Uses = [SP] in {
     def AUTH_TCRETURN
       : Pseudo<(outs), (ins tcGPR64:$dst, i32imm:$FPDiff, i32imm:$Key,
-                            i64imm:$Disc, tcGPR64:$AddrDisc),
+                            i64imm:$Disc, tcGPRnotx16x17:$AddrDisc),
                []>, Sched<[WriteBrReg]>;
     def AUTH_TCRETURN_BTI
       : Pseudo<(outs), (ins tcGPRx16x17:$dst, i32imm:$FPDiff, i32imm:$Key,
-                            i64imm:$Disc, tcGPR64:$AddrDisc),
+                            i64imm:$Disc, tcGPRnotx16x17:$AddrDisc),
                []>, Sched<[WriteBrReg]>;
   }
 
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index 8516ab2c7dd71c..46ec32b402e35d 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -238,6 +238,10 @@ def tcGPR64 : RegisterClass<"AArch64", [i64], 64, (sub GPR64common, X19, X20, X2
 def tcGPRx17 : RegisterClass<"AArch64", [i64], 64, (add X17)>;
 def tcGPRx16x17 : RegisterClass<"AArch64", [i64], 64, (add X16, X17)>;
 def tcGPRnotx16 : RegisterClass<"AArch64", [i64], 64, (sub tcGPR64, X16)>;
+// LR checking code expects either x16 or x17 to be available as a scratch
+// register - for that reason restrict one of two register operands of
+// AUTH_TCRETURN* pseudos.
+def tcGPRnotx16x17 : RegisterClass<"AArch64", [i64], 64, (sub tcGPR64, X16, X17)>;
 
 // Register set that excludes registers that are reserved for procedure calls.
 // This is used for pseudo-instructions that are actually implemented using a
diff --git a/llvm/test/CodeGen/AArch64/ptrauth-call.ll b/llvm/test/CodeGen/AArch64/ptrauth-call.ll
index 5fd6116285122f..4d96287c971465 100644
--- a/llvm/test/CodeGen/AArch64/ptrauth-call.ll
+++ b/llvm/test/CodeGen/AArch64/ptrauth-call.ll
@@ -173,9 +173,9 @@ define void @test_tailcall_omit_mov_x16_x16(ptr %objptr) #0 {
 ; CHECK:         mov     x17, x0
 ; CHECK:         movk    x17, #6503, lsl #48
 ; CHECK:         autda   x16, x17
-; CHECK:         ldr     x1, [x16]
+; CHECK:         ldr     x2, [x16]
 ; CHECK:         movk    x16, #54167, lsl #48
-; CHECK:         braa    x1, x16
+; CHECK:         braa    x2, x16
   %vtable.signed = load ptr, ptr %objptr, align 8
   %objptr.int = ptrtoint ptr %objptr to i64
   %vtable.discr = tail call i64 @llvm.ptrauth.blend(i64 %objptr.int, i64 6503)

>From 8c2472ca7f9f41b3174621fb91a2a6f861f0beac Mon Sep 17 00:00:00 2001
From: Anatoly Trosinenko <atrosinenko at accesssoftek.com>
Date: Thu, 10 Oct 2024 17:16:54 +0300
Subject: [PATCH 3/3] Misc improvements

---
 llvm/lib/Target/AArch64/AArch64InstrInfo.cpp  |  2 +-
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |  8 +++---
 llvm/lib/Target/AArch64/AArch64PointerAuth.h  | 12 +++++----
 .../AArch64/sign-return-address-tailcall.ll   | 26 +++++++++++++++++++
 4 files changed, 38 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index d54582df819604..7d7d3324400cbd 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -114,7 +114,7 @@ unsigned AArch64InstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
     if (!MFI->shouldSignReturnAddress(MF))
       return NumBytes;
 
-    auto &STI = MF->getSubtarget<AArch64Subtarget>();
+    const auto &STI = MF->getSubtarget<AArch64Subtarget>();
     auto Method = STI.getAuthenticatedLRCheckMethod(*MF);
     NumBytes += AArch64PAuth::getCheckerSizeInBytes(Method);
     return NumBytes;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index c7615f3a751ced..6f6f430c6ae6ba 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1972,16 +1972,16 @@ let Predicates = [HasPAuth] in {
 
   let Predicates = [TailCallAny] in
     def : Pat<(AArch64authtcret tcGPR64:$dst, (i32 timm:$FPDiff), (i32 timm:$Key),
-                                (i64 timm:$Disc), tcGPR64:$AddrDisc),
+                                (i64 timm:$Disc), tcGPRnotx16x17:$AddrDisc),
               (AUTH_TCRETURN tcGPR64:$dst, imm:$FPDiff, imm:$Key, imm:$Disc,
-                             tcGPR64:$AddrDisc)>;
+                             tcGPRnotx16x17:$AddrDisc)>;
 
   let Predicates = [TailCallX16X17] in
     def : Pat<(AArch64authtcret tcGPRx16x17:$dst, (i32 timm:$FPDiff),
                                 (i32 timm:$Key), (i64 timm:$Disc),
-                                tcGPR64:$AddrDisc),
+                                tcGPRnotx16x17:$AddrDisc),
               (AUTH_TCRETURN_BTI tcGPRx16x17:$dst, imm:$FPDiff, imm:$Key,
-                                 imm:$Disc, tcGPR64:$AddrDisc)>;
+                                 imm:$Disc, tcGPRnotx16x17:$AddrDisc)>;
 }
 
 // v9.5-A pointer authentication extensions
diff --git a/llvm/lib/Target/AArch64/AArch64PointerAuth.h b/llvm/lib/Target/AArch64/AArch64PointerAuth.h
index 1e1d82fd50c7e0..d6947f0d22aacb 100644
--- a/llvm/lib/Target/AArch64/AArch64PointerAuth.h
+++ b/llvm/lib/Target/AArch64/AArch64PointerAuth.h
@@ -86,13 +86,15 @@ enum class AuthCheckMethod {
 };
 
 #define AUTH_CHECK_METHOD_CL_VALUES_COMMON                                     \
-      clEnumValN(AArch64PAuth::AuthCheckMethod::None, "none",                  \
-                 "Do not check authenticated address"),                        \
+  clEnumValN(AArch64PAuth::AuthCheckMethod::None, "none",                      \
+             "Do not check authenticated address"),                            \
       clEnumValN(AArch64PAuth::AuthCheckMethod::DummyLoad, "load",             \
                  "Perform dummy load from authenticated address"),             \
-      clEnumValN(AArch64PAuth::AuthCheckMethod::HighBitsNoTBI,                 \
-                 "high-bits-notbi",                                            \
-                 "Compare bits 62 and 61 of address (TBI should be disabled)")
+      clEnumValN(                                                              \
+          AArch64PAuth::AuthCheckMethod::HighBitsNoTBI, "high-bits-notbi",     \
+          "Compare bits 62 and 61 of address (TBI should be disabled)"),       \
+      clEnumValN(AArch64PAuth::AuthCheckMethod::XPAC, "xpac",                  \
+                 "Compare with the result of XPAC (requires Armv8.3-a)")
 
 #define AUTH_CHECK_METHOD_CL_VALUES_LR                                         \
       AUTH_CHECK_METHOD_CL_VALUES_COMMON,                                      \
diff --git a/llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll b/llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll
index 3e5c5c1695b899..032d3cc05961f9 100644
--- a/llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll
+++ b/llvm/test/CodeGen/AArch64/sign-return-address-tailcall.ll
@@ -3,6 +3,7 @@
 ; RUN: llc -mtriple=aarch64 -asm-verbose=0 -aarch64-authenticated-lr-check-method=high-bits-notbi        < %s | FileCheck -DAUTIASP="hint #29" --check-prefixes=COMMON,BITS-NOTBI,BRK %s
 ; RUN: llc -mtriple=aarch64 -asm-verbose=0 -aarch64-authenticated-lr-check-method=xpac-hint              < %s | FileCheck -DAUTIASP="hint #29" -DXPACLRI="hint #7" --check-prefixes=COMMON,XPAC,BRK %s
 ; RUN: llc -mtriple=aarch64 -asm-verbose=0 -aarch64-authenticated-lr-check-method=xpac-hint -mattr=v8.3a < %s | FileCheck -DAUTIASP="autiasp"  -DXPACLRI="xpaclri" --check-prefixes=COMMON,XPAC,BRK %s
+; RUN: llc -mtriple=aarch64 -asm-verbose=0 -aarch64-authenticated-lr-check-method=xpac      -mattr=v8.3a < %s | FileCheck -DAUTIASP="autiasp"  --check-prefixes=COMMON,XPAC83,BRK %s
 
 define i32 @tailcall_direct() "sign-return-address"="non-leaf" {
 ; COMMON-LABEL: tailcall_direct:
@@ -21,6 +22,11 @@ define i32 @tailcall_direct() "sign-return-address"="non-leaf" {
 ; XPAC-NEXT:      cmp x30, x16
 ; XPAC-NEXT:      b.eq .[[GOOD:Lauth_success[_0-9]+]]
 ;
+; XPAC83-NEXT:    mov x16, x30
+; XPAC83-NEXT:    xpaci x16
+; XPAC83-NEXT:    cmp x30, x16
+; XPAC83-NEXT:    b.eq .[[GOOD:Lauth_success[_0-9]+]]
+;
 ; BRK-NEXT:       brk #0xc470
 ; BRK-NEXT:     .[[GOOD]]:
 ; COMMON-NEXT:    b callee
@@ -46,6 +52,11 @@ define i32 @tailcall_indirect(ptr %fptr) "sign-return-address"="non-leaf" {
 ; XPAC-NEXT:      cmp x30, x16
 ; XPAC-NEXT:      b.eq .[[GOOD:Lauth_success[_0-9]+]]
 ;
+; XPAC83-NEXT:    mov x16, x30
+; XPAC83-NEXT:    xpaci x16
+; XPAC83-NEXT:    cmp x30, x16
+; XPAC83-NEXT:    b.eq .[[GOOD:Lauth_success[_0-9]+]]
+;
 ; BRK-NEXT:       brk #0xc470
 ; BRK-NEXT:     .[[GOOD]]:
 ; COMMON-NEXT:    br x0
@@ -87,6 +98,11 @@ define i32 @tailcall_direct_noframe_sign_all() "sign-return-address"="all" {
 ; XPAC-NEXT:      cmp x30, x16
 ; XPAC-NEXT:      b.eq .[[GOOD:Lauth_success[_0-9]+]]
 ;
+; XPAC83-NEXT:    mov x16, x30
+; XPAC83-NEXT:    xpaci x16
+; XPAC83-NEXT:    cmp x30, x16
+; XPAC83-NEXT:    b.eq .[[GOOD:Lauth_success[_0-9]+]]
+;
 ; BRK-NEXT:       brk #0xc470
 ; BRK-NEXT:     .[[GOOD]]:
 ; COMMON-NEXT:    b callee
@@ -111,6 +127,11 @@ define i32 @tailcall_indirect_noframe_sign_all(ptr %fptr) "sign-return-address"=
 ; XPAC-NEXT:      cmp x30, x16
 ; XPAC-NEXT:      b.eq .[[GOOD:Lauth_success[_0-9]+]]
 ;
+; XPAC83-NEXT:    mov x16, x30
+; XPAC83-NEXT:    xpaci x16
+; XPAC83-NEXT:    cmp x30, x16
+; XPAC83-NEXT:    b.eq .[[GOOD:Lauth_success[_0-9]+]]
+;
 ; BRK-NEXT:       brk #0xc470
 ; BRK-NEXT:     .[[GOOD]]:
 ; COMMON-NEXT:    br x0
@@ -148,6 +169,11 @@ define i32 @tailcall_two_branches(i1 %0) "sign-return-address"="all" {
 ; XPAC-NEXT:         cmp x30, x16
 ; XPAC-NEXT:         b.eq .[[GOOD:Lauth_success[_0-9]+]]
 ;
+; XPAC83-NEXT:       mov x16, x30
+; XPAC83-NEXT:       xpaci x16
+; XPAC83-NEXT:       cmp x30, x16
+; XPAC83-NEXT:       b.eq .[[GOOD:Lauth_success[_0-9]+]]
+;
 ; BRK-NEXT:          brk #0xc470
 ; BRK-NEXT:        .[[GOOD]]:
 ; COMMON-NEXT:       b callee



More information about the llvm-branch-commits mailing list