[llvm] [AArch64][PAC] Support BLRA* instructions in SLS Hardening pass (PR #98062)

Anatoly Trosinenko via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 9 04:23:29 PDT 2024


https://github.com/atrosinenko updated https://github.com/llvm/llvm-project/pull/98062

>From b894e66946a6e8f6c4f2800c36b249615d78e263 Mon Sep 17 00:00:00 2001
From: Anatoly Trosinenko <atrosinenko at accesssoftek.com>
Date: Sat, 6 Jul 2024 13:36:02 +0300
Subject: [PATCH] [AArch64][PAC] Support BLRA* instructions in SLS Hardening
 pass

Make SLS Hardening pass handle BLRA* instructions the same way it
handles BLR. The thunk names have the form

    __llvm_slsblr_thunk_xN            for BLR thunks
    __llvm_slsblr_thunk_(aaz|abz)_xN  for BLRAAZ and BLRABZ thunks
    __llvm_slsblr_thunk_(aa|ab)_xN_xM for BLRAA and BLRAB thunks

Now there are about 1800 possible thunk names, so do not rely on linear
thunk function's name lookup and parse the name instead.

This patch reapplies llvm/llvm-project#97605.
---
 .../Target/AArch64/AArch64SLSHardening.cpp    | 377 ++++++++++++------
 .../speculation-hardening-sls-blra.mir        | 210 ++++++++++
 2 files changed, 469 insertions(+), 118 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/speculation-hardening-sls-blra.mir

diff --git a/llvm/lib/Target/AArch64/AArch64SLSHardening.cpp b/llvm/lib/Target/AArch64/AArch64SLSHardening.cpp
index 00ba31b3e500d..e648b7fbab566 100644
--- a/llvm/lib/Target/AArch64/AArch64SLSHardening.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SLSHardening.cpp
@@ -13,6 +13,7 @@
 
 #include "AArch64InstrInfo.h"
 #include "AArch64Subtarget.h"
+#include "llvm/ADT/StringSwitch.h"
 #include "llvm/CodeGen/IndirectThunks.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/CodeGen/MachineFunction.h"
@@ -23,8 +24,11 @@
 #include "llvm/IR/DebugLoc.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Target/TargetMachine.h"
 #include <cassert>
+#include <climits>
+#include <tuple>
 
 using namespace llvm;
 
@@ -32,17 +36,107 @@ using namespace llvm;
 
 #define AARCH64_SLS_HARDENING_NAME "AArch64 sls hardening pass"
 
-static const char SLSBLRNamePrefix[] = "__llvm_slsblr_thunk_";
+// Common name prefix of all thunks generated by this pass.
+//
+// The generic form is
+// __llvm_slsblr_thunk_xN            for BLR thunks
+// __llvm_slsblr_thunk_(aaz|abz)_xN  for BLRAAZ and BLRABZ thunks
+// __llvm_slsblr_thunk_(aa|ab)_xN_xM for BLRAA and BLRAB thunks
+static constexpr StringRef CommonNamePrefix = "__llvm_slsblr_thunk_";
 
 namespace {
 
-// Set of inserted thunks: bitmask with bits corresponding to
-// indexes in SLSBLRThunks array.
-typedef uint32_t ThunksSet;
+struct ThunkKind {
+  enum ThunkKindId {
+    ThunkBR,
+    ThunkBRAA,
+    ThunkBRAB,
+    ThunkBRAAZ,
+    ThunkBRABZ,
+  };
+
+  ThunkKindId Id;
+  StringRef NameInfix;
+  bool HasXmOperand;
+  bool NeedsPAuth;
+
+  // Opcode to perform indirect jump from inside the thunk.
+  unsigned BROpcode;
+
+  static const ThunkKind BR;
+  static const ThunkKind BRAA;
+  static const ThunkKind BRAB;
+  static const ThunkKind BRAAZ;
+  static const ThunkKind BRABZ;
+};
+
+// Set of inserted thunks.
+class ThunksSet {
+public:
+  static constexpr unsigned NumXRegisters = 32;
+
+  // Given Xn register, returns n.
+  static unsigned indexOfXReg(Register Xn);
+  // Given n, returns Xn register.
+  static Register xRegByIndex(unsigned N);
+
+  ThunksSet &operator|=(const ThunksSet &Other) {
+    BLRThunks |= Other.BLRThunks;
+    BLRAAZThunks |= Other.BLRAAZThunks;
+    BLRABZThunks |= Other.BLRABZThunks;
+    for (unsigned I = 0; I < NumXRegisters; ++I)
+      BLRAAThunks[I] |= Other.BLRAAThunks[I];
+    for (unsigned I = 0; I < NumXRegisters; ++I)
+      BLRABThunks[I] |= Other.BLRABThunks[I];
+
+    return *this;
+  }
+
+  bool get(ThunkKind::ThunkKindId Kind, Register Xn, Register Xm) {
+    reg_bitmask_t XnBit = reg_bitmask_t(1) << indexOfXReg(Xn);
+    return getBitmask(Kind, Xm) & XnBit;
+  }
+
+  void set(ThunkKind::ThunkKindId Kind, Register Xn, Register Xm) {
+    reg_bitmask_t XnBit = reg_bitmask_t(1) << indexOfXReg(Xn);
+    getBitmask(Kind, Xm) |= XnBit;
+  }
+
+private:
+  typedef uint32_t reg_bitmask_t;
+  static_assert(NumXRegisters <= sizeof(reg_bitmask_t) * CHAR_BIT,
+                "Bitmask is not wide enough to hold all Xn registers");
+
+  // Bitmasks representing operands used, with n-th bit corresponding to Xn
+  // register operand. If the instruction has a second operand (Xm), an array
+  // of bitmasks is used, indexed by m.
+  // Indexes corresponding to the forbidden x16, x17 and x30 registers are
+  // always unset, for simplicity there are no holes.
+  reg_bitmask_t BLRThunks = 0;
+  reg_bitmask_t BLRAAZThunks = 0;
+  reg_bitmask_t BLRABZThunks = 0;
+  reg_bitmask_t BLRAAThunks[NumXRegisters] = {};
+  reg_bitmask_t BLRABThunks[NumXRegisters] = {};
+
+  reg_bitmask_t &getBitmask(ThunkKind::ThunkKindId Kind, Register Xm) {
+    switch (Kind) {
+    case ThunkKind::ThunkBR:
+      return BLRThunks;
+    case ThunkKind::ThunkBRAAZ:
+      return BLRAAZThunks;
+    case ThunkKind::ThunkBRABZ:
+      return BLRABZThunks;
+    case ThunkKind::ThunkBRAA:
+      return BLRAAThunks[indexOfXReg(Xm)];
+    case ThunkKind::ThunkBRAB:
+      return BLRABThunks[indexOfXReg(Xm)];
+    }
+  }
+};
 
 struct SLSHardeningInserter : ThunkInserter<SLSHardeningInserter, ThunksSet> {
 public:
-  const char *getThunkPrefix() { return SLSBLRNamePrefix; }
+  const char *getThunkPrefix() { return CommonNamePrefix.data(); }
   bool mayUseThunk(const MachineFunction &MF) {
     ComdatThunks &= !MF.getSubtarget<AArch64Subtarget>().hardenSlsNoComdat();
     // We are inserting barriers aside from thunk calls, so
@@ -68,6 +162,61 @@ struct SLSHardeningInserter : ThunkInserter<SLSHardeningInserter, ThunksSet> {
 
 } // end anonymous namespace
 
+const ThunkKind ThunkKind::BR = {ThunkBR, "", /*HasXmOperand=*/false,
+                                 /*NeedsPAuth=*/false, AArch64::BR};
+const ThunkKind ThunkKind::BRAA = {ThunkBRAA, "aa_", /*HasXmOperand=*/true,
+                                   /*NeedsPAuth=*/true, AArch64::BRAA};
+const ThunkKind ThunkKind::BRAB = {ThunkBRAB, "ab_", /*HasXmOperand=*/true,
+                                   /*NeedsPAuth=*/true, AArch64::BRAB};
+const ThunkKind ThunkKind::BRAAZ = {ThunkBRAAZ, "aaz_", /*HasXmOperand=*/false,
+                                    /*NeedsPAuth=*/true, AArch64::BRAAZ};
+const ThunkKind ThunkKind::BRABZ = {ThunkBRABZ, "abz_", /*HasXmOperand=*/false,
+                                    /*NeedsPAuth=*/true, AArch64::BRABZ};
+
+// Returns thunk kind to emit, or nullptr if not a BLR* instruction.
+static const ThunkKind *getThunkKind(unsigned OriginalOpcode) {
+  switch (OriginalOpcode) {
+  case AArch64::BLR:
+  case AArch64::BLRNoIP:
+    return &ThunkKind::BR;
+  case AArch64::BLRAA:
+    return &ThunkKind::BRAA;
+  case AArch64::BLRAB:
+    return &ThunkKind::BRAB;
+  case AArch64::BLRAAZ:
+    return &ThunkKind::BRAAZ;
+  case AArch64::BLRABZ:
+    return &ThunkKind::BRABZ;
+  }
+  return nullptr;
+}
+
+static bool isBLR(const MachineInstr &MI) {
+  return getThunkKind(MI.getOpcode()) != nullptr;
+}
+
+unsigned ThunksSet::indexOfXReg(Register Reg) {
+  assert(AArch64::GPR64RegClass.contains(Reg));
+  assert(Reg != AArch64::X16 && Reg != AArch64::X17 && Reg != AArch64::LR);
+
+  // Most Xn registers have consecutive ids, except for FP and XZR.
+  unsigned Result = (unsigned)Reg - (unsigned)AArch64::X0;
+  if (Reg == AArch64::FP)
+    Result = 29;
+  else if (Reg == AArch64::XZR)
+    Result = 31;
+
+  assert(Result < NumXRegisters && "Internal register numbering changed");
+  assert(AArch64::GPR64RegClass.getRegister(Result).id() == Reg &&
+         "Internal register numbering changed");
+
+  return Result;
+}
+
+Register ThunksSet::xRegByIndex(unsigned N) {
+  return AArch64::GPR64RegClass.getRegister(N);
+}
+
 static void insertSpeculationBarrier(const AArch64Subtarget *ST,
                                      MachineBasicBlock &MBB,
                                      MachineBasicBlock::iterator MBBI,
@@ -104,22 +253,6 @@ ThunksSet SLSHardeningInserter::insertThunks(MachineModuleInfo &MMI,
   return ExistingThunks;
 }
 
-static bool isBLR(const MachineInstr &MI) {
-  switch (MI.getOpcode()) {
-  case AArch64::BLR:
-  case AArch64::BLRNoIP:
-    return true;
-  case AArch64::BLRAA:
-  case AArch64::BLRAB:
-  case AArch64::BLRAAZ:
-  case AArch64::BLRABZ:
-    llvm_unreachable("Currently, LLVM's code generator does not support "
-                     "producing BLRA* instructions. Therefore, there's no "
-                     "support in this pass for those instructions.");
-  }
-  return false;
-}
-
 bool SLSHardeningInserter::hardenReturnsAndBRs(MachineModuleInfo &MMI,
                                                MachineBasicBlock &MBB) {
   const AArch64Subtarget *ST =
@@ -139,64 +272,64 @@ bool SLSHardeningInserter::hardenReturnsAndBRs(MachineModuleInfo &MMI,
   return Modified;
 }
 
-static const unsigned NumPermittedRegs = 29;
-static const struct ThunkNameAndReg {
-  const char* Name;
-  Register Reg;
-} SLSBLRThunks[NumPermittedRegs] = {
-    {"__llvm_slsblr_thunk_x0", AArch64::X0},
-    {"__llvm_slsblr_thunk_x1", AArch64::X1},
-    {"__llvm_slsblr_thunk_x2", AArch64::X2},
-    {"__llvm_slsblr_thunk_x3", AArch64::X3},
-    {"__llvm_slsblr_thunk_x4", AArch64::X4},
-    {"__llvm_slsblr_thunk_x5", AArch64::X5},
-    {"__llvm_slsblr_thunk_x6", AArch64::X6},
-    {"__llvm_slsblr_thunk_x7", AArch64::X7},
-    {"__llvm_slsblr_thunk_x8", AArch64::X8},
-    {"__llvm_slsblr_thunk_x9", AArch64::X9},
-    {"__llvm_slsblr_thunk_x10", AArch64::X10},
-    {"__llvm_slsblr_thunk_x11", AArch64::X11},
-    {"__llvm_slsblr_thunk_x12", AArch64::X12},
-    {"__llvm_slsblr_thunk_x13", AArch64::X13},
-    {"__llvm_slsblr_thunk_x14", AArch64::X14},
-    {"__llvm_slsblr_thunk_x15", AArch64::X15},
-    // X16 and X17 are deliberately missing, as the mitigation requires those
-    // register to not be used in BLR. See comment in ConvertBLRToBL for more
-    // details.
-    {"__llvm_slsblr_thunk_x18", AArch64::X18},
-    {"__llvm_slsblr_thunk_x19", AArch64::X19},
-    {"__llvm_slsblr_thunk_x20", AArch64::X20},
-    {"__llvm_slsblr_thunk_x21", AArch64::X21},
-    {"__llvm_slsblr_thunk_x22", AArch64::X22},
-    {"__llvm_slsblr_thunk_x23", AArch64::X23},
-    {"__llvm_slsblr_thunk_x24", AArch64::X24},
-    {"__llvm_slsblr_thunk_x25", AArch64::X25},
-    {"__llvm_slsblr_thunk_x26", AArch64::X26},
-    {"__llvm_slsblr_thunk_x27", AArch64::X27},
-    {"__llvm_slsblr_thunk_x28", AArch64::X28},
-    {"__llvm_slsblr_thunk_x29", AArch64::FP},
-    // X30 is deliberately missing, for similar reasons as X16 and X17 are
-    // missing.
-    {"__llvm_slsblr_thunk_x31", AArch64::XZR},
-};
+// Currently, the longest possible thunk name is
+//   __llvm_slsblr_thunk_aa_xNN_xMM
+// which is 31 characters (without the '\0' character).
+static SmallString<32> createThunkName(const ThunkKind &Kind, Register Xn,
+                                       Register Xm) {
+  unsigned N = ThunksSet::indexOfXReg(Xn);
+  if (!Kind.HasXmOperand)
+    return formatv("{0}{1}x{2}", CommonNamePrefix, Kind.NameInfix, N);
+
+  unsigned M = ThunksSet::indexOfXReg(Xm);
+  return formatv("{0}{1}x{2}_x{3}", CommonNamePrefix, Kind.NameInfix, N, M);
+}
 
-unsigned getThunkIndex(Register Reg) {
-  for (unsigned I = 0; I < NumPermittedRegs; ++I)
-    if (SLSBLRThunks[I].Reg == Reg)
-      return I;
-  llvm_unreachable("Unexpected register");
+static std::tuple<const ThunkKind &, Register, Register>
+parseThunkName(StringRef ThunkName) {
+  assert(ThunkName.starts_with(CommonNamePrefix) &&
+         "Should be filtered out by ThunkInserter");
+  // Thunk name suffix, such as "x1" or "aa_x2_x3".
+  StringRef NameSuffix = ThunkName.drop_front(CommonNamePrefix.size());
+
+  // Parse thunk kind based on thunk name infix.
+  const ThunkKind &Kind = *StringSwitch<const ThunkKind *>(NameSuffix)
+                               .StartsWith("aa_", &ThunkKind::BRAA)
+                               .StartsWith("ab_", &ThunkKind::BRAB)
+                               .StartsWith("aaz_", &ThunkKind::BRAAZ)
+                               .StartsWith("abz_", &ThunkKind::BRABZ)
+                               .Default(&ThunkKind::BR);
+
+  auto ParseRegName = [](StringRef Name) {
+    unsigned N;
+
+    assert(Name.starts_with("x") && "xN register name expected");
+    bool Fail = Name.drop_front(1).getAsInteger(/*Radix=*/10, N);
+    assert(!Fail && N < ThunksSet::NumXRegisters && "Unexpected register");
+    (void)Fail;
+
+    return ThunksSet::xRegByIndex(N);
+  };
+
+  // For example, "x1" or "x2_x3".
+  StringRef RegsStr = NameSuffix.drop_front(Kind.NameInfix.size());
+  StringRef XnStr, XmStr;
+  std::tie(XnStr, XmStr) = RegsStr.split('_');
+
+  // Parse register operands.
+  Register Xn = ParseRegName(XnStr);
+  Register Xm = Kind.HasXmOperand ? ParseRegName(XmStr) : AArch64::NoRegister;
+
+  return std::make_tuple(std::ref(Kind), Xn, Xm);
 }
 
 void SLSHardeningInserter::populateThunk(MachineFunction &MF) {
   assert(MF.getFunction().hasComdat() == ComdatThunks &&
          "ComdatThunks value changed since MF creation");
-  // FIXME: How to better communicate Register number, rather than through
-  // name and lookup table?
-  assert(MF.getName().starts_with(getThunkPrefix()));
-  auto ThunkIt = llvm::find_if(
-      SLSBLRThunks, [&MF](auto T) { return T.Name == MF.getName(); });
-  assert(ThunkIt != std::end(SLSBLRThunks));
-  Register ThunkReg = ThunkIt->Reg;
+  Register Xn, Xm;
+  auto KindAndRegs = parseThunkName(MF.getName());
+  const ThunkKind &Kind = std::get<0>(KindAndRegs);
+  std::tie(std::ignore, Xn, Xm) = KindAndRegs;
 
   const TargetInstrInfo *TII =
       MF.getSubtarget<AArch64Subtarget>().getInstrInfo();
@@ -218,16 +351,26 @@ void SLSHardeningInserter::populateThunk(MachineFunction &MF) {
   Entry->clear();
 
   //  These thunks need to consist of the following instructions:
-  //  __llvm_slsblr_thunk_xN:
-  //      BR xN
+  //  __llvm_slsblr_thunk_...:
+  //      MOV x16, xN     ; BR* instructions are not compatible with "BTI c"
+  //                      ; branch target unless xN is x16 or x17.
+  //      BR* ...         ; One of: BR        x16
+  //                      ;         BRA(A|B)  x16, xM
+  //                      ;         BRA(A|B)Z x16
   //      barrierInsts
-  Entry->addLiveIn(ThunkReg);
-  // MOV X16, ThunkReg == ORR X16, XZR, ThunkReg, LSL #0
+  Entry->addLiveIn(Xn);
+  // MOV X16, Reg == ORR X16, XZR, Reg, LSL #0
   BuildMI(Entry, DebugLoc(), TII->get(AArch64::ORRXrs), AArch64::X16)
       .addReg(AArch64::XZR)
-      .addReg(ThunkReg)
+      .addReg(Xn)
       .addImm(0);
-  BuildMI(Entry, DebugLoc(), TII->get(AArch64::BR)).addReg(AArch64::X16);
+  MachineInstrBuilder Builder =
+      BuildMI(Entry, DebugLoc(), TII->get(Kind.BROpcode)).addReg(AArch64::X16);
+  if (Xm != AArch64::NoRegister) {
+    Entry->addLiveIn(Xm);
+    Builder.addReg(Xm);
+  }
+
   // Make sure the thunks do not make use of the SB extension in case there is
   // a function somewhere that will call to it that for some reason disabled
   // the SB extension locally on that function, even though it's enabled for
@@ -239,12 +382,14 @@ void SLSHardeningInserter::populateThunk(MachineFunction &MF) {
 void SLSHardeningInserter::convertBLRToBL(
     MachineModuleInfo &MMI, MachineBasicBlock &MBB,
     MachineBasicBlock::instr_iterator MBBI, ThunksSet &Thunks) {
-  // Transform a BLR to a BL as follows:
+  // Transform a BLR* instruction (one of BLR, BLRAA/BLRAB or BLRAAZ/BLRABZ) to
+  // a BL to the thunk containing BR, BRAA/BRAB or BRAAZ/BRABZ, respectively.
+  //
   // Before:
   //   |-----------------------------|
   //   |      ...                    |
   //   |  instI                      |
-  //   |  BLR xN                     |
+  //   |  BLR* xN or BLR* xN, xM     |
   //   |  instJ                      |
   //   |      ...                    |
   //   |-----------------------------|
@@ -253,61 +398,53 @@ void SLSHardeningInserter::convertBLRToBL(
   //   |-----------------------------|
   //   |      ...                    |
   //   |  instI                      |
-  //   |  BL __llvm_slsblr_thunk_xN  |
+  //   |  BL __llvm_slsblr_thunk_... |
   //   |  instJ                      |
   //   |      ...                    |
   //   |-----------------------------|
   //
-  //   __llvm_slsblr_thunk_xN:
+  //   __llvm_slsblr_thunk_...:
   //   |-----------------------------|
-  //   |  BR xN                      |
+  //   |  MOV x16, xN                |
+  //   |  BR* x16 or BR* x16, xM     |
   //   |  barrierInsts               |
   //   |-----------------------------|
   //
-  // This function merely needs to transform BLR xN into BL
-  // __llvm_slsblr_thunk_xN.
+  // This function needs to transform BLR* instruction into BL with the correct
+  // thunk name and lazily create the thunk if it does not exist yet.
   //
   // Since linkers are allowed to clobber X16 and X17 on function calls, the
-  // above mitigation only works if the original BLR instruction was not
-  // BLR X16 nor BLR X17. Code generation before must make sure that no BLR
-  // X16|X17 was produced if the mitigation is enabled.
+  // above mitigation only works if the original BLR* instruction had neither
+  // X16 nor X17 as one of its operands. Code generation before must make sure
+  // that no such BLR* instruction was produced if the mitigation is enabled.
 
   MachineInstr &BLR = *MBBI;
   assert(isBLR(BLR));
-  unsigned BLOpcode;
-  Register Reg;
-  bool RegIsKilled;
-  switch (BLR.getOpcode()) {
-  case AArch64::BLR:
-  case AArch64::BLRNoIP:
-    BLOpcode = AArch64::BL;
-    Reg = BLR.getOperand(0).getReg();
-    assert(Reg != AArch64::X16 && Reg != AArch64::X17 && Reg != AArch64::LR);
-    RegIsKilled = BLR.getOperand(0).isKill();
-    break;
-  case AArch64::BLRAA:
-  case AArch64::BLRAB:
-  case AArch64::BLRAAZ:
-  case AArch64::BLRABZ:
-    llvm_unreachable("BLRA instructions cannot yet be produced by LLVM, "
-                     "therefore there is no need to support them for now.");
-  default:
-    llvm_unreachable("unhandled BLR");
-  }
+  const ThunkKind &Kind = *getThunkKind(BLR.getOpcode());
+
+  unsigned NumRegOperands = Kind.HasXmOperand ? 2 : 1;
+  assert(BLR.getNumExplicitOperands() == NumRegOperands &&
+         "Expected one or two register inputs");
+  Register Xn = BLR.getOperand(0).getReg();
+  Register Xm =
+      Kind.HasXmOperand ? BLR.getOperand(1).getReg() : AArch64::NoRegister;
+
   DebugLoc DL = BLR.getDebugLoc();
 
   MachineFunction &MF = *MBBI->getMF();
   MCContext &Context = MBB.getParent()->getContext();
   const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
-  unsigned ThunkIndex = getThunkIndex(Reg);
-  StringRef ThunkName = SLSBLRThunks[ThunkIndex].Name;
+
+  auto ThunkName = createThunkName(Kind, Xn, Xm);
   MCSymbol *Sym = Context.getOrCreateSymbol(ThunkName);
-  if (!(Thunks & (1u << ThunkIndex))) {
-    Thunks |= 1u << ThunkIndex;
-    createThunkFunction(MMI, ThunkName, ComdatThunks);
+
+  if (!Thunks.get(Kind.Id, Xn, Xm)) {
+    StringRef TargetAttrs = Kind.NeedsPAuth ? "+pauth" : "";
+    Thunks.set(Kind.Id, Xn, Xm);
+    createThunkFunction(MMI, ThunkName, ComdatThunks, TargetAttrs);
   }
 
-  MachineInstr *BL = BuildMI(MBB, MBBI, DL, TII->get(BLOpcode)).addSym(Sym);
+  MachineInstr *BL = BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL)).addSym(Sym);
 
   // Now copy the implicit operands from BLR to BL and copy other necessary
   // info.
@@ -338,9 +475,13 @@ void SLSHardeningInserter::convertBLRToBL(
   // Now copy over the implicit operands from the original BLR
   BL->copyImplicitOps(MF, BLR);
   MF.moveCallSiteInfo(&BLR, BL);
-  // Also add the register called in the BLR as being used in the called thunk.
-  BL->addOperand(MachineOperand::CreateReg(Reg, false /*isDef*/, true /*isImp*/,
-                                           RegIsKilled /*isKill*/));
+  // Also add the register operands of the original BLR* instruction
+  // as being used in the called thunk.
+  for (unsigned OpIdx = 0; OpIdx < NumRegOperands; ++OpIdx) {
+    MachineOperand &Op = BLR.getOperand(OpIdx);
+    BL->addOperand(MachineOperand::CreateReg(Op.getReg(), /*isDef=*/false,
+                                             /*isImp=*/true, Op.isKill()));
+  }
   // Remove BLR instruction
   MBB.erase(MBBI);
 }
diff --git a/llvm/test/CodeGen/AArch64/speculation-hardening-sls-blra.mir b/llvm/test/CodeGen/AArch64/speculation-hardening-sls-blra.mir
new file mode 100644
index 0000000000000..06669a6d6aae2
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/speculation-hardening-sls-blra.mir
@@ -0,0 +1,210 @@
+# RUN: llc -verify-machineinstrs -mtriple=aarch64-none-linux-gnu \
+# RUN:     -start-before aarch64-sls-hardening -o - %s \
+# RUN:     -asm-verbose=0 \
+# RUN: | FileCheck %s \
+# RUN:     --implicit-check-not=__llvm_slsblr_thunk_aa_x5_x8 \
+# RUN:     --implicit-check-not=__llvm_slsblr_thunk_ab_x5_x8 \
+# RUN:     --implicit-check-not=__llvm_slsblr_thunk_aaz_x5 \
+# RUN:     --implicit-check-not=__llvm_slsblr_thunk_abz_x5
+
+# Pointer Authentication extension introduces more branch-with-link-to-register
+# instructions for the BLR SLS hardening to handle, namely BLRAA, BLRAB, BLRAAZ
+# and BLRABZ. Unlike the non-authenticating BLR instruction, BLRAA and BLRAB
+# accept two register operands (almost 900 combinations for each instruction).
+# For that reason, it is not practical to create all possible thunks.
+
+# Check that the BLR SLS hardening transforms BLRA* instructions into
+# unconditional BL calls to the correct thunk functions.
+# Check that only relevant thunk functions are generated.
+--- |
+  define void @test_instructions() #0 {
+  entry:
+    ret void
+  }
+
+  define void @test_no_redef() #0 {
+  entry:
+    ret void
+  }
+
+  define void @test_regs() #0 {
+  entry:
+    ret void
+  }
+
+  attributes #0 = { "target-features"="+pauth,+harden-sls-blr" }
+...
+
+# Test that all BLRA* instructions are handled.
+---
+name:            test_instructions
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $lr, $x0, $x1, $x2, $x3
+
+    BLRAA $x0, $x1, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAB $x1, $x2, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAAZ $x2, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRABZ $x3, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    RET undef $lr
+...
+
+# Test that the same thunk function is not created twice.
+---
+name:            test_no_redef
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $lr, $x0, $x1, $x2, $x3, $x4
+
+    ; thunk used by @test_instructions
+    BLRAB $x1, $x2, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+
+    ; thunk used by this function twice
+    BLRAB $x3, $x4, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAB $x3, $x4, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+
+    RET undef $lr
+...
+
+# Test that all xN registers (except x16, x17, x30 and xzr) are handled.
+---
+name:            test_regs
+tracksRegLiveness: true
+body:             |
+  bb.0.entry:
+    liveins: $lr, $x0, $x1, $x2, $x3, $x4, $x5, $x6, $x7, $x8, $x9, $x10, $x11, $x12, $x13, $x14, $x15, $x16, $x17, $x18, $x19, $x20, $x21, $x22, $x23, $x24, $x25, $x26, $x27, $x28, $fp
+
+    BLRAA $x0, $x1, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x2, $x3, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x4, $x5, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x6, $x7, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x8, $x9, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x10, $x11, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x12, $x13, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x14, $x15, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    ; skipping x16 and x17
+    BLRAA $x18, $x19, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x20, $x21, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x22, $x23, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x24, $x25, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x26, $x27, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    BLRAA $x28, $fp, implicit-def $lr, implicit $sp, implicit-def $sp, implicit-def $w0
+    RET undef $lr
+...
+
+# CHECK-LABEL: test_instructions:
+# CHECK-NEXT:    .cfi_startproc
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x0_x1
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_ab_x1_x2
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aaz_x2
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_abz_x3
+# CHECK-NEXT:    ret
+
+# CHECK-LABEL: test_no_redef:
+# CHECK-NEXT:    .cfi_startproc
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_ab_x1_x2
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_ab_x3_x4
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_ab_x3_x4
+# CHECK-NEXT:    ret
+
+# CHECK-LABEL: test_regs:
+# CHECK-NEXT:    .cfi_startproc
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x0_x1
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x2_x3
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x4_x5
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x6_x7
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x8_x9
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x10_x11
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x12_x13
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x14_x15
+# skipping x16 and x17
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x18_x19
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x20_x21
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x22_x23
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x24_x25
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x26_x27
+# CHECK-NEXT:    bl      __llvm_slsblr_thunk_aa_x28_x29
+# CHECK-NEXT:    ret
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x0_x1:
+# CHECK-NEXT:    mov     x16, x0
+# CHECK-NEXT:    braa    x16, x1
+# CHECK-NEXT:    dsb     sy
+# CHECK-NEXT:    isb
+
+# CHECK-LABEL: __llvm_slsblr_thunk_ab_x1_x2:
+# CHECK-NEXT:    mov     x16, x1
+# CHECK-NEXT:    brab    x16, x2
+# CHECK-NEXT:    dsb     sy
+# CHECK-NEXT:    isb
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aaz_x2:
+# CHECK-NEXT:    mov     x16, x2
+# CHECK-NEXT:    braaz   x16
+# CHECK-NEXT:    dsb     sy
+# CHECK-NEXT:    isb
+
+# CHECK-LABEL: __llvm_slsblr_thunk_abz_x3:
+# CHECK-NEXT:    mov     x16, x3
+# CHECK-NEXT:    brabz   x16
+# CHECK-NEXT:    dsb     sy
+# CHECK-NEXT:    isb
+
+# The instruction *operands* should correspond to the thunk function *name*
+# (check that the name is parsed correctly when populating the thunk).
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x2_x3:
+# CHECK-NEXT:    mov     x16, x2
+# CHECK:         braa    x16, x3
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x4_x5:
+# CHECK-NEXT:    mov     x16, x4
+# CHECK:         braa    x16, x5
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x6_x7:
+# CHECK-NEXT:    mov     x16, x6
+# CHECK:         braa    x16, x7
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x8_x9:
+# CHECK-NEXT:    mov     x16, x8
+# CHECK:         braa    x16, x9
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x10_x11:
+# CHECK-NEXT:    mov     x16, x10
+# CHECK:         braa    x16, x11
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x12_x13:
+# CHECK-NEXT:    mov     x16, x12
+# CHECK:         braa    x16, x13
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x14_x15:
+# CHECK-NEXT:    mov     x16, x14
+# CHECK:         braa    x16, x15
+
+# skipping x16 and x17
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x18_x19:
+# CHECK-NEXT:    mov     x16, x18
+# CHECK:         braa    x16, x19
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x20_x21:
+# CHECK-NEXT:    mov     x16, x20
+# CHECK:         braa    x16, x21
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x22_x23:
+# CHECK-NEXT:    mov     x16, x22
+# CHECK:         braa    x16, x23
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x24_x25:
+# CHECK-NEXT:    mov     x16, x24
+# CHECK:         braa    x16, x25
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x26_x27:
+# CHECK-NEXT:    mov     x16, x26
+# CHECK:         braa    x16, x27
+
+# CHECK-LABEL: __llvm_slsblr_thunk_aa_x28_x29:
+# CHECK-NEXT:    mov     x16, x28
+# CHECK:         braa    x16, x29



More information about the llvm-commits mailing list