[llvm] [AArch64] Indirect tail-calls cannot use x16 with pac-ret+pc (PR #81020)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 7 10:28:55 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: None (ostannard)

<details>
<summary>Changes</summary>

When using -mbranch-protection=pac-ret+pc, x16 is used in the function epilogue to hold the address of the signing instruction. This is used by a HINT instruction which can only use x16, so we can't change this. This means that we can't use it to hold the function pointer for an indirect tail-call.

There is existing code to force indirect tail-calls to use x16 or x17 when BTI is enabled, so there are now 4 combinations:

```
bti  pac-ret+pc  Valid function pointer registers
off  off         Any non callee-saved register
on   off         x16 or x17
off  on          Any non callee-saved register except x16
on   on          x17
```

---
Full diff: https://github.com/llvm/llvm-project/pull/81020.diff


9 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+3-1) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+3-1) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+3-1) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+42-9) 
- (modified) llvm/lib/Target/AArch64/AArch64RegisterInfo.td (+10-5) 
- (modified) llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp (+11-4) 
- (modified) llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp (+3-1) 
- (modified) llvm/test/CodeGen/AArch64/branch-target-enforcement-indirect-calls.ll (+65) 
- (modified) llvm/test/CodeGen/AArch64/kcfi-bti.ll (+2-2) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index de247253eb18a5..5b5ffd7b2feb06 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -1602,7 +1602,9 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
   // attributes (isCall, isReturn, etc.). We lower them to the real
   // instruction here.
   case AArch64::TCRETURNri:
-  case AArch64::TCRETURNriBTI:
+  case AArch64::TCRETURNrix16x17:
+  case AArch64::TCRETURNrix17:
+  case AArch64::TCRETURNrinotx16:
   case AArch64::TCRETURNriALL: {
     MCInst TmpInst;
     TmpInst.setOpcode(AArch64::BR);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8573939b04389f..20290c958a70e9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -25700,7 +25700,9 @@ AArch64TargetLowering::EmitKCFICheck(MachineBasicBlock &MBB,
   case AArch64::BLR:
   case AArch64::BLRNoIP:
   case AArch64::TCRETURNri:
-  case AArch64::TCRETURNriBTI:
+  case AArch64::TCRETURNrix16x17:
+  case AArch64::TCRETURNrix17:
+  case AArch64::TCRETURNrinotx16:
     break;
   default:
     llvm_unreachable("Unexpected CFI call opcode");
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 9add7d87017a73..39c96092f10319 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -2503,7 +2503,9 @@ bool AArch64InstrInfo::isTailCallReturnInst(const MachineInstr &MI) {
     return false;
   case AArch64::TCRETURNdi:
   case AArch64::TCRETURNri:
-  case AArch64::TCRETURNriBTI:
+  case AArch64::TCRETURNrix16x17:
+  case AArch64::TCRETURNrix17:
+  case AArch64::TCRETURNrinotx16:
   case AArch64::TCRETURNriALL:
     return true;
   }
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 77fdb688d0422e..0eff7cbb516094 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -928,8 +928,29 @@ let RecomputePerFunction = 1 in {
   // Avoid generating STRQro if it is slow, unless we're optimizing for code size.
   def UseSTRQro : Predicate<"!Subtarget->isSTRQroSlow() || shouldOptForSize(MF)">;
 
-  def UseBTI : Predicate<[{ MF->getInfo<AArch64FunctionInfo>()->branchTargetEnforcement() }]>;
-  def NotUseBTI : Predicate<[{ !MF->getInfo<AArch64FunctionInfo>()->branchTargetEnforcement() }]>;
+  // Register restrictions for indirect tail-calls:
+  // - If branch target enforcement is enabled, indirect calls must use x16 or
+  //   x17, because these are the only registers which can target the BTI C
+  //   instruction.
+  // - If PAuthLR is enabled, x16 is used in the epilogue to hold the address
+  //   of the signing instruction. This can't be changed because it is used by a
+  //   HINT instruction which only accepts x16. We can't load anything from the
+  //   stack after this because the authentication instruction checks that SP is
+  //   the same as it was at function entry, so we can't have anything on the
+  //   stack.
+
+  // BTI on, PAuthLR off: x16 or x17
+  def TailCallX16X17 : Predicate<[{  MF->getInfo<AArch64FunctionInfo>()->branchTargetEnforcement() &&
+                                    !MF->getInfo<AArch64FunctionInfo>()->branchProtectionPAuthLR() }]>;
+  // BTI on, PAuthLR on: x17 only
+  def TailCallX17 : Predicate<[{ MF->getInfo<AArch64FunctionInfo>()->branchTargetEnforcement() &&
+                                 MF->getInfo<AArch64FunctionInfo>()->branchProtectionPAuthLR() }]>;
+  // BTI off, PAuthLR on: Any non-callee-saved register except x16
+  def TailCallNotX16 : Predicate<[{ !MF->getInfo<AArch64FunctionInfo>()->branchTargetEnforcement() &&
+                                    MF->getInfo<AArch64FunctionInfo>()->branchProtectionPAuthLR() }]>;
+  // BTI off, PAuthLR off: Any non-callee-saved register
+  def TailCallAny : Predicate<[{ !MF->getInfo<AArch64FunctionInfo>()->branchTargetEnforcement() &&
+                                 !MF->getInfo<AArch64FunctionInfo>()->branchProtectionPAuthLR() }]>;
 
   def SLSBLRMitigation : Predicate<[{ MF->getSubtarget<AArch64Subtarget>().hardenSlsBlr() }]>;
   def NoSLSBLRMitigation : Predicate<[{ !MF->getSubtarget<AArch64Subtarget>().hardenSlsBlr() }]>;
@@ -9121,18 +9142,30 @@ let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [SP] in {
   // some verifier checks for outlined functions.
   def TCRETURNriALL : Pseudo<(outs), (ins GPR64:$dst, i32imm:$FPDiff), []>,
                       Sched<[WriteBrReg]>;
-  // Indirect tail-call limited to only use registers (x16 and x17) which are
-  // allowed to tail-call a "BTI c" instruction.
-  def TCRETURNriBTI : Pseudo<(outs), (ins rtcGPR64:$dst, i32imm:$FPDiff), []>,
+
+  // Indirect tail-calls with reduced register classes, needed for BTI and
+  // PAuthLR.
+  def TCRETURNrix16x17 : Pseudo<(outs), (ins tcGPRx16x17:$dst, i32imm:$FPDiff), []>,
+                      Sched<[WriteBrReg]>;
+  def TCRETURNrix17 : Pseudo<(outs), (ins tcGPRx17:$dst, i32imm:$FPDiff), []>,
+                      Sched<[WriteBrReg]>;
+  def TCRETURNrinotx16 : Pseudo<(outs), (ins tcGPRnotx16:$dst, i32imm:$FPDiff), []>,
                       Sched<[WriteBrReg]>;
 }
 
 def : Pat<(AArch64tcret tcGPR64:$dst, (i32 timm:$FPDiff)),
           (TCRETURNri tcGPR64:$dst, imm:$FPDiff)>,
-      Requires<[NotUseBTI]>;
-def : Pat<(AArch64tcret rtcGPR64:$dst, (i32 timm:$FPDiff)),
-          (TCRETURNriBTI rtcGPR64:$dst, imm:$FPDiff)>,
-      Requires<[UseBTI]>;
+      Requires<[TailCallAny]>;
+def : Pat<(AArch64tcret tcGPRx16x17:$dst, (i32 timm:$FPDiff)),
+          (TCRETURNrix16x17 tcGPRx16x17:$dst, imm:$FPDiff)>,
+      Requires<[TailCallX16X17]>;
+def : Pat<(AArch64tcret tcGPRx17:$dst, (i32 timm:$FPDiff)),
+          (TCRETURNrix17 tcGPRx17:$dst, imm:$FPDiff)>,
+      Requires<[TailCallX17]>;
+def : Pat<(AArch64tcret tcGPRnotx16:$dst, (i32 timm:$FPDiff)),
+          (TCRETURNrinotx16 tcGPRnotx16:$dst, imm:$FPDiff)>,
+      Requires<[TailCallNotX16]>;
+
 def : Pat<(AArch64tcret tglobaladdr:$dst, (i32 timm:$FPDiff)),
           (TCRETURNdi texternalsym:$dst, imm:$FPDiff)>;
 def : Pat<(AArch64tcret texternalsym:$dst, (i32 timm:$FPDiff)),
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index b70ab856888478..569944e0e660b7 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -217,11 +217,16 @@ def tcGPR64 : RegisterClass<"AArch64", [i64], 64, (sub GPR64common, X19, X20, X2
                                                      X22, X23, X24, X25, X26,
                                                      X27, X28, FP, LR)>;
 
-// Restricted set of tail call registers, for use when branch target
-// enforcement is enabled. These are the only registers which can be used to
-// indirectly branch (not call) to the "BTI c" instruction at the start of a
-// BTI-protected function.
-def rtcGPR64 : RegisterClass<"AArch64", [i64], 64, (add X16, X17)>;
+// Restricted sets of tail call registers, for use when branch target
+// enforcement or PAuthLR are enabled.
+// For BTI, x16 and x17 are the only registers which can be used to indirectly
+// branch (not call) to the "BTI c" instruction at the start of a BTI-protected
+// function.
+// For PAuthLR, x16 must be used in the function epilogue for other purposes,
+// so cannot hold the function pointer.
+def tcGPRx17 : RegisterClass<"AArch64", [i64], 64, (add X17)>;
+def tcGPRx16x17 : RegisterClass<"AArch64", [i64], 64, (add X16, X17)>;
+def tcGPRnotx16 : RegisterClass<"AArch64", [i64], 64, (sub tcGPR64, X16)>;
 
 // Register set that excludes registers that are reserved for procedure calls.
 // This is used for pseudo-instructions that are actually implemented using a
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
index 55cad848393aed..491040ee0b7639 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
@@ -1012,16 +1012,23 @@ bool AArch64CallLowering::isEligibleForTailCallOptimization(
 
 static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
                               bool IsTailCall) {
+  const AArch64FunctionInfo *FuncInfo = CallerF.getInfo<AArch64FunctionInfo>();
+
   if (!IsTailCall)
     return IsIndirect ? getBLRCallOpcode(CallerF) : (unsigned)AArch64::BL;
 
   if (!IsIndirect)
     return AArch64::TCRETURNdi;
 
-  // When BTI is enabled, we need to use TCRETURNriBTI to make sure that we use
-  // x16 or x17.
-  if (CallerF.getInfo<AArch64FunctionInfo>()->branchTargetEnforcement())
-    return AArch64::TCRETURNriBTI;
+  // When BTI or PAuthLR are enabled, there are restrictions on using x16 and
+  // x17 to hold the function pointer.
+  if (FuncInfo->branchTargetEnforcement()) {
+    if (FuncInfo->branchProtectionPAuthLR())
+      return AArch64::TCRETURNrix17;
+    else
+      return AArch64::TCRETURNrix16x17;
+  } else if (FuncInfo->branchProtectionPAuthLR())
+      return AArch64::TCRETURNrinotx16;
 
   return AArch64::TCRETURNri;
 }
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index b8e5e7bbdaba77..0fc4d7f1991061 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -273,7 +273,9 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
   case AArch64::GPR64common_and_GPR64noipRegClassID:
   case AArch64::GPR64noip_and_tcGPR64RegClassID:
   case AArch64::tcGPR64RegClassID:
-  case AArch64::rtcGPR64RegClassID:
+  case AArch64::tcGPRx16x17RegClassID:
+  case AArch64::tcGPRx17RegClassID:
+  case AArch64::tcGPRnotx16RegClassID:
   case AArch64::WSeqPairsClassRegClassID:
   case AArch64::XSeqPairsClassRegClassID:
   case AArch64::MatrixIndexGPR32_8_11RegClassID:
diff --git a/llvm/test/CodeGen/AArch64/branch-target-enforcement-indirect-calls.ll b/llvm/test/CodeGen/AArch64/branch-target-enforcement-indirect-calls.ll
index de543f4e4d9424..833a6d5b1d1da0 100644
--- a/llvm/test/CodeGen/AArch64/branch-target-enforcement-indirect-calls.ll
+++ b/llvm/test/CodeGen/AArch64/branch-target-enforcement-indirect-calls.ll
@@ -26,3 +26,68 @@ entry:
 ; CHECK: br {{x16|x17}}
   ret void
 }
+define void @bti_enabled_force_x10(ptr %p) "branch-target-enforcement"="true" {
+entry:
+  %p_x10 = tail call ptr asm "", "={x10},{x10},~{lr}"(ptr %p)
+  tail call void %p_x10()
+; CHECK: br {{x16|x17}}
+  ret void
+}
+
+; sign-return-address places no further restrictions on the tail-call register.
+
+define void @bti_enabled_pac_enabled(ptr %p) "branch-target-enforcement"="true" "sign-return-address"="all" {
+entry:
+  tail call void %p()
+; CHECK: br {{x16|x17}}
+  ret void
+}
+define void @bti_enabled_pac_enabled_force_x10(ptr %p) "branch-target-enforcement"="true" "sign-return-address"="all" {
+entry:
+  %p_x10 = tail call ptr asm "", "={x10},{x10},~{lr}"(ptr %p)
+  tail call void %p_x10()
+; CHECK: br {{x16|x17}}
+  ret void
+}
+
+; PAuthLR needs to use x16 to hold the address of the signing instruction. That
+; can't be changed because the hint instruction only uses that register, so the
+; only choice for the tail-call function pointer is x17.
+
+define void @bti_enabled_pac_pc_enabled(ptr %p) "branch-target-enforcement"="true" "sign-return-address"="all" "branch-protection-pauth-lr"="true" {
+entry:
+  tail call void %p()
+; CHECK: br x17
+  ret void
+}
+define void @bti_enabled_pac_pc_enabled_force_x16(ptr %p) "branch-target-enforcement"="true" "sign-return-address"="all" "branch-protection-pauth-lr"="true" {
+entry:
+  %p_x16 = tail call ptr asm "", "={x16},{x16},~{lr}"(ptr %p)
+  tail call void %p_x16()
+; CHECK: br x17
+  ret void
+}
+
+; PAuthLR by itself prevents x16 from being used, but any other
+; non-callee-saved register can be used.
+
+define void @pac_pc_enabled(ptr %p) "sign-return-address"="all" "branch-protection-pauth-lr"="true" {
+entry:
+  tail call void %p()
+; CHECK: br {{(x[0-9]|x1[0-578])$}}
+  ret void
+}
+define void @pac_pc_enabled_force_x16(ptr %p) "sign-return-address"="all" "branch-protection-pauth-lr"="true" {
+entry:
+  %p_x16 = tail call ptr asm "", "={x16},{x16},~{lr}"(ptr %p)
+  tail call void %p_x16()
+; CHECK: br {{(x[0-9]|x1[0-578])$}}
+  ret void
+}
+define void @pac_pc_enabled_force_x17(ptr %p) "sign-return-address"="all" "branch-protection-pauth-lr"="true" {
+entry:
+  %p_x17 = tail call ptr asm "", "={x17},{x17},~{lr}"(ptr %p)
+  tail call void %p_x17()
+; CHECK: br x17
+  ret void
+}
diff --git a/llvm/test/CodeGen/AArch64/kcfi-bti.ll b/llvm/test/CodeGen/AArch64/kcfi-bti.ll
index 12cde4371e15b1..d3febb536824e3 100644
--- a/llvm/test/CodeGen/AArch64/kcfi-bti.ll
+++ b/llvm/test/CodeGen/AArch64/kcfi-bti.ll
@@ -73,11 +73,11 @@ define void @f3(ptr noundef %x) {
 ; MIR-LABEL: name: f3
 ; MIR: body:
 
-; ISEL: TCRETURNriBTI %1, 0, csr_aarch64_aapcs, implicit $sp, cfi-type 12345678
+; ISEL: TCRETURNrix16x17 %1, 0, csr_aarch64_aapcs, implicit $sp, cfi-type 12345678
 
 ; KCFI:       BUNDLE{{.*}} {
 ; KCFI-NEXT:    KCFI_CHECK $x16, 12345678, implicit-def $x9, implicit-def $x16, implicit-def $x17, implicit-def $nzcv
-; KCFI-NEXT:    TCRETURNriBTI internal killed $x16, 0, csr_aarch64_aapcs, implicit $sp
+; KCFI-NEXT:    TCRETURNrix16x17 internal killed $x16, 0, csr_aarch64_aapcs, implicit $sp
 ; KCFI-NEXT:  }
 
   tail call void %x() [ "kcfi"(i32 12345678) ]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81020


More information about the llvm-commits mailing list