[llvm] Do not use R12 for indirect tail calls with PACBTI (PR #82661)

Eleanor Bonnici via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 30 02:50:03 PDT 2024


https://github.com/eleanor-arm updated https://github.com/llvm/llvm-project/pull/82661

>From 5e77391493eef0606f85525e12102a19e1d3e198 Mon Sep 17 00:00:00 2001
From: Eleanor Bonnici <eleanor.bonnici at arm.com>
Date: Tue, 30 Apr 2024 10:46:38 +0100
Subject: [PATCH] WIP

---
 llvm/lib/Target/ARM/ARMBaseInstrInfo.h        |  1 +
 llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp  |  6 ++-
 llvm/lib/Target/ARM/ARMFrameLowering.cpp      |  8 +++-
 llvm/lib/Target/ARM/ARMInstrInfo.td           | 11 +++++-
 llvm/lib/Target/ARM/ARMPredicates.td          |  4 ++
 llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp   |  9 +----
 llvm/lib/Target/ARM/ARMRegisterInfo.td        | 10 +++++
 llvm/lib/Target/ARM/Thumb1FrameLowering.cpp   |  6 +--
 ...cbti-indirect-tail-calls-function-flags.ll | 37 +++++++++++++++++++
 ...acbti-indirect-tail-calls-module-flags1.ll | 17 +++++++++
 ...acbti-indirect-tail-calls-module-flags2.ll | 18 +++++++++
 11 files changed, 112 insertions(+), 15 deletions(-)
 create mode 100644 llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll
 create mode 100644 llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll
 create mode 100644 llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll

diff --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
index 30f0730774b78c..a3c2684ac1fb97 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
@@ -683,6 +683,7 @@ static inline bool isIndirectCall(const MachineInstr &MI) {
   case ARM::BX_CALL:
   case ARM::BMOVPCRX_CALL:
   case ARM::TCRETURNri:
+  case ARM::TCRETURNrinotr12:
   case ARM::TAILJMPr:
   case ARM::TAILJMPr4:
   case ARM::tBLXr:
diff --git a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
index 0f7858a3be9f9f..df10613fcc7c93 100644
--- a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
@@ -2197,7 +2197,8 @@ bool ARMExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
     }
 
     case ARM::TCRETURNdi:
-    case ARM::TCRETURNri: {
+    case ARM::TCRETURNri:
+    case ARM::TCRETURNrinotr12: {
       MachineBasicBlock::iterator MBBI = MBB.getLastNonDebugInstr();
       if (MBBI->getOpcode() == ARM::SEH_EpilogEnd)
         MBBI--;
@@ -2241,7 +2242,8 @@ bool ARMExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
         // Add the default predicate in Thumb mode.
         if (STI->isThumb())
           MIB.add(predOps(ARMCC::AL));
-      } else if (RetOpcode == ARM::TCRETURNri) {
+      } else if (RetOpcode == ARM::TCRETURNri ||
+                 RetOpcode == ARM::TCRETURNrinotr12) {
         unsigned Opcode =
           STI->isThumb() ? ARM::tTAILJMPr
                          : (STI->hasV4TOps() ? ARM::TAILJMPr : ARM::TAILJMPr4);
diff --git a/llvm/lib/Target/ARM/ARMFrameLowering.cpp b/llvm/lib/Target/ARM/ARMFrameLowering.cpp
index 9b54dd4e4e618d..425a5535b4e750 100644
--- a/llvm/lib/Target/ARM/ARMFrameLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMFrameLowering.cpp
@@ -257,7 +257,8 @@ static int getArgumentStackToRestore(MachineFunction &MF,
   if (MBB.end() != MBBI) {
     unsigned RetOpcode = MBBI->getOpcode();
     IsTailCallReturn = RetOpcode == ARM::TCRETURNdi ||
-                       RetOpcode == ARM::TCRETURNri;
+                       RetOpcode == ARM::TCRETURNri ||
+                       RetOpcode == ARM::TCRETURNrinotr12;
   }
   ARMFunctionInfo *AFI = MF.getInfo<ARMFunctionInfo>();
 
@@ -486,6 +487,7 @@ static MachineBasicBlock::iterator insertSEH(MachineBasicBlock::iterator MBBI,
 
   case ARM::tBX_RET:
   case ARM::TCRETURNri:
+  case ARM::TCRETURNrinotr12:
     MIB = BuildMI(MF, DL, TII.get(ARM::SEH_Nop_Ret))
               .addImm(/*Wide=*/0)
               .setMIFlags(Flags);
@@ -1615,7 +1617,9 @@ void ARMFrameLowering::emitPopInst(MachineBasicBlock &MBB,
   if (MBB.end() != MI) {
     DL = MI->getDebugLoc();
     unsigned RetOpcode = MI->getOpcode();
-    isTailCall = (RetOpcode == ARM::TCRETURNdi || RetOpcode == ARM::TCRETURNri);
+    isTailCall =
+        (RetOpcode == ARM::TCRETURNdi || RetOpcode == ARM::TCRETURNri ||
+         RetOpcode == ARM::TCRETURNrinotr12);
     isInterrupt =
         RetOpcode == ARM::SUBS_PC_LR || RetOpcode == ARM::t2SUBS_PC_LR;
     isTrap =
diff --git a/llvm/lib/Target/ARM/ARMInstrInfo.td b/llvm/lib/Target/ARM/ARMInstrInfo.td
index 08b519e4d5cbf5..1f7bd8dd3121d8 100644
--- a/llvm/lib/Target/ARM/ARMInstrInfo.td
+++ b/llvm/lib/Target/ARM/ARMInstrInfo.td
@@ -2677,6 +2677,9 @@ let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [SP] in {
   def TCRETURNri : PseudoInst<(outs), (ins tcGPR:$dst, i32imm:$SPDiff), IIC_Br, []>,
                    Sched<[WriteBr]>;
 
+  def TCRETURNrinotr12 : PseudoInst<(outs), (ins tcGPRnotr12:$dst, i32imm:$SPDiff), IIC_Br, []>,
+                   Sched<[WriteBr]>;
+
   def TAILJMPd : ARMPseudoExpand<(outs), (ins arm_br_target:$dst),
                                  4, IIC_Br, [],
                                  (Bcc arm_br_target:$dst, (ops 14, zero_reg))>,
@@ -6081,8 +6084,14 @@ def : ARMPat<(ARMWrapperJT tjumptable:$dst),
 // TODO: add,sub,and, 3-instr forms?
 
 // Tail calls. These patterns also apply to Thumb mode.
+// Regular indirect tail call
 def : Pat<(ARMtcret tcGPR:$dst, (i32 timm:$SPDiff)),
-          (TCRETURNri tcGPR:$dst, timm:$SPDiff)>;
+          (TCRETURNri tcGPR:$dst, timm:$SPDiff)>,
+          Requires<[NoSignRetAddr]>;
+// Indirect tail call when PACBTI is enabled
+def : Pat<(ARMtcret tcGPRnotr12:$dst, (i32 timm:$SPDiff)),
+          (TCRETURNrinotr12 tcGPRnotr12:$dst, timm:$SPDiff)>,
+          Requires<[SignRetAddr]>;
 def : Pat<(ARMtcret (i32 tglobaladdr:$dst), (i32 timm:$SPDiff)),
           (TCRETURNdi texternalsym:$dst, (i32 timm:$SPDiff))>;
 def : Pat<(ARMtcret (i32 texternalsym:$dst), (i32 timm:$SPDiff)),
diff --git a/llvm/lib/Target/ARM/ARMPredicates.td b/llvm/lib/Target/ARM/ARMPredicates.td
index aca970d900a8dc..ddc5ad8754eee1 100644
--- a/llvm/lib/Target/ARM/ARMPredicates.td
+++ b/llvm/lib/Target/ARM/ARMPredicates.td
@@ -228,6 +228,10 @@ def DontGenExecuteOnly : Predicate<"!Subtarget->genExecuteOnly()">;
 def GenT1ExecuteOnly : Predicate<"Subtarget->genExecuteOnly() && "
                                  "Subtarget->isThumb1Only() && "
                                  "!Subtarget->hasV8MBaselineOps()">;
+let RecomputePerFunction = 1 in {
+  def SignRetAddr : Predicate<[{ MF->getInfo<ARMFunctionInfo>()->shouldSignReturnAddress(true) }]>;
+  def NoSignRetAddr : Predicate<[{ !MF->getInfo<ARMFunctionInfo>()->shouldSignReturnAddress(true) }]>;
+}
 
 // Armv8.5-A extensions
 def HasSB            : Predicate<"Subtarget->hasSB()">,
diff --git a/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp b/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
index 5d4ae9a7648e69..a6fdece10ba47c 100644
--- a/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
@@ -156,11 +156,6 @@ ARMRegisterBankInfo::ARMRegisterBankInfo(const TargetRegisterInfo &TRI) {
            "Subclass not added?");
     assert(RBGPR.covers(*TRI.getRegClass(ARM::tcGPRRegClassID)) &&
            "Subclass not added?");
-    assert(RBGPR.covers(*TRI.getRegClass(ARM::GPRnoip_and_tcGPRRegClassID)) &&
-           "Subclass not added?");
-    assert(RBGPR.covers(*TRI.getRegClass(
-               ARM::tGPREven_and_GPRnoip_and_tcGPRRegClassID)) &&
-           "Subclass not added?");
     assert(RBGPR.covers(*TRI.getRegClass(ARM::tGPROdd_and_tcGPRRegClassID)) &&
            "Subclass not added?");
     assert(getMaximumSize(RBGPR.getID()) == 32 &&
@@ -188,16 +183,16 @@ ARMRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
   case GPRnoip_and_GPRnopcRegClassID:
   case rGPRRegClassID:
   case GPRspRegClassID:
-  case GPRnoip_and_tcGPRRegClassID:
   case tcGPRRegClassID:
+  case tcGPRnotr12RegClassID:
   case tGPRRegClassID:
   case tGPREvenRegClassID:
   case tGPROddRegClassID:
   case tGPR_and_tGPREvenRegClassID:
   case tGPR_and_tGPROddRegClassID:
   case tGPREven_and_tcGPRRegClassID:
-  case tGPREven_and_GPRnoip_and_tcGPRRegClassID:
   case tGPROdd_and_tcGPRRegClassID:
+  case tGPREven_and_tcGPRnotr12RegClassID:
     return getRegBank(ARM::GPRRegBankID);
   case HPRRegClassID:
   case SPR_8RegClassID:
diff --git a/llvm/lib/Target/ARM/ARMRegisterInfo.td b/llvm/lib/Target/ARM/ARMRegisterInfo.td
index 194d65cad8d170..212f22651f9f94 100644
--- a/llvm/lib/Target/ARM/ARMRegisterInfo.td
+++ b/llvm/lib/Target/ARM/ARMRegisterInfo.td
@@ -373,6 +373,16 @@ def tcGPR : RegisterClass<"ARM", [i32], 32, (add R0, R1, R2, R3, R12)> {
   }];
 }
 
+// Some pointer authentication instructions require the use of R12. When return
+// address signing is enabled, authentication of the caller's return address
+// must be performed before a tail call is made. Therefore, indirect tail call
+// jump cannot be from R12.
+// FIXME: All PACBTI instruction currently implemented in the compiler
+// implicitly use R12. When instructions that allow PAC to be placed in a
+// specific register are implemented the restriction needs to be updated to
+// make sure that PACBTI signature and indirect tail call both use a different register.
+def tcGPRnotr12 : RegisterClass<"ARM", [i32], 32, (add R0, R1, R2, R3)>;
+
 def tGPROdd : RegisterClass<"ARM", [i32], 32, (add R1, R3, R5, R7, R9, R11)> {
   let AltOrders = [(and tGPROdd, tGPR)];
   let AltOrderSelect = [{
diff --git a/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp b/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp
index 0f4ece64bff532..366173cc65a5bc 100644
--- a/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp
+++ b/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp
@@ -1049,9 +1049,9 @@ static void popRegsFromStack(MachineBasicBlock &MBB,
         continue;
 
       if (Reg == ARM::LR) {
-        if (!MBB.succ_empty() ||
-            MI->getOpcode() == ARM::TCRETURNdi ||
-            MI->getOpcode() == ARM::TCRETURNri)
+        if (!MBB.succ_empty() || MI->getOpcode() == ARM::TCRETURNdi ||
+            MI->getOpcode() == ARM::TCRETURNri ||
+            MI->getOpcode() == ARM::TCRETURNrinotr12)
           // LR may only be popped into PC, as part of return sequence.
           // If this isn't the return sequence, we'll need emitPopSpecialFixUp
           // to restore LR the hard way.
diff --git a/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll
new file mode 100644
index 00000000000000..7362b63a0ad62f
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll
@@ -0,0 +1,37 @@
+; RUN: llc -mtriple=thumbv8.1m.main-none-none-eabi -mattr=+pacbti< %s | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+target triple = "thumbv8.1m.main-m.main-unknown"
+
+; When PACBTI is enabled, indirect tail-calls must not use R12 that is used
+; to store authentication code.
+
+define void @pacbti_disabled(ptr %p) "sign-return-address"="none" {
+entry:
+  tail call void %p()
+; CHECK: bx {{r0|r1|r2|r3|r12}}
+  ret void
+}
+
+define void @pacbti_enabled(ptr %p) "sign-return-address"="all" {
+entry:
+  tail call void %p()
+; CHECK: bx {{r0|r1|r2|r3}}
+  ret void
+}
+
+define void @pacbti_disabled_force_r12(ptr %p) "sign-return-address"="none" {
+entry:
+  %p_r12 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %p)
+  tail call void %p_r12()
+; CHECK: bx r12
+  ret void
+}
+
+define void @pacbti_enabled_force_r12(ptr %p) "sign-return-address"="all" {
+entry:
+  %p_r12 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %p)
+  tail call void %p_r12()
+; CHECK: bx {{r0|r1|r2|r3}}
+  ret void
+}
diff --git a/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll
new file mode 100644
index 00000000000000..59499a240fb371
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll
@@ -0,0 +1,17 @@
+; RUN: llc -mtriple=thumbv8.1m.main-none-none-eabi -mattr=+pacbti< %s | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+target triple = "thumbv8.1m.main-m.main-unknown"
+
+define dso_local void @sgign_return_address(ptr noundef readonly %fptr_arg) local_unnamed_addr #0 {
+entry:
+  %0 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %fptr_arg)
+  tail call void %0()
+; CHECK: bx {{r0|r1|r2|r3}}
+  ret void
+}
+
+!llvm.module.flags = !{!1}
+
+!1 = !{i32 8, !"sign-return-address", i32 1}
+
diff --git a/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll
new file mode 100644
index 00000000000000..b2ae55c43c3390
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll
@@ -0,0 +1,18 @@
+; RUN: llc -mtriple=thumbv8.1m.main-none-none-eabi -mattr=+pacbti< %s | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+target triple = "thumbv8.1m.main-m.main-unknown"
+
+define dso_local void @sgign_return_address_all(ptr noundef readonly %fptr_arg) local_unnamed_addr #0 {
+entry:
+  %0 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %fptr_arg)
+  tail call void %0()
+; CHECK: bx {{r0|r1|r2|r3}}
+  ret void
+}
+
+!llvm.module.flags = !{!1}
+
+!1 = !{i32 8, !"sign-return-address", i32 1}
+!2 = !{i32 8, !"sign-return-address-all", i32 1}
+



More information about the llvm-commits mailing list