[llvm-branch-commits] [llvm] [AArch64][PAC] Lower authenticated calls with ptrauth bundles. (PR #85736)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Mar 18 22:30:54 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-llvm-globalisel

Author: Ahmed Bougacha (ahmedbougacha)

<details>
<summary>Changes</summary>

This adds codegen support for the "ptrauth" operand bundles, which can
be used to augment indirect calls with the equivalent of an
`@<!-- -->llvm.ptrauth.auth` intrinsic call on the call target (possibly
preceded by an `@<!-- -->llvm.ptrauth.blend` on the auth discriminator if
applicable.)

This allows the generation of combined authenticating calls
on AArch64 (in the BLRA* PAuth instructions), while avoiding
the raw just-authenticated function pointer from being
exposed to attackers.

This is done by threading a PtrAuthInfo descriptor through
the call lowering infrastructure.

Note that this also applies to the other forms of indirect calls,
notably invokes, rvmarker, and tail calls.  Tail-calls in particular
bring some additional complexity, with the intersecting register
constraints of BTI and PAC discriminator computation.

This also adopts an x8+ allocation order for GPR64noip, matching GPR64.

---

Patch is 72.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85736.diff


21 Files Affected:

- (modified) llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h (+7) 
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+18) 
- (modified) llvm/lib/CodeGen/GlobalISel/CallLowering.cpp (+2) 
- (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+15-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+46-5) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h (+5-1) 
- (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+90) 
- (modified) llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp (+39-4) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+79-24) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+12) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+2) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+80) 
- (modified) llvm/lib/Target/AArch64/AArch64RegisterInfo.td (+4-1) 
- (modified) llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp (+68-20) 
- (modified) llvm/lib/Target/AArch64/GISel/AArch64GlobalISelUtils.cpp (+1-1) 
- (added) llvm/test/CodeGen/AArch64/GlobalISel/ptrauth-invoke.ll (+183) 
- (modified) llvm/test/CodeGen/AArch64/branch-target-enforcement-indirect-calls.ll (+1-3) 
- (added) llvm/test/CodeGen/AArch64/ptrauth-bti-call.ll (+105) 
- (added) llvm/test/CodeGen/AArch64/ptrauth-call-rv-marker.ll (+154) 
- (added) llvm/test/CodeGen/AArch64/ptrauth-call.ll (+195) 
- (added) llvm/test/CodeGen/AArch64/ptrauth-invoke.ll (+189) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h b/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h
index 4c187a3068d823..e02c7435a1b6d2 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h
@@ -46,6 +46,10 @@ class CallLowering {
 
   virtual void anchor();
 public:
+  struct PointerAuthInfo {
+    Register Discriminator;
+    uint64_t Key;
+  };
   struct BaseArgInfo {
     Type *Ty;
     SmallVector<ISD::ArgFlagsTy, 4> Flags;
@@ -125,6 +129,8 @@ class CallLowering {
 
     MDNode *KnownCallees = nullptr;
 
+    std::optional<PointerAuthInfo> PAI;
+
     /// True if the call must be tail call optimized.
     bool IsMustTailCall = false;
 
@@ -587,6 +593,7 @@ class CallLowering {
   bool lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &Call,
                  ArrayRef<Register> ResRegs,
                  ArrayRef<ArrayRef<Register>> ArgRegs, Register SwiftErrorVReg,
+                 std::optional<PointerAuthInfo> PAI,
                  Register ConvergenceCtrlToken,
                  std::function<unsigned()> GetCalleeReg) const;
 
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 2f164a460db843..e47a4e12d33cbc 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4290,6 +4290,9 @@ class TargetLowering : public TargetLoweringBase {
   /// Return true if the target supports kcfi operand bundles.
   virtual bool supportKCFIBundles() const { return false; }
 
+  /// Return true if the target supports ptrauth operand bundles.
+  virtual bool supportPtrAuthBundles() const { return false; }
+
   /// Perform necessary initialization to handle a subset of CSRs explicitly
   /// via copies. This function is called at the beginning of instruction
   /// selection.
@@ -4401,6 +4404,14 @@ class TargetLowering : public TargetLoweringBase {
     llvm_unreachable("Not Implemented");
   }
 
+  /// This structure contains the information necessary for lowering
+  /// pointer-authenticating indirect calls.  It is equivalent to the "ptrauth"
+  /// operand bundle found on the call instruction, if any.
+  struct PtrAuthInfo {
+    uint64_t Key;
+    SDValue Discriminator;
+  };
+
   /// This structure contains all information that is necessary for lowering
   /// calls. It is passed to TLI::LowerCallTo when the SelectionDAG builder
   /// needs to lower a call, and targets will see this struct in their LowerCall
@@ -4440,6 +4451,8 @@ class TargetLowering : public TargetLoweringBase {
     const ConstantInt *CFIType = nullptr;
     SDValue ConvergenceControlToken;
 
+    std::optional<PtrAuthInfo> PAI;
+
     CallLoweringInfo(SelectionDAG &DAG)
         : RetSExt(false), RetZExt(false), IsVarArg(false), IsInReg(false),
           DoesNotReturn(false), IsReturnValueUsed(true), IsConvergent(false),
@@ -4562,6 +4575,11 @@ class TargetLowering : public TargetLoweringBase {
       return *this;
     }
 
+    CallLoweringInfo &setPtrAuth(PtrAuthInfo Value) {
+      PAI = Value;
+      return *this;
+    }
+
     CallLoweringInfo &setIsPostTypeLegalization(bool Value=true) {
       IsPostTypeLegalization = Value;
       return *this;
diff --git a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
index 363fad53b76c35..740a00d8afdd49 100644
--- a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
@@ -92,6 +92,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
                              ArrayRef<Register> ResRegs,
                              ArrayRef<ArrayRef<Register>> ArgRegs,
                              Register SwiftErrorVReg,
+                             std::optional<PointerAuthInfo> PAI,
                              Register ConvergenceCtrlToken,
                              std::function<unsigned()> GetCalleeReg) const {
   CallLoweringInfo Info;
@@ -188,6 +189,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
   Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees);
   Info.CallConv = CallConv;
   Info.SwiftErrorVReg = SwiftErrorVReg;
+  Info.PAI = PAI;
   Info.ConvergenceCtrlToken = ConvergenceCtrlToken;
   Info.IsMustTailCall = CB.isMustTailCall();
   Info.IsTailCall = CanBeTailCalled;
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 47ee2ee507137e..0be2597bf7c54b 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2615,6 +2615,20 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
     }
   }
 
+  std::optional<CallLowering::PointerAuthInfo> PAI;
+  if (CB.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
+    // Functions should never be ptrauth-called directly.
+    assert(!CB.getCalledFunction() && "invalid direct ptrauth call");
+
+    auto PAB = CB.getOperandBundle("ptrauth");
+    Value *Key = PAB->Inputs[0];
+    Value *Discriminator = PAB->Inputs[1];
+
+    Register DiscReg = getOrCreateVReg(*Discriminator);
+    PAI = CallLowering::PointerAuthInfo{DiscReg,
+                                        cast<ConstantInt>(Key)->getZExtValue()};
+  }
+
   Register ConvergenceCtrlToken = 0;
   if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_convergencectrl)) {
     const auto &Token = *Bundle->Inputs[0].get();
@@ -2625,7 +2639,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
   // optimize into tail calls. Instead, we defer that to selection where a final
   // scan is done to check if any instructions are calls.
   bool Success = CLI->lowerCall(
-      MIRBuilder, CB, Res, Args, SwiftErrorVReg, ConvergenceCtrlToken,
+      MIRBuilder, CB, Res, Args, SwiftErrorVReg, PAI, ConvergenceCtrlToken,
       [&]() { return getOrCreateVReg(*CB.getCalledOperand()); });
 
   // Check if we just inserted a tail call.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index f1923a64368f4f..b60be1bb77212e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3300,12 +3300,12 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) {
   const BasicBlock *EHPadBB = I.getSuccessor(1);
   MachineBasicBlock *EHPadMBB = FuncInfo.MBBMap[EHPadBB];
 
-  // Deopt bundles are lowered in LowerCallSiteWithDeoptBundle, and we don't
+  // Deopt and ptrauth bundles are lowered in helper functions, and we don't
   // have to do anything here to lower funclet bundles.
   assert(!I.hasOperandBundlesOtherThan(
              {LLVMContext::OB_deopt, LLVMContext::OB_gc_transition,
               LLVMContext::OB_gc_live, LLVMContext::OB_funclet,
-              LLVMContext::OB_cfguardtarget,
+              LLVMContext::OB_cfguardtarget, LLVMContext::OB_ptrauth,
               LLVMContext::OB_clang_arc_attachedcall}) &&
          "Cannot lower invokes with arbitrary operand bundles yet!");
 
@@ -3356,6 +3356,8 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) {
     // intrinsic, and right now there are no plans to support other intrinsics
     // with deopt state.
     LowerCallSiteWithDeoptBundle(&I, getValue(Callee), EHPadBB);
+  } else if (I.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
+    LowerCallSiteWithPtrAuthBundle(cast<CallBase>(I), EHPadBB);
   } else {
     LowerCallTo(I, getValue(Callee), false, false, EHPadBB);
   }
@@ -8508,9 +8510,9 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
 }
 
 void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
-                                      bool isTailCall,
-                                      bool isMustTailCall,
-                                      const BasicBlock *EHPadBB) {
+                                      bool isTailCall, bool isMustTailCall,
+                                      const BasicBlock *EHPadBB,
+                                      const TargetLowering::PtrAuthInfo *PAI) {
   auto &DL = DAG.getDataLayout();
   FunctionType *FTy = CB.getFunctionType();
   Type *RetTy = CB.getType();
@@ -8619,6 +8621,15 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
           CB.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0)
       .setCFIType(CFIType)
       .setConvergenceControlToken(ConvControlToken);
+
+  // Set the pointer authentication info if we have it.
+  if (PAI) {
+    if (!TLI.supportPtrAuthBundles())
+      report_fatal_error(
+          "This target doesn't support calls with ptrauth operand bundles.");
+    CLI.setPtrAuth(*PAI);
+  }
+
   std::pair<SDValue, SDValue> Result = lowerInvokable(CLI, EHPadBB);
 
   if (Result.first.getNode()) {
@@ -9164,6 +9175,11 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
     }
   }
 
+  if (I.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
+    LowerCallSiteWithPtrAuthBundle(cast<CallBase>(I), /*EHPadBB=*/nullptr);
+    return;
+  }
+
   // Deopt bundles are lowered in LowerCallSiteWithDeoptBundle, and we don't
   // have to do anything here to lower funclet bundles.
   // CFGuardTarget bundles are lowered in LowerCallTo.
@@ -9185,6 +9201,31 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
     LowerCallTo(I, Callee, I.isTailCall(), I.isMustTailCall());
 }
 
+void SelectionDAGBuilder::LowerCallSiteWithPtrAuthBundle(
+    const CallBase &CB, const BasicBlock *EHPadBB) {
+  auto PAB = CB.getOperandBundle("ptrauth");
+  auto *CalleeV = CB.getCalledOperand();
+
+  // Gather the call ptrauth data from the operand bundle:
+  //   [ i32 <key>, i64 <discriminator> ]
+  auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
+  Value *Discriminator = PAB->Inputs[1];
+
+  assert(Key->getType()->isIntegerTy(32) && "Invalid ptrauth key");
+  assert(Discriminator->getType()->isIntegerTy(64) &&
+         "Invalid ptrauth discriminator");
+
+  // Functions should never be ptrauth-called directly.
+  assert(!isa<Function>(CalleeV) && "invalid direct ptrauth call");
+
+  // Otherwise, do an authenticated indirect call.
+  TargetLowering::PtrAuthInfo PAI = {Key->getZExtValue(),
+                                     getValue(Discriminator)};
+
+  LowerCallTo(CB, getValue(CalleeV), CB.isTailCall(), CB.isMustTailCall(),
+              EHPadBB, &PAI);
+}
+
 namespace {
 
 /// AsmOperandInfo - This contains information for each constraint that we are
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
index dcf46e0563ff9d..6a3d0276cfd456 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
@@ -406,7 +406,8 @@ class SelectionDAGBuilder {
   void CopyToExportRegsIfNeeded(const Value *V);
   void ExportFromCurrentBlock(const Value *V);
   void LowerCallTo(const CallBase &CB, SDValue Callee, bool IsTailCall,
-                   bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr);
+                   bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr,
+                   const TargetLowering::PtrAuthInfo *PAI = nullptr);
 
   // Lower range metadata from 0 to N to assert zext to an integer of nearest
   // floor power of two.
@@ -490,6 +491,9 @@ class SelectionDAGBuilder {
                                         bool VarArgDisallowed,
                                         bool ForceVoidReturnTy);
 
+  void LowerCallSiteWithPtrAuthBundle(const CallBase &CB,
+                                      const BasicBlock *EHPadBB);
+
   /// Returns the type of FrameIndex and TargetFrameIndex nodes.
   MVT getFrameIndexTy() {
     return DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout());
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 6d34e16fc43401..e4c9503b42e8ef 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -133,6 +133,8 @@ class AArch64AsmPrinter : public AsmPrinter {
 
   void emitSled(const MachineInstr &MI, SledKind Kind);
 
+  // Emit the sequence for BLRA (authenticate + branch).
+  void emitPtrauthBranch(const MachineInstr *MI);
   // Emit the sequence for AUT or AUTPAC.
   void emitPtrauthAuthResign(const MachineInstr *MI);
   // Emit the sequence to compute a discriminator into x17, or reuse AddrDisc.
@@ -1482,6 +1484,10 @@ void AArch64AsmPrinter::emitFMov0(const MachineInstr &MI) {
 unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
                                                      unsigned AddrDisc,
                                                      unsigned &InstsEmitted) {
+  // So far we've used NoRegister in pseudos.  Now we need real encodings.
+  if (AddrDisc == AArch64::NoRegister)
+    AddrDisc = AArch64::XZR;
+
   // If there is no constant discriminator, there's no blend involved:
   // just use the address discriminator register as-is (XZR or not).
   if (!Disc)
@@ -1729,6 +1735,39 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
     OutStreamer->emitLabel(EndSym);
 }
 
+void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
+  unsigned InstsEmitted = 0;
+
+  unsigned BrTarget = MI->getOperand(0).getReg();
+  auto Key = (AArch64PACKey::ID)MI->getOperand(1).getImm();
+  uint64_t Disc = MI->getOperand(2).getImm();
+  unsigned AddrDisc = MI->getOperand(3).getReg();
+
+  // Compute discriminator into x17
+  assert(isUInt<16>(Disc));
+  unsigned DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, InstsEmitted);
+  bool IsZeroDisc = DiscReg == AArch64::XZR;
+
+  assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
+         "Invalid auth call key");
+
+  unsigned Opc;
+  if (Key == AArch64PACKey::IA)
+    Opc = IsZeroDisc ? AArch64::BLRAAZ : AArch64::BLRAA;
+  else
+    Opc = IsZeroDisc ? AArch64::BLRABZ : AArch64::BLRAB;
+
+  MCInst BRInst;
+  BRInst.setOpcode(Opc);
+  BRInst.addOperand(MCOperand::createReg(BrTarget));
+  if (!IsZeroDisc)
+    BRInst.addOperand(MCOperand::createReg(DiscReg));
+  EmitToStreamer(*OutStreamer, BRInst);
+  ++InstsEmitted;
+
+  assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
+}
+
 // Simple pseudo-instructions have their lowering (with expansion to real
 // instructions) auto-generated.
 #include "AArch64GenMCPseudoLowering.inc"
@@ -1869,9 +1908,60 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
     emitPtrauthAuthResign(MI);
     return;
 
+  case AArch64::BLRA:
+    emitPtrauthBranch(MI);
+    return;
+
   // Tail calls use pseudo instructions so they have the proper code-gen
   // attributes (isCall, isReturn, etc.). We lower them to the real
   // instruction here.
+  case AArch64::AUTH_TCRETURN:
+  case AArch64::AUTH_TCRETURN_BTI: {
+    const uint64_t Key = MI->getOperand(2).getImm();
+    assert(Key < 2 && "Unknown key kind for authenticating tail-call return");
+    const uint64_t Disc = MI->getOperand(3).getImm();
+    Register AddrDisc = MI->getOperand(4).getReg();
+
+    Register ScratchReg = MI->getOperand(0).getReg() == AArch64::X16
+                              ? AArch64::X17
+                              : AArch64::X16;
+
+    unsigned DiscReg = AddrDisc;
+    if (Disc) {
+      assert(isUInt<16>(Disc) && "Integer discriminator is too wide");
+
+      if (AddrDisc != AArch64::NoRegister) {
+        EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ORRXrs)
+                                         .addReg(ScratchReg)
+                                         .addReg(AArch64::XZR)
+                                         .addReg(AddrDisc)
+                                         .addImm(0));
+        EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVKXi)
+                                         .addReg(ScratchReg)
+                                         .addReg(ScratchReg)
+                                         .addImm(Disc)
+                                         .addImm(/*shift=*/48));
+      } else {
+        EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVZXi)
+                                         .addReg(ScratchReg)
+                                         .addImm(Disc)
+                                         .addImm(/*shift=*/0));
+      }
+      DiscReg = ScratchReg;
+    }
+
+    const bool isZero = DiscReg == AArch64::NoRegister;
+    const unsigned Opcodes[2][2] = {{AArch64::BRAA, AArch64::BRAAZ},
+                                    {AArch64::BRAB, AArch64::BRABZ}};
+
+    MCInst TmpInst;
+    TmpInst.setOpcode(Opcodes[Key][isZero]);
+    TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
+    if (!isZero)
+      TmpInst.addOperand(MCOperand::createReg(DiscReg));
+    EmitToStreamer(*OutStreamer, TmpInst);
+    return;
+  }
   case AArch64::TCRETURNri:
   case AArch64::TCRETURNrix16x17:
   case AArch64::TCRETURNrix17:
diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index 03f0778bae59d5..657324d2307c58 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -817,10 +817,44 @@ bool AArch64ExpandPseudo::expandCALL_RVMARKER(
   MachineInstr &MI = *MBBI;
   MachineOperand &RVTarget = MI.getOperand(0);
   assert(RVTarget.isGlobal() && "invalid operand for attached call");
-  MachineInstr *OriginalCall =
-      createCall(MBB, MBBI, TII, MI.getOperand(1),
-                 // Regmask starts after the RV and call targets.
-                 /*RegMaskStartIdx=*/2);
+
+  MachineInstr *OriginalCall = nullptr;
+
+  if (MI.getOpcode() == AArch64::BLRA_RVMARKER) {
+    // Pointer auth call.
+    MachineOperand &Key = MI.getOperand(2);
+    assert((Key.getImm() == 0 || Key.getImm() == 1) &&
+           "invalid key for ptrauth call");
+    MachineOperand &IntDisc = MI.getOperand(3);
+    MachineOperand &AddrDisc = MI.getOperand(4);
+
+    OriginalCall = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::BLRA))
+                       .getInstr();
+    OriginalCall->addOperand(MI.getOperand(1));
+    OriginalCall->addOperand(Key);
+    OriginalCall->addOperand(IntDisc);
+    OriginalCall->addOperand(AddrDisc);
+
+    unsigned RegMaskStartIdx = 5;
+    // Skip register arguments. Those are added during ISel, but are not
+    // needed for the concrete branch.
+    while (!MI.getOperand(RegMaskStartIdx).isRegMask()) {
+      auto MOP = MI.getOperand(RegMaskStartIdx);
+      assert(MOP.isReg() && "can only add register operands");
+      OriginalCall->addOperand(MachineOperand::CreateReg(
+          MOP.getReg(), /*Def=*/false, /*Implicit=*/true, /*isKill=*/false,
+          /*isDead=*/false, /*isUndef=*/MOP.isUndef()));
+      RegMaskStartIdx++;
+    }
+    for (const MachineOperand &MO :
+         llvm::drop_begin(MI.operands(), RegMaskStartIdx))
+      OriginalCall->addOperand(MO);
+  } else {
+    assert(MI.getOpcode() == AArch64::BLR_RVMARKER && "unknown rvmarker MI");
+    OriginalCall = createCall(MBB, MBBI, TII, MI.getOperand(1),
+                              // Regmask starts after the RV and call targets.
+                              /*RegMaskStartIdx=*/2);
+  }
 
   BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::ORRXrs))
                      .addReg(AArch64::FP, RegState::Define)
@@ -1529,6 +1563,7 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
    case AArch64::LDR_PPXI:
      return expandSVESpillFill(MBB, MBBI, AArch64::LDR_PXI, 2);
    case AArch64::BLR_RVMARKER:
+   case AArch64::BLRA_RVMARKER:
      return expandCALL_RVMARKER(MBB, MBBI);
    case AArch64::BLR_BTI:
      return expandCALL_BTI(MBB, MBBI);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 92f59dbf1e9e52..cb3dd5cf0b22f9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -329,6 +329,40 @@ static bool isZeroingInactiveLanes(SDValue Op) {
   }
 }
 
+static std::tuple<SDValue, SDValue>
+extractPtrauthBlendDiscriminators(SDValue Disc, SelectionDAG *DAG) {
+  SDLoc DL(Disc);
+  SDValue AddrDisc;
+  SDValue ConstDisc;
+
+  // If this is a blend, remember the cons...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/85736


More information about the llvm-branch-commits mailing list