[llvm] Do not use R12 for indirect tail calls with PACBTI (PR #82661)
Eleanor Bonnici via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 30 02:50:03 PDT 2024
https://github.com/eleanor-arm updated https://github.com/llvm/llvm-project/pull/82661
>From 5e77391493eef0606f85525e12102a19e1d3e198 Mon Sep 17 00:00:00 2001
From: Eleanor Bonnici <eleanor.bonnici at arm.com>
Date: Tue, 30 Apr 2024 10:46:38 +0100
Subject: [PATCH] WIP
---
llvm/lib/Target/ARM/ARMBaseInstrInfo.h | 1 +
llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp | 6 ++-
llvm/lib/Target/ARM/ARMFrameLowering.cpp | 8 +++-
llvm/lib/Target/ARM/ARMInstrInfo.td | 11 +++++-
llvm/lib/Target/ARM/ARMPredicates.td | 4 ++
llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp | 9 +----
llvm/lib/Target/ARM/ARMRegisterInfo.td | 10 +++++
llvm/lib/Target/ARM/Thumb1FrameLowering.cpp | 6 +--
...cbti-indirect-tail-calls-function-flags.ll | 37 +++++++++++++++++++
...acbti-indirect-tail-calls-module-flags1.ll | 17 +++++++++
...acbti-indirect-tail-calls-module-flags2.ll | 18 +++++++++
11 files changed, 112 insertions(+), 15 deletions(-)
create mode 100644 llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll
create mode 100644 llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll
create mode 100644 llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll
diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
index 30f0730774b78c..a3c2684ac1fb97 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
@@ -683,6 +683,7 @@ static inline bool isIndirectCall(const MachineInstr &MI) {
case ARM::BX_CALL:
case ARM::BMOVPCRX_CALL:
case ARM::TCRETURNri:
+ case ARM::TCRETURNrinotr12:
case ARM::TAILJMPr:
case ARM::TAILJMPr4:
case ARM::tBLXr:
diff --git a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
index 0f7858a3be9f9f..df10613fcc7c93 100644
--- a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
@@ -2197,7 +2197,8 @@ bool ARMExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
}
case ARM::TCRETURNdi:
- case ARM::TCRETURNri: {
+ case ARM::TCRETURNri:
+ case ARM::TCRETURNrinotr12: {
MachineBasicBlock::iterator MBBI = MBB.getLastNonDebugInstr();
if (MBBI->getOpcode() == ARM::SEH_EpilogEnd)
MBBI--;
@@ -2241,7 +2242,8 @@ bool ARMExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
// Add the default predicate in Thumb mode.
if (STI->isThumb())
MIB.add(predOps(ARMCC::AL));
- } else if (RetOpcode == ARM::TCRETURNri) {
+ } else if (RetOpcode == ARM::TCRETURNri ||
+ RetOpcode == ARM::TCRETURNrinotr12) {
unsigned Opcode =
STI->isThumb() ? ARM::tTAILJMPr
: (STI->hasV4TOps() ? ARM::TAILJMPr : ARM::TAILJMPr4);
diff --git a/llvm/lib/Target/ARM/ARMFrameLowering.cpp b/llvm/lib/Target/ARM/ARMFrameLowering.cpp
index 9b54dd4e4e618d..425a5535b4e750 100644
--- a/llvm/lib/Target/ARM/ARMFrameLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMFrameLowering.cpp
@@ -257,7 +257,8 @@ static int getArgumentStackToRestore(MachineFunction &MF,
if (MBB.end() != MBBI) {
unsigned RetOpcode = MBBI->getOpcode();
IsTailCallReturn = RetOpcode == ARM::TCRETURNdi ||
- RetOpcode == ARM::TCRETURNri;
+ RetOpcode == ARM::TCRETURNri ||
+ RetOpcode == ARM::TCRETURNrinotr12;
}
ARMFunctionInfo *AFI = MF.getInfo<ARMFunctionInfo>();
@@ -486,6 +487,7 @@ static MachineBasicBlock::iterator insertSEH(MachineBasicBlock::iterator MBBI,
case ARM::tBX_RET:
case ARM::TCRETURNri:
+ case ARM::TCRETURNrinotr12:
MIB = BuildMI(MF, DL, TII.get(ARM::SEH_Nop_Ret))
.addImm(/*Wide=*/0)
.setMIFlags(Flags);
@@ -1615,7 +1617,9 @@ void ARMFrameLowering::emitPopInst(MachineBasicBlock &MBB,
if (MBB.end() != MI) {
DL = MI->getDebugLoc();
unsigned RetOpcode = MI->getOpcode();
- isTailCall = (RetOpcode == ARM::TCRETURNdi || RetOpcode == ARM::TCRETURNri);
+ isTailCall =
+ (RetOpcode == ARM::TCRETURNdi || RetOpcode == ARM::TCRETURNri ||
+ RetOpcode == ARM::TCRETURNrinotr12);
isInterrupt =
RetOpcode == ARM::SUBS_PC_LR || RetOpcode == ARM::t2SUBS_PC_LR;
isTrap =
diff --git a/llvm/lib/Target/ARM/ARMInstrInfo.td b/llvm/lib/Target/ARM/ARMInstrInfo.td
index 08b519e4d5cbf5..1f7bd8dd3121d8 100644
--- a/llvm/lib/Target/ARM/ARMInstrInfo.td
+++ b/llvm/lib/Target/ARM/ARMInstrInfo.td
@@ -2677,6 +2677,9 @@ let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [SP] in {
def TCRETURNri : PseudoInst<(outs), (ins tcGPR:$dst, i32imm:$SPDiff), IIC_Br, []>,
Sched<[WriteBr]>;
+ def TCRETURNrinotr12 : PseudoInst<(outs), (ins tcGPRnotr12:$dst, i32imm:$SPDiff), IIC_Br, []>,
+ Sched<[WriteBr]>;
+
def TAILJMPd : ARMPseudoExpand<(outs), (ins arm_br_target:$dst),
4, IIC_Br, [],
(Bcc arm_br_target:$dst, (ops 14, zero_reg))>,
@@ -6081,8 +6084,14 @@ def : ARMPat<(ARMWrapperJT tjumptable:$dst),
// TODO: add,sub,and, 3-instr forms?
// Tail calls. These patterns also apply to Thumb mode.
+// Regular indirect tail call
def : Pat<(ARMtcret tcGPR:$dst, (i32 timm:$SPDiff)),
- (TCRETURNri tcGPR:$dst, timm:$SPDiff)>;
+ (TCRETURNri tcGPR:$dst, timm:$SPDiff)>,
+ Requires<[NoSignRetAddr]>;
+// Indirect tail call when PACBTI is enabled
+def : Pat<(ARMtcret tcGPRnotr12:$dst, (i32 timm:$SPDiff)),
+ (TCRETURNrinotr12 tcGPRnotr12:$dst, timm:$SPDiff)>,
+ Requires<[SignRetAddr]>;
def : Pat<(ARMtcret (i32 tglobaladdr:$dst), (i32 timm:$SPDiff)),
(TCRETURNdi texternalsym:$dst, (i32 timm:$SPDiff))>;
def : Pat<(ARMtcret (i32 texternalsym:$dst), (i32 timm:$SPDiff)),
diff --git a/llvm/lib/Target/ARM/ARMPredicates.td b/llvm/lib/Target/ARM/ARMPredicates.td
index aca970d900a8dc..ddc5ad8754eee1 100644
--- a/llvm/lib/Target/ARM/ARMPredicates.td
+++ b/llvm/lib/Target/ARM/ARMPredicates.td
@@ -228,6 +228,10 @@ def DontGenExecuteOnly : Predicate<"!Subtarget->genExecuteOnly()">;
def GenT1ExecuteOnly : Predicate<"Subtarget->genExecuteOnly() && "
"Subtarget->isThumb1Only() && "
"!Subtarget->hasV8MBaselineOps()">;
+let RecomputePerFunction = 1 in {
+ def SignRetAddr : Predicate<[{ MF->getInfo<ARMFunctionInfo>()->shouldSignReturnAddress(true) }]>;
+ def NoSignRetAddr : Predicate<[{ !MF->getInfo<ARMFunctionInfo>()->shouldSignReturnAddress(true) }]>;
+}
// Armv8.5-A extensions
def HasSB : Predicate<"Subtarget->hasSB()">,
diff --git a/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp b/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
index 5d4ae9a7648e69..a6fdece10ba47c 100644
--- a/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
@@ -156,11 +156,6 @@ ARMRegisterBankInfo::ARMRegisterBankInfo(const TargetRegisterInfo &TRI) {
"Subclass not added?");
assert(RBGPR.covers(*TRI.getRegClass(ARM::tcGPRRegClassID)) &&
"Subclass not added?");
- assert(RBGPR.covers(*TRI.getRegClass(ARM::GPRnoip_and_tcGPRRegClassID)) &&
- "Subclass not added?");
- assert(RBGPR.covers(*TRI.getRegClass(
- ARM::tGPREven_and_GPRnoip_and_tcGPRRegClassID)) &&
- "Subclass not added?");
assert(RBGPR.covers(*TRI.getRegClass(ARM::tGPROdd_and_tcGPRRegClassID)) &&
"Subclass not added?");
assert(getMaximumSize(RBGPR.getID()) == 32 &&
@@ -188,16 +183,16 @@ ARMRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
case GPRnoip_and_GPRnopcRegClassID:
case rGPRRegClassID:
case GPRspRegClassID:
- case GPRnoip_and_tcGPRRegClassID:
case tcGPRRegClassID:
+ case tcGPRnotr12RegClassID:
case tGPRRegClassID:
case tGPREvenRegClassID:
case tGPROddRegClassID:
case tGPR_and_tGPREvenRegClassID:
case tGPR_and_tGPROddRegClassID:
case tGPREven_and_tcGPRRegClassID:
- case tGPREven_and_GPRnoip_and_tcGPRRegClassID:
case tGPROdd_and_tcGPRRegClassID:
+ case tGPREven_and_tcGPRnotr12RegClassID:
return getRegBank(ARM::GPRRegBankID);
case HPRRegClassID:
case SPR_8RegClassID:
diff --git a/llvm/lib/Target/ARM/ARMRegisterInfo.td b/llvm/lib/Target/ARM/ARMRegisterInfo.td
index 194d65cad8d170..212f22651f9f94 100644
--- a/llvm/lib/Target/ARM/ARMRegisterInfo.td
+++ b/llvm/lib/Target/ARM/ARMRegisterInfo.td
@@ -373,6 +373,16 @@ def tcGPR : RegisterClass<"ARM", [i32], 32, (add R0, R1, R2, R3, R12)> {
}];
}
+// Some pointer authentication instructions require the use of R12. When return
+// address signing is enabled, authentication of the caller's return address
+// must be performed before a tail call is made. Therefore, indirect tail call
+// jump cannot be from R12.
+// FIXME: All PACBTI instruction currently implemented in the compiler
+// implicitly use R12. When instructions that allow PAC to be placed in a
+// specific register are implemented the restriction needs to be updated to
+// make sure that PACBTI signature and indirect tail call both use a different register.
+def tcGPRnotr12 : RegisterClass<"ARM", [i32], 32, (add R0, R1, R2, R3)>;
+
def tGPROdd : RegisterClass<"ARM", [i32], 32, (add R1, R3, R5, R7, R9, R11)> {
let AltOrders = [(and tGPROdd, tGPR)];
let AltOrderSelect = [{
diff --git a/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp b/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp
index 0f4ece64bff532..366173cc65a5bc 100644
--- a/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp
+++ b/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp
@@ -1049,9 +1049,9 @@ static void popRegsFromStack(MachineBasicBlock &MBB,
continue;
if (Reg == ARM::LR) {
- if (!MBB.succ_empty() ||
- MI->getOpcode() == ARM::TCRETURNdi ||
- MI->getOpcode() == ARM::TCRETURNri)
+ if (!MBB.succ_empty() || MI->getOpcode() == ARM::TCRETURNdi ||
+ MI->getOpcode() == ARM::TCRETURNri ||
+ MI->getOpcode() == ARM::TCRETURNrinotr12)
// LR may only be popped into PC, as part of return sequence.
// If this isn't the return sequence, we'll need emitPopSpecialFixUp
// to restore LR the hard way.
diff --git a/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll
new file mode 100644
index 00000000000000..7362b63a0ad62f
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll
@@ -0,0 +1,37 @@
+; RUN: llc -mtriple=thumbv8.1m.main-none-none-eabi -mattr=+pacbti< %s | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+target triple = "thumbv8.1m.main-m.main-unknown"
+
+; When PACBTI is enabled, indirect tail-calls must not use R12 that is used
+; to store authentication code.
+
+define void @pacbti_disabled(ptr %p) "sign-return-address"="none" {
+entry:
+ tail call void %p()
+; CHECK: bx {{r0|r1|r2|r3|r12}}
+ ret void
+}
+
+define void @pacbti_enabled(ptr %p) "sign-return-address"="all" {
+entry:
+ tail call void %p()
+; CHECK: bx {{r0|r1|r2|r3}}
+ ret void
+}
+
+define void @pacbti_disabled_force_r12(ptr %p) "sign-return-address"="none" {
+entry:
+ %p_r12 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %p)
+ tail call void %p_r12()
+; CHECK: bx r12
+ ret void
+}
+
+define void @pacbti_enabled_force_r12(ptr %p) "sign-return-address"="all" {
+entry:
+ %p_r12 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %p)
+ tail call void %p_r12()
+; CHECK: bx {{r0|r1|r2|r3}}
+ ret void
+}
diff --git a/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll
new file mode 100644
index 00000000000000..59499a240fb371
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll
@@ -0,0 +1,17 @@
+; RUN: llc -mtriple=thumbv8.1m.main-none-none-eabi -mattr=+pacbti< %s | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+target triple = "thumbv8.1m.main-m.main-unknown"
+
+define dso_local void @sgign_return_address(ptr noundef readonly %fptr_arg) local_unnamed_addr #0 {
+entry:
+ %0 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %fptr_arg)
+ tail call void %0()
+; CHECK: bx {{r0|r1|r2|r3}}
+ ret void
+}
+
+!llvm.module.flags = !{!1}
+
+!1 = !{i32 8, !"sign-return-address", i32 1}
+
diff --git a/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll
new file mode 100644
index 00000000000000..b2ae55c43c3390
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll
@@ -0,0 +1,18 @@
+; RUN: llc -mtriple=thumbv8.1m.main-none-none-eabi -mattr=+pacbti< %s | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+target triple = "thumbv8.1m.main-m.main-unknown"
+
+define dso_local void @sgign_return_address_all(ptr noundef readonly %fptr_arg) local_unnamed_addr #0 {
+entry:
+ %0 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %fptr_arg)
+ tail call void %0()
+; CHECK: bx {{r0|r1|r2|r3}}
+ ret void
+}
+
+!llvm.module.flags = !{!1}
+
+!1 = !{i32 8, !"sign-return-address", i32 1}
+!2 = !{i32 8, !"sign-return-address-all", i32 1}
+
More information about the llvm-commits
mailing list