[llvm] c12bc57 - Do not use R12 for indirect tail calls with PACBTI (#82661)

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 30 07:29:11 PDT 2024


Author: Eleanor Bonnici
Date: 2024-04-30T15:29:07+01:00
New Revision: c12bc57e23f8c37380ac25e774a60a684fce7bd3

URL: https://github.com/llvm/llvm-project/commit/c12bc57e23f8c37380ac25e774a60a684fce7bd3
DIFF: https://github.com/llvm/llvm-project/commit/c12bc57e23f8c37380ac25e774a60a684fce7bd3.diff

LOG: Do not use R12 for indirect tail calls with PACBTI (#82661)

When compiling for thumbv8.1m with +pacbti and making an indirect tail
call, the compiler was free to put the function pointer into R12.

This is incorrect because R12 is restored to contain authentication code
for the caller's return address.

This patch excludes R12 from the set of registers the compiler can put
the function pointer in.

Fixes https://github.com/llvm/llvm-project/issues/75998

Added: 
    llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll
    llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll
    llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll

Modified: 
    llvm/lib/Target/ARM/ARMBaseInstrInfo.h
    llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
    llvm/lib/Target/ARM/ARMFrameLowering.cpp
    llvm/lib/Target/ARM/ARMInstrInfo.td
    llvm/lib/Target/ARM/ARMPredicates.td
    llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
    llvm/lib/Target/ARM/ARMRegisterInfo.td
    llvm/lib/Target/ARM/Thumb1FrameLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
index 30f0730774b78c..a3c2684ac1fb97 100644
--- a/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
+++ b/llvm/lib/Target/ARM/ARMBaseInstrInfo.h
@@ -683,6 +683,7 @@ static inline bool isIndirectCall(const MachineInstr &MI) {
   case ARM::BX_CALL:
   case ARM::BMOVPCRX_CALL:
   case ARM::TCRETURNri:
+  case ARM::TCRETURNrinotr12:
   case ARM::TAILJMPr:
   case ARM::TAILJMPr4:
   case ARM::tBLXr:

diff  --git a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
index 0f7858a3be9f9f..df10613fcc7c93 100644
--- a/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/ARM/ARMExpandPseudoInsts.cpp
@@ -2197,7 +2197,8 @@ bool ARMExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
     }
 
     case ARM::TCRETURNdi:
-    case ARM::TCRETURNri: {
+    case ARM::TCRETURNri:
+    case ARM::TCRETURNrinotr12: {
       MachineBasicBlock::iterator MBBI = MBB.getLastNonDebugInstr();
       if (MBBI->getOpcode() == ARM::SEH_EpilogEnd)
         MBBI--;
@@ -2241,7 +2242,8 @@ bool ARMExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
         // Add the default predicate in Thumb mode.
         if (STI->isThumb())
           MIB.add(predOps(ARMCC::AL));
-      } else if (RetOpcode == ARM::TCRETURNri) {
+      } else if (RetOpcode == ARM::TCRETURNri ||
+                 RetOpcode == ARM::TCRETURNrinotr12) {
         unsigned Opcode =
           STI->isThumb() ? ARM::tTAILJMPr
                          : (STI->hasV4TOps() ? ARM::TAILJMPr : ARM::TAILJMPr4);

diff  --git a/llvm/lib/Target/ARM/ARMFrameLowering.cpp b/llvm/lib/Target/ARM/ARMFrameLowering.cpp
index a332f743f495b8..11496a6e032dde 100644
--- a/llvm/lib/Target/ARM/ARMFrameLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMFrameLowering.cpp
@@ -257,7 +257,8 @@ static int getArgumentStackToRestore(MachineFunction &MF,
   if (MBB.end() != MBBI) {
     unsigned RetOpcode = MBBI->getOpcode();
     IsTailCallReturn = RetOpcode == ARM::TCRETURNdi ||
-                       RetOpcode == ARM::TCRETURNri;
+                       RetOpcode == ARM::TCRETURNri ||
+                       RetOpcode == ARM::TCRETURNrinotr12;
   }
   ARMFunctionInfo *AFI = MF.getInfo<ARMFunctionInfo>();
 
@@ -486,6 +487,7 @@ static MachineBasicBlock::iterator insertSEH(MachineBasicBlock::iterator MBBI,
 
   case ARM::tBX_RET:
   case ARM::TCRETURNri:
+  case ARM::TCRETURNrinotr12:
     MIB = BuildMI(MF, DL, TII.get(ARM::SEH_Nop_Ret))
               .addImm(/*Wide=*/0)
               .setMIFlags(Flags);
@@ -1615,7 +1617,9 @@ void ARMFrameLowering::emitPopInst(MachineBasicBlock &MBB,
   if (MBB.end() != MI) {
     DL = MI->getDebugLoc();
     unsigned RetOpcode = MI->getOpcode();
-    isTailCall = (RetOpcode == ARM::TCRETURNdi || RetOpcode == ARM::TCRETURNri);
+    isTailCall =
+        (RetOpcode == ARM::TCRETURNdi || RetOpcode == ARM::TCRETURNri ||
+         RetOpcode == ARM::TCRETURNrinotr12);
     isInterrupt =
         RetOpcode == ARM::SUBS_PC_LR || RetOpcode == ARM::t2SUBS_PC_LR;
     isTrap =

diff  --git a/llvm/lib/Target/ARM/ARMInstrInfo.td b/llvm/lib/Target/ARM/ARMInstrInfo.td
index 08b519e4d5cbf5..1f7bd8dd3121d8 100644
--- a/llvm/lib/Target/ARM/ARMInstrInfo.td
+++ b/llvm/lib/Target/ARM/ARMInstrInfo.td
@@ -2677,6 +2677,9 @@ let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Uses = [SP] in {
   def TCRETURNri : PseudoInst<(outs), (ins tcGPR:$dst, i32imm:$SPDiff), IIC_Br, []>,
                    Sched<[WriteBr]>;
 
+  def TCRETURNrinotr12 : PseudoInst<(outs), (ins tcGPRnotr12:$dst, i32imm:$SPDiff), IIC_Br, []>,
+                   Sched<[WriteBr]>;
+
   def TAILJMPd : ARMPseudoExpand<(outs), (ins arm_br_target:$dst),
                                  4, IIC_Br, [],
                                  (Bcc arm_br_target:$dst, (ops 14, zero_reg))>,
@@ -6081,8 +6084,14 @@ def : ARMPat<(ARMWrapperJT tjumptable:$dst),
 // TODO: add,sub,and, 3-instr forms?
 
 // Tail calls. These patterns also apply to Thumb mode.
+// Regular indirect tail call
 def : Pat<(ARMtcret tcGPR:$dst, (i32 timm:$SPDiff)),
-          (TCRETURNri tcGPR:$dst, timm:$SPDiff)>;
+          (TCRETURNri tcGPR:$dst, timm:$SPDiff)>,
+          Requires<[NoSignRetAddr]>;
+// Indirect tail call when PACBTI is enabled
+def : Pat<(ARMtcret tcGPRnotr12:$dst, (i32 timm:$SPDiff)),
+          (TCRETURNrinotr12 tcGPRnotr12:$dst, timm:$SPDiff)>,
+          Requires<[SignRetAddr]>;
 def : Pat<(ARMtcret (i32 tglobaladdr:$dst), (i32 timm:$SPDiff)),
           (TCRETURNdi texternalsym:$dst, (i32 timm:$SPDiff))>;
 def : Pat<(ARMtcret (i32 texternalsym:$dst), (i32 timm:$SPDiff)),

diff  --git a/llvm/lib/Target/ARM/ARMPredicates.td b/llvm/lib/Target/ARM/ARMPredicates.td
index aca970d900a8dc..ddc5ad8754eee1 100644
--- a/llvm/lib/Target/ARM/ARMPredicates.td
+++ b/llvm/lib/Target/ARM/ARMPredicates.td
@@ -228,6 +228,10 @@ def DontGenExecuteOnly : Predicate<"!Subtarget->genExecuteOnly()">;
 def GenT1ExecuteOnly : Predicate<"Subtarget->genExecuteOnly() && "
                                  "Subtarget->isThumb1Only() && "
                                  "!Subtarget->hasV8MBaselineOps()">;
+let RecomputePerFunction = 1 in {
+  def SignRetAddr : Predicate<[{ MF->getInfo<ARMFunctionInfo>()->shouldSignReturnAddress(true) }]>;
+  def NoSignRetAddr : Predicate<[{ !MF->getInfo<ARMFunctionInfo>()->shouldSignReturnAddress(true) }]>;
+}
 
 // Armv8.5-A extensions
 def HasSB            : Predicate<"Subtarget->hasSB()">,

diff  --git a/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp b/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
index 5d4ae9a7648e69..a6fdece10ba47c 100644
--- a/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp
@@ -156,11 +156,6 @@ ARMRegisterBankInfo::ARMRegisterBankInfo(const TargetRegisterInfo &TRI) {
            "Subclass not added?");
     assert(RBGPR.covers(*TRI.getRegClass(ARM::tcGPRRegClassID)) &&
            "Subclass not added?");
-    assert(RBGPR.covers(*TRI.getRegClass(ARM::GPRnoip_and_tcGPRRegClassID)) &&
-           "Subclass not added?");
-    assert(RBGPR.covers(*TRI.getRegClass(
-               ARM::tGPREven_and_GPRnoip_and_tcGPRRegClassID)) &&
-           "Subclass not added?");
     assert(RBGPR.covers(*TRI.getRegClass(ARM::tGPROdd_and_tcGPRRegClassID)) &&
            "Subclass not added?");
     assert(getMaximumSize(RBGPR.getID()) == 32 &&
@@ -188,16 +183,16 @@ ARMRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
   case GPRnoip_and_GPRnopcRegClassID:
   case rGPRRegClassID:
   case GPRspRegClassID:
-  case GPRnoip_and_tcGPRRegClassID:
   case tcGPRRegClassID:
+  case tcGPRnotr12RegClassID:
   case tGPRRegClassID:
   case tGPREvenRegClassID:
   case tGPROddRegClassID:
   case tGPR_and_tGPREvenRegClassID:
   case tGPR_and_tGPROddRegClassID:
   case tGPREven_and_tcGPRRegClassID:
-  case tGPREven_and_GPRnoip_and_tcGPRRegClassID:
   case tGPROdd_and_tcGPRRegClassID:
+  case tGPREven_and_tcGPRnotr12RegClassID:
     return getRegBank(ARM::GPRRegBankID);
   case HPRRegClassID:
   case SPR_8RegClassID:

diff  --git a/llvm/lib/Target/ARM/ARMRegisterInfo.td b/llvm/lib/Target/ARM/ARMRegisterInfo.td
index 194d65cad8d170..212f22651f9f94 100644
--- a/llvm/lib/Target/ARM/ARMRegisterInfo.td
+++ b/llvm/lib/Target/ARM/ARMRegisterInfo.td
@@ -373,6 +373,16 @@ def tcGPR : RegisterClass<"ARM", [i32], 32, (add R0, R1, R2, R3, R12)> {
   }];
 }
 
+// Some pointer authentication instructions require the use of R12. When return
+// address signing is enabled, authentication of the caller's return address
+// must be performed before a tail call is made. Therefore, indirect tail call
+// jump cannot be from R12.
+// FIXME: All PACBTI instruction currently implemented in the compiler
+// implicitly use R12. When instructions that allow PAC to be placed in a
+// specific register are implemented the restriction needs to be updated to
+// make sure that PACBTI signature and indirect tail call both use a 
diff erent register.
+def tcGPRnotr12 : RegisterClass<"ARM", [i32], 32, (add R0, R1, R2, R3)>;
+
 def tGPROdd : RegisterClass<"ARM", [i32], 32, (add R1, R3, R5, R7, R9, R11)> {
   let AltOrders = [(and tGPROdd, tGPR)];
   let AltOrderSelect = [{

diff  --git a/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp b/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp
index 047c6731333c9b..e908f1fb951247 100644
--- a/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp
+++ b/llvm/lib/Target/ARM/Thumb1FrameLowering.cpp
@@ -1044,9 +1044,9 @@ static void popRegsFromStack(MachineBasicBlock &MBB,
         continue;
 
       if (Reg == ARM::LR) {
-        if (!MBB.succ_empty() ||
-            MI->getOpcode() == ARM::TCRETURNdi ||
-            MI->getOpcode() == ARM::TCRETURNri)
+        if (!MBB.succ_empty() || MI->getOpcode() == ARM::TCRETURNdi ||
+            MI->getOpcode() == ARM::TCRETURNri ||
+            MI->getOpcode() == ARM::TCRETURNrinotr12)
           // LR may only be popped into PC, as part of return sequence.
           // If this isn't the return sequence, we'll need emitPopSpecialFixUp
           // to restore LR the hard way.

diff  --git a/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll
new file mode 100644
index 00000000000000..7362b63a0ad62f
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-function-flags.ll
@@ -0,0 +1,37 @@
+; RUN: llc -mtriple=thumbv8.1m.main-none-none-eabi -mattr=+pacbti< %s | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+target triple = "thumbv8.1m.main-m.main-unknown"
+
+; When PACBTI is enabled, indirect tail-calls must not use R12 that is used
+; to store authentication code.
+
+define void @pacbti_disabled(ptr %p) "sign-return-address"="none" {
+entry:
+  tail call void %p()
+; CHECK: bx {{r0|r1|r2|r3|r12}}
+  ret void
+}
+
+define void @pacbti_enabled(ptr %p) "sign-return-address"="all" {
+entry:
+  tail call void %p()
+; CHECK: bx {{r0|r1|r2|r3}}
+  ret void
+}
+
+define void @pacbti_disabled_force_r12(ptr %p) "sign-return-address"="none" {
+entry:
+  %p_r12 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %p)
+  tail call void %p_r12()
+; CHECK: bx r12
+  ret void
+}
+
+define void @pacbti_enabled_force_r12(ptr %p) "sign-return-address"="all" {
+entry:
+  %p_r12 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %p)
+  tail call void %p_r12()
+; CHECK: bx {{r0|r1|r2|r3}}
+  ret void
+}

diff  --git a/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll
new file mode 100644
index 00000000000000..59499a240fb371
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags1.ll
@@ -0,0 +1,17 @@
+; RUN: llc -mtriple=thumbv8.1m.main-none-none-eabi -mattr=+pacbti< %s | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+target triple = "thumbv8.1m.main-m.main-unknown"
+
+define dso_local void @sgign_return_address(ptr noundef readonly %fptr_arg) local_unnamed_addr #0 {
+entry:
+  %0 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %fptr_arg)
+  tail call void %0()
+; CHECK: bx {{r0|r1|r2|r3}}
+  ret void
+}
+
+!llvm.module.flags = !{!1}
+
+!1 = !{i32 8, !"sign-return-address", i32 1}
+

diff  --git a/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll
new file mode 100644
index 00000000000000..b2ae55c43c3390
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/pacbti-indirect-tail-calls-module-flags2.ll
@@ -0,0 +1,18 @@
+; RUN: llc -mtriple=thumbv8.1m.main-none-none-eabi -mattr=+pacbti< %s | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+target triple = "thumbv8.1m.main-m.main-unknown"
+
+define dso_local void @sgign_return_address_all(ptr noundef readonly %fptr_arg) local_unnamed_addr #0 {
+entry:
+  %0 = tail call ptr asm "", "={r12},{r12},~{lr}"(ptr %fptr_arg)
+  tail call void %0()
+; CHECK: bx {{r0|r1|r2|r3}}
+  ret void
+}
+
+!llvm.module.flags = !{!1}
+
+!1 = !{i32 8, !"sign-return-address", i32 1}
+!2 = !{i32 8, !"sign-return-address-all", i32 1}
+


        


More information about the llvm-commits mailing list