[llvm] [AArch64][PAC] Select auth+load into LDRAA/LDRAB/LDRA[pre]. (PR #123769)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 21 07:46:19 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Ahmed Bougacha (ahmedbougacha)

<details>
<summary>Changes</summary>

This can lower loads of a ptrauth.auth base into a fixed sequence that doesn't allow the raw intermediate value to be exposed.

It's based on the AArch64 LDRAA/LDRAB instructions, but as those have limited encodings (in particular, small immediate offsets, and only zero discriminators), it generalizes them with a LDRA pseudo.

It handles arbitrary ptrauth schemas on the authentication, materializing the integer constant discriminator and blending it with an address discriminator if needed.

It handles arbitrary offsets (applied after the authentication).

It also handles pre-indexing with writeback, either writing back the authentication result alone if the offset is 0, or both authentication and offset addition otherwise.

At ISel time, the real LDRAA family of instructions is selected when possible, to avoid needlessly constraining regalloc with X16/X17. After ISel, the LDRA pseudos are expanded in AsmPrinter, into either of:
- writeback, 0 offset (we already wrote the AUT result): LDRXui
- no wb, uimm12s8 offset (including 0): LDRXui
- no wb, simm9 offset: LDURXi
- pre-indexed wb, simm9 offset: LDRXpre
- no wb, any offset: expanded MOVImm + LDRXroX
- pre-indexed wb, any offset: expanded MOVImm + ADD + LDRXui

Though the main intended optimization target is vtable-like codegen, where both the base vtable pointer is signed, as well as its entries, at small fixed offsets.  This does benefit from writeback, hence the ISel complexity to support that, as it's otherwise unlikely to be worthwhile.

GlobalISel would benefit from further optimization, as this lowering conflicts with the generic indexed lowering there.

I did a pass to refresh the old patch, but it's been a while, please let me know if I missed a spot!

---

Patch is 47.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123769.diff


8 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+114) 
- (modified) llvm/lib/Target/AArch64/AArch64Combine.td (+9) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp (+163-2) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrGISel.td (+17) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+36) 
- (modified) llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp (+71) 
- (modified) llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp (+115) 
- (added) llvm/test/CodeGen/AArch64/ptrauth-load.ll (+716) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 27e65d60122fd7..b3876ff4862e1d 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "AArch64.h"
+#include "AArch64ExpandImm.h"
 #include "AArch64MCInstLower.h"
 #include "AArch64MachineFunctionInfo.h"
 #include "AArch64RegisterInfo.h"
@@ -204,6 +205,9 @@ class AArch64AsmPrinter : public AsmPrinter {
   // authenticating)
   void LowerLOADgotAUTH(const MachineInstr &MI);
 
+  // Emit the sequence for LDRA (auth + load from authenticated base).
+  void LowerPtrauthAuthLoad(const MachineInstr &MI);
+
   /// tblgen'erated driver function for lowering simple MI->MC
   /// pseudo instructions.
   bool lowerPseudoInstExpansion(const MachineInstr *MI, MCInst &Inst);
@@ -2159,6 +2163,111 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
   EmitToStreamer(*OutStreamer, BRInst);
 }
 
+void AArch64AsmPrinter::LowerPtrauthAuthLoad(const MachineInstr &MI) {
+  const bool IsPreWB = MI.getOpcode() == AArch64::LDRApre;
+
+  const unsigned DstReg = MI.getOperand(0).getReg();
+  const int64_t Offset = MI.getOperand(1).getImm();
+  const auto Key = (AArch64PACKey::ID)MI.getOperand(2).getImm();
+  const uint64_t Disc = MI.getOperand(3).getImm();
+  const unsigned AddrDisc = MI.getOperand(4).getReg();
+
+  Register DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, AArch64::X17);
+
+  unsigned AUTOpc = getAUTOpcodeForKey(Key, DiscReg == AArch64::XZR);
+  auto MIB = MCInstBuilder(AUTOpc).addReg(AArch64::X16).addReg(AArch64::X16);
+  if (DiscReg != AArch64::XZR)
+    MIB.addReg(DiscReg);
+
+  EmitToStreamer(MIB);
+
+  // We have a few options for offset folding:
+  // - writeback, 0 offset (we already wrote the AUT result): LDRXui
+  // - no wb, uimm12s8 offset (including 0): LDRXui
+  if (!Offset || (!IsPreWB && isShiftedUInt<12, 3>(Offset))) {
+    EmitToStreamer(MCInstBuilder(AArch64::LDRXui)
+                       .addReg(DstReg)
+                       .addReg(AArch64::X16)
+                       .addImm(Offset / 8));
+    return;
+  }
+
+  // - no wb, simm9 offset: LDURXi
+  if (!IsPreWB && isInt<9>(Offset)) {
+    EmitToStreamer(MCInstBuilder(AArch64::LDURXi)
+                       .addReg(DstReg)
+                       .addReg(AArch64::X16)
+                       .addImm(Offset));
+    return;
+  }
+
+  // - pre-indexed wb, simm9 offset: LDRXpre
+  if (IsPreWB && isInt<9>(Offset)) {
+    EmitToStreamer(MCInstBuilder(AArch64::LDRXpre)
+                       .addReg(AArch64::X16)
+                       .addReg(DstReg)
+                       .addReg(AArch64::X16)
+                       .addImm(Offset));
+    return;
+  }
+
+  // Finally, in the general case, we need a MOVimm either way.
+  SmallVector<AArch64_IMM::ImmInsnModel, 4> ImmInsns;
+  AArch64_IMM::expandMOVImm(Offset, 64, ImmInsns);
+
+  // X17 is dead at this point, use it as the offset register
+  for (auto &ImmI : ImmInsns) {
+    switch (ImmI.Opcode) {
+    default:
+      llvm_unreachable("invalid ldra imm expansion opc!");
+      break;
+
+    case AArch64::ORRXri:
+      EmitToStreamer(MCInstBuilder(ImmI.Opcode)
+                         .addReg(AArch64::X17)
+                         .addReg(AArch64::XZR)
+                         .addImm(ImmI.Op2));
+      break;
+    case AArch64::MOVNXi:
+    case AArch64::MOVZXi:
+      EmitToStreamer(MCInstBuilder(ImmI.Opcode)
+                         .addReg(AArch64::X17)
+                         .addImm(ImmI.Op1)
+                         .addImm(ImmI.Op2));
+      break;
+    case AArch64::MOVKXi:
+      EmitToStreamer(MCInstBuilder(ImmI.Opcode)
+                         .addReg(AArch64::X17)
+                         .addReg(AArch64::X17)
+                         .addImm(ImmI.Op1)
+                         .addImm(ImmI.Op2));
+      break;
+    }
+  }
+
+  // - no wb, any offset: expanded MOVImm + LDRXroX
+  if (!IsPreWB) {
+    EmitToStreamer(MCInstBuilder(AArch64::LDRXroX)
+                       .addReg(DstReg)
+                       .addReg(AArch64::X16)
+                       .addReg(AArch64::X17)
+                       .addImm(0)
+                       .addImm(0));
+    return;
+  }
+
+  // - pre-indexed wb, any offset: expanded MOVImm + ADD + LDRXui
+  EmitToStreamer(MCInstBuilder(AArch64::ADDXrs)
+                     .addReg(AArch64::X16)
+                     .addReg(AArch64::X16)
+                     .addReg(AArch64::X17)
+                     .addImm(0));
+  EmitToStreamer(MCInstBuilder(AArch64::LDRXui)
+                     .addReg(DstReg)
+                     .addReg(AArch64::X16)
+                     .addImm(0));
+}
+
 const MCExpr *
 AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) {
   MCContext &Ctx = OutContext;
@@ -2698,6 +2807,11 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
     LowerLOADgotAUTH(*MI);
     return;
 
+  case AArch64::LDRA:
+  case AArch64::LDRApre:
+    LowerPtrauthAuthLoad(*MI);
+    return;
+
   case AArch64::BRA:
   case AArch64::BLRA:
     emitPtrauthBranch(MI);
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index ce1980697abbbb..f1853f2b8cc5ad 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -255,6 +255,14 @@ def form_truncstore : GICombineRule<
   (apply [{ applyFormTruncstore(*${root}, MRI, B, Observer, ${matchinfo}); }])
 >;
 
+def form_auth_load_matchdata : GIDefMatchData<"AuthLoadMatchInfo">;
+def form_auth_load : GICombineRule<
+  (defs root:$root, form_auth_load_matchdata:$matchinfo),
+  (match (wip_match_opcode G_LOAD):$root,
+         [{ return matchFormAuthLoad(*${root}, MRI, Helper, ${matchinfo}); }]),
+  (apply [{ applyFormAuthLoad(*${root}, MRI, B, Helper, Observer, ${matchinfo}); }])
+>;
+
 def fold_merge_to_zext : GICombineRule<
   (defs root:$d),
   (match (wip_match_opcode G_MERGE_VALUES):$d,
@@ -315,6 +323,7 @@ def AArch64PostLegalizerLowering
                        [shuffle_vector_lowering, vashr_vlshr_imm,
                         icmp_lowering, build_vector_lowering,
                         lower_vector_fcmp, form_truncstore,
+                        form_auth_load,
                         vector_sext_inreg_to_shift,
                         unmerge_ext_to_unmerge, lower_mull,
                         vector_unmerge_lowering, insertelt_nonconst]> {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 6aa8cd4f0232ac..8660b2d0bc8e6f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -361,6 +361,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
 
   bool tryIndexedLoad(SDNode *N);
 
+  bool tryAuthLoad(SDNode *N);
+
   void SelectPtrauthAuth(SDNode *N);
   void SelectPtrauthResign(SDNode *N);
 
@@ -1671,6 +1673,163 @@ bool AArch64DAGToDAGISel::tryIndexedLoad(SDNode *N) {
   return true;
 }
 
+bool AArch64DAGToDAGISel::tryAuthLoad(SDNode *N) {
+  LoadSDNode *LD = cast<LoadSDNode>(N);
+  EVT VT = LD->getMemoryVT();
+  if (VT != MVT::i64)
+    return false;
+
+  assert(LD->getExtensionType() == ISD::NON_EXTLOAD && "invalid 64bit extload");
+
+  ISD::MemIndexedMode AM = LD->getAddressingMode();
+  if (AM != ISD::PRE_INC && AM != ISD::UNINDEXED)
+    return false;
+  bool IsPre = AM == ISD::PRE_INC;
+
+  SDValue Chain = LD->getChain();
+  SDValue Ptr = LD->getBasePtr();
+
+  SDValue Base = Ptr;
+
+  int64_t OffsetVal = 0;
+  if (IsPre) {
+    OffsetVal = cast<ConstantSDNode>(LD->getOffset())->getSExtValue();
+  } else if (CurDAG->isBaseWithConstantOffset(Base)) {
+    // We support both 'base' and 'base + constant offset' modes.
+    ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Base.getOperand(1));
+    if (!RHS)
+      return false;
+    OffsetVal = RHS->getSExtValue();
+    Base = Base.getOperand(0);
+  }
+
+  // The base must be of the form:
+  //   (int_ptrauth_auth <signedbase>, <key>, <disc>)
+  // with disc being either a constant int, or:
+  //   (int_ptrauth_blend <addrdisc>, <const int disc>)
+  if (Base.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+    return false;
+
+  unsigned IntID = cast<ConstantSDNode>(Base.getOperand(0))->getZExtValue();
+  if (IntID != Intrinsic::ptrauth_auth)
+    return false;
+
+  unsigned KeyC = cast<ConstantSDNode>(Base.getOperand(2))->getZExtValue();
+  bool IsDKey = KeyC == AArch64PACKey::DA || KeyC == AArch64PACKey::DB;
+  SDValue Disc = Base.getOperand(3);
+
+  Base = Base.getOperand(1);
+
+  bool ZeroDisc = isNullConstant(Disc);
+  SDValue IntDisc, AddrDisc;
+  std::tie(IntDisc, AddrDisc) = extractPtrauthBlendDiscriminators(Disc, CurDAG);
+
+  // If this is an indexed pre-inc load, we obviously need the writeback form.
+  bool needsWriteback = IsPre;
+  // If not, but the base authenticated pointer has any other use, it's
+  // beneficial to use the writeback form, to "writeback" the auth, even if
+  // there is no base+offset addition.
+  if (!Ptr.hasOneUse()) {
+    needsWriteback = true;
+
+    // However, we can only do that if we don't introduce cycles between the
+    // load node and any other user of the pointer computation nodes.  That can
+    // happen if the load node uses any of said other users.
+    // In other words: we can only do this transformation if none of the other
+    // uses of the pointer computation to be folded are predecessors of the load
+    // we're folding into.
+    //
+    // Visited is a cache containing nodes that are known predecessors of N.
+    // Worklist is the set of nodes we're looking for predecessors of.
+    // For the first lookup, that only contains the load node N.  Each call to
+    // hasPredecessorHelper adds any of the potential predecessors of N to the
+    // Worklist.
+    SmallPtrSet<const SDNode *, 32> Visited;
+    SmallVector<const SDNode *, 16> Worklist;
+    Worklist.push_back(N);
+    for (SDNode *U : Ptr.getNode()->users())
+      if (SDNode::hasPredecessorHelper(U, Visited, Worklist, /*Max=*/32,
+                                       /*TopologicalPrune=*/true))
+        return false;
+  }
+
+  // We have 2 main isel alternatives:
+  // - LDRAA/LDRAB, writeback or indexed.  Zero disc, small offsets, D key.
+  // - LDRA/LDRApre.  Pointer needs to be in X16.
+  SDLoc DL(N);
+  MachineSDNode *Res = nullptr;
+  SDValue Writeback, ResVal, OutChain;
+
+  // If the discriminator is zero and the offset fits, we can use LDRAA/LDRAB.
+  // Do that here to avoid needlessly constraining regalloc into using X16.
+  if (ZeroDisc && isShiftedInt<10, 3>(OffsetVal) && IsDKey) {
+    unsigned Opc = 0;
+    switch (KeyC) {
+    case AArch64PACKey::DA:
+      Opc = needsWriteback ? AArch64::LDRAAwriteback : AArch64::LDRAAindexed;
+      break;
+    case AArch64PACKey::DB:
+      Opc = needsWriteback ? AArch64::LDRABwriteback : AArch64::LDRABindexed;
+      break;
+    default:
+      llvm_unreachable("Invalid key for LDRAA/LDRAB");
+    }
+    // The offset is encoded as scaled, for an element size of 8 bytes.
+    SDValue Offset = CurDAG->getTargetConstant(OffsetVal / 8, DL, MVT::i64);
+    SDValue Ops[] = {Base, Offset, Chain};
+    Res = needsWriteback
+              ? CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::i64, MVT::Other,
+                                       Ops)
+              : CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::Other, Ops);
+    if (needsWriteback) {
+      Writeback = SDValue(Res, 0);
+      ResVal = SDValue(Res, 1);
+      OutChain = SDValue(Res, 2);
+    } else {
+      ResVal = SDValue(Res, 0);
+      OutChain = SDValue(Res, 1);
+    }
+  } else {
+    // Otherwise, use the generalized LDRA pseudos.
+    unsigned Opc = needsWriteback ? AArch64::LDRApre : AArch64::LDRA;
+
+    SDValue X16Copy =
+        CurDAG->getCopyToReg(Chain, DL, AArch64::X16, Base, SDValue());
+    SDValue Offset = CurDAG->getTargetConstant(OffsetVal, DL, MVT::i64);
+    SDValue Key = CurDAG->getTargetConstant(KeyC, DL, MVT::i32);
+    SDValue Ops[] = {Offset, Key, IntDisc, AddrDisc, X16Copy.getValue(1)};
+    Res = CurDAG->getMachineNode(Opc, DL, MVT::i64, MVT::Other, MVT::Glue, Ops);
+    if (needsWriteback)
+      Writeback = CurDAG->getCopyFromReg(SDValue(Res, 1), DL, AArch64::X16,
+                                         MVT::i64, SDValue(Res, 2));
+    ResVal = SDValue(Res, 0);
+    OutChain = SDValue(Res, 1);
+  }
+
+  if (IsPre) {
+    // If the original load was pre-inc, the resulting LDRA is writeback.
+    assert(needsWriteback && "preinc loads can't be selected into non-wb ldra");
+    ReplaceUses(SDValue(N, 1), Writeback); // writeback
+    ReplaceUses(SDValue(N, 0), ResVal);    // loaded value
+    ReplaceUses(SDValue(N, 2), OutChain);  // chain
+  } else if (needsWriteback) {
+    // If the original load was unindexed, but we emitted a writeback form,
+    // we need to replace the uses of the original auth(signedbase)[+offset]
+    // computation.
+    ReplaceUses(Ptr, Writeback);          // writeback
+    ReplaceUses(SDValue(N, 0), ResVal);   // loaded value
+    ReplaceUses(SDValue(N, 1), OutChain); // chain
+  } else {
+    // Otherwise, we selected a simple load to a simple non-wb ldra.
+    assert(Ptr.hasOneUse() && "reused auth ptr should be folded into ldra");
+    ReplaceUses(SDValue(N, 0), ResVal);   // loaded value
+    ReplaceUses(SDValue(N, 1), OutChain); // chain
+  }
+
+  CurDAG->RemoveDeadNode(N);
+  return true;
+}
+
 void AArch64DAGToDAGISel::SelectLoad(SDNode *N, unsigned NumVecs, unsigned Opc,
                                      unsigned SubRegIdx) {
   SDLoc dl(N);
@@ -4643,8 +4802,10 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
     break;
 
   case ISD::LOAD: {
-    // Try to select as an indexed load. Fall through to normal processing
-    // if we can't.
+    // Try to select as an indexed or authenticating load. Fall through to
+    // normal processing if we can't.
+    if (tryAuthLoad(Node))
+      return;
     if (tryIndexedLoad(Node))
       return;
     break;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrGISel.td b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
index 2d2b2bee99ec41..1b544c4f8c19a6 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrGISel.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrGISel.td
@@ -25,6 +25,23 @@ def G_ADD_LOW : AArch64GenericInstruction {
   let hasSideEffects = 0;
 }
 
+// Represents an auth-load instruction.  Produced post-legalization from
+// G_LOADs of ptrauth_auth intrinsics, with variants for keys/discriminators.
+def G_LDRA : AArch64GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type1:$addr, i64imm:$offset, i32imm:$key, i64imm:$disc, type0:$addrdisc);
+  let hasSideEffects = 0;
+  let mayLoad = 1;
+}
+
+// Represents a pre-inc writeback auth-load instruction.  Similar to G_LDRA.
+def G_LDRApre : AArch64GenericInstruction {
+  let OutOperandList = (outs type0:$dst, ptype1:$newaddr);
+  let InOperandList = (ins ptype1:$addr, i64imm:$offset, i32imm:$key, i64imm:$disc, type0:$addrdisc);
+  let hasSideEffects = 0;
+  let mayLoad = 1;
+}
+
 // Pseudo for a rev16 instruction. Produced post-legalization from
 // G_SHUFFLE_VECTORs with appropriate masks.
 def G_REV16 : AArch64GenericInstruction {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 8e575abf83d449..44be2fe00f0aa7 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1973,6 +1973,42 @@ let Predicates = [HasPAuth] in {
     let Size = 8;
   }
 
+  // LDRA pseudo: generalized LDRAA/Bindexed, allowing arbitrary discriminators,
+  // and wider offsets.
+  // This directly manipulates x16/x17, which are the only registers the OS
+  // guarantees are safe to use for sensitive operations.
+  // The loaded value is in $Rt.  The signed pointer is in X16.
+  // $Rt could be GPR64 but is GPR64noip to help out regalloc: we imp-def 2/3rds
+  // of the difference between the two, and the 3rd reg (LR) is often reserved.
+  def LDRA : Pseudo<(outs GPR64noip:$Rt),
+                    (ins i64imm:$Offset, i32imm:$Key, i64imm:$Disc,
+                         GPR64noip:$AddrDisc),
+                    []>, Sched<[]> {
+    let isCodeGenOnly = 1;
+    let hasSideEffects = 1;
+    let mayStore = 0;
+    let mayLoad = 1;
+    let Size = 48;
+    let Defs = [X16,X17];
+    let Uses = [X16];
+  }
+
+  // Pre-indexed + writeback variant of LDRA.
+  // The signed pointer is in X16, and is written back, after being
+  // authenticated and offset, into X16.
+  def LDRApre : Pseudo<(outs GPR64noip:$Rt),
+                       (ins i64imm:$Offset, i32imm:$Key, i64imm:$Disc,
+                            GPR64noip:$AddrDisc),
+                    []>, Sched<[]> {
+    let isCodeGenOnly = 1;
+    let hasSideEffects = 1;
+    let mayStore = 0;
+    let mayLoad = 1;
+    let Size = 48;
+    let Defs = [X16,X17];
+    let Uses = [X16];
+  }
+
   // Size 16: 4 fixed + 8 variable, to compute discriminator.
   // The size returned by getInstSizeInBytes() is incremented according
   // to the variant of LR check.
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 07f03644336cdd..87158df0b75c2a 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -225,6 +225,7 @@ class AArch64InstructionSelector : public InstructionSelector {
   bool selectTLSGlobalValue(MachineInstr &I, MachineRegisterInfo &MRI);
   bool selectPtrAuthGlobalValue(MachineInstr &I,
                                 MachineRegisterInfo &MRI) const;
+  bool selectAuthLoad(MachineInstr &I, MachineRegisterInfo &MRI);
   bool selectReduction(MachineInstr &I, MachineRegisterInfo &MRI);
   bool selectMOPS(MachineInstr &I, MachineRegisterInfo &MRI);
   bool selectUSMovFromExtend(MachineInstr &I, MachineRegisterInfo &MRI);
@@ -2992,6 +2993,10 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
   case TargetOpcode::G_PTRAUTH_GLOBAL_VALUE:
     return selectPtrAuthGlobalValue(I, MRI);
 
+  case AArch64::G_LDRA:
+  case AArch64::G_LDRApre:
+    return selectAuthLoad(I, MRI);
+
   case TargetOpcode::G_ZEXTLOAD:
   case TargetOpcode::G_LOAD:
   case TargetOpcode::G_STORE: {
@@ -6976,6 +6981,72 @@ bool AArch64InstructionSelector::selectPtrAuthGlobalValue(
   return true;
 }
 
+bool AArch64InstructionSelector::selectAuthLoad(MachineInstr &I,
+                                                MachineRegisterInfo &MRI) {
+  bool Writeback = I.getOpcode() == AArch64::G_LDRApre;
+
+  Register ValReg = I.getOperand(0).getReg();
+  Register PtrReg = I.getOperand(1 + Writeback).getReg();
+  int64_t Offset = I.getOperand(2 + Writeback).getImm();
+  auto Key =
+      static_cast<AArch64PACKey::ID>(I.getOperand(3 + Writeback).getImm());
+  uint64_t DiscImm = I.getOperand(4 + Writeback).getImm();
+  Register AddrDisc = I.getOperand(5 + Writeback).getReg();
+
+  bool IsDKey = Key == AArch64PACKey::DA || Key == AArch64PACKey::DB;
+  bool ZeroDisc = AddrDisc == AArch64::NoRegister && !DiscImm;
+
+  // If the discriminator is zero and the offset fits, we can use LDRAA/LDRAB.
+  // Do that here to avoid needlessly constraining regalloc into using X16.
+  if (ZeroDisc && isShiftedInt<10, 3>(Offset) && IsDKey) {
+    unsigned Opc = 0;
+    switch (Key) {
+    case AArch64PACKey::DA:
+      Opc = Writeback ? AArch64::LDRAAwriteback : AArch64::LDRAAindexed;
+      break;
+    case AArch64PACKey::DB:
+      Opc = Writeback ? AArch64::LDRABwriteback : AArch64::LDRABindexed;
+      break;
+    default:
+      llvm_unreachable("Invalid key for LDRAA/LDRAB");
+    }
+    // The LDRAA/LDRAB offset immediate is scaled.
+    Offset /= 8;
+    if (Writeback) {
+      MIB.buildInstr(Opc, {I.getOperand(1).getReg(), ValReg}, {PtrReg, Offset})
+          .constrainAllUses(TII, TRI, RBI);
+      RBI.constrainGenericRegister(I.getOperand(1).getReg(),
+                                   AArch64::GPR64spRegClass, MRI);
+    } else {
+      MIB.buildInstr(Opc, {ValReg}, {PtrReg, Offset})
+          .constrainAllUses(TII, TRI, RBI);
+    }
+    I.eraseFromParent();
+    return...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/123769


More information about the llvm-commits mailing list