[llvm] [AArch64][CodeGen] Optimize security cookie check with New Fixup Pass (PR #121938)
Omair Javaid via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 7 06:19:29 PST 2025
https://github.com/omjavaid created https://github.com/llvm/llvm-project/pull/121938
This patch adds the AArch64WinFixupBufferSecurityCheckPass to optimize the handling of buffer security checks on AArch64 Windows targets.
The pass selectively replaces __security_check_cookie calls with inline comparisons, transferring control to the runtime library only on failure.
A similar implementation for X86 Windows target was implemented by #95904
>From 87592ba64c03f96611a52084c0194e33711f9020 Mon Sep 17 00:00:00 2001
From: Muhammad Omair Javaid <omair.javaid at linaro.org>
Date: Mon, 6 Jan 2025 16:31:18 +0500
Subject: [PATCH] Optimize security cookie check with New Fixup Pass
This patch adds the AArch64WinFixupBufferSecurityCheckPass to optimize
the handling of buffer security checks on AArch64 Windows targets.
The pass selectively replaces __security_check_cookie calls with inline
comparisons, transferring control to the runtime library only on failure.
A similar implementation for X86 Windows target was implemented by #95904
---
llvm/lib/Target/AArch64/AArch64.h | 2 +
.../Target/AArch64/AArch64TargetMachine.cpp | 1 +
.../AArch64WinFixupBufferSecurityCheck.cpp | 275 ++++++++++++++++++
llvm/lib/Target/AArch64/CMakeLists.txt | 1 +
.../irtranslator-stack-protector-windows.ll | 9 +-
llvm/test/CodeGen/AArch64/O0-pipeline.ll | 1 +
llvm/test/CodeGen/AArch64/O3-pipeline.ll | 4 +-
7 files changed, 291 insertions(+), 2 deletions(-)
create mode 100644 llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index ffa578d412b3c2..33e40ae9f89190 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -72,6 +72,7 @@ FunctionPass *createAArch64PostLegalizerLowering();
FunctionPass *createAArch64PostSelectOptimize();
FunctionPass *createAArch64StackTaggingPass(bool IsOptNone);
FunctionPass *createAArch64StackTaggingPreRAPass();
+FunctionPass *createAArch64WinFixupBufferSecurityCheckPass();
ModulePass *createAArch64Arm64ECCallLoweringPass();
void initializeAArch64A53Fix835769Pass(PassRegistry&);
@@ -105,6 +106,7 @@ void initializeAArch64SpeculationHardeningPass(PassRegistry &);
void initializeAArch64StackTaggingPass(PassRegistry &);
void initializeAArch64StackTaggingPreRAPass(PassRegistry &);
void initializeAArch64StorePairSuppressPass(PassRegistry&);
+void initializeAArch64WinFixupBufferSecurityCheckPassPass(PassRegistry &);
void initializeFalkorHWPFFixPass(PassRegistry&);
void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&);
void initializeLDTLSCleanupPass(PassRegistry&);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index 07f072446081a3..30d2be9c2a832b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -812,6 +812,7 @@ void AArch64PassConfig::addPreRegAlloc() {
}
if (TM->getOptLevel() != CodeGenOptLevel::None && EnableMachinePipeliner)
addPass(&MachinePipelinerID);
+ addPass(createAArch64WinFixupBufferSecurityCheckPass());
}
void AArch64PassConfig::addPostRegAlloc() {
diff --git a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
new file mode 100644
index 00000000000000..f8c7c332aaf7b8
--- /dev/null
+++ b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
@@ -0,0 +1,275 @@
+//===- AArch64WinFixupBufferSecurityCheck.cpp Fix Buffer Security Check Call
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Buffer Security Check implementation inserts windows specific callback into
+// code. On windows, __security_check_cookie call gets call everytime function
+// is return without fixup. Since this function is defined in runtime library,
+// it incures cost of call in dll which simply does comparison and returns most
+// time. With Fixup, We selective move to call in DLL only if comparison fails.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/IR/Module.h"
+
+#include "AArch64.h"
+#include "AArch64InstrInfo.h"
+#include "AArch64Subtarget.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-win-fixup-bscheck"
+
+namespace {
+
+class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
+public:
+ static char ID;
+
+ AArch64WinFixupBufferSecurityCheckPass() : MachineFunctionPass(ID) {}
+
+ StringRef getPassName() const override {
+ return "AArch64 Windows Fixup Buffer Security Check";
+ }
+
+ bool runOnMachineFunction(MachineFunction &MF) override;
+
+ std::pair<MachineBasicBlock *, MachineInstr *>
+ getSecurityCheckerBasicBlock(MachineFunction &MF);
+
+ MachineInstr *cloneLoadStackGuard(MachineBasicBlock *CurMBB,
+ MachineInstr *CheckCall);
+
+ void getGuardCheckSequence(MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
+ MachineInstr *SeqMI[5]);
+
+ void SplitBasicBlock(MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
+ MachineBasicBlock::iterator SplitIt);
+
+ void FinishBlock(MachineBasicBlock *MBB);
+
+ void FinishFunction(MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB);
+};
+} // end anonymous namespace
+
+char AArch64WinFixupBufferSecurityCheckPass::ID = 0;
+
+INITIALIZE_PASS(AArch64WinFixupBufferSecurityCheckPass, DEBUG_TYPE, DEBUG_TYPE,
+ false, false)
+
+FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
+ return new AArch64WinFixupBufferSecurityCheckPass();
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::SplitBasicBlock(
+ MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
+ MachineBasicBlock::iterator SplitIt) {
+ NewRetMBB->splice(NewRetMBB->end(), CurMBB, SplitIt, CurMBB->end());
+}
+
+std::pair<MachineBasicBlock *, MachineInstr *>
+AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock(
+ MachineFunction &MF) {
+ for (auto &MBB : MF) {
+ for (auto &MI : MBB) {
+ if (MI.getOpcode() == AArch64::BL && MI.getNumExplicitOperands() == 1) {
+ auto MO = MI.getOperand(0);
+ if (MO.isGlobal()) {
+ auto Callee = dyn_cast<Function>(MO.getGlobal());
+ if (Callee && Callee->getName() == "__security_check_cookie") {
+ return std::make_pair(&MBB, &MI);
+ break;
+ }
+ }
+ }
+ }
+ }
+ return std::make_pair(nullptr, nullptr);
+}
+
+MachineInstr *AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(
+ MachineBasicBlock *CurMBB, MachineInstr *CheckCall) {
+ // Ensure that we have a valid MachineBasicBlock and CheckCall
+ if (!CurMBB || !CheckCall)
+ return nullptr;
+
+ MachineFunction &MF = *CurMBB->getParent();
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+
+ // Initialize reverse iterator starting just before CheckCall
+ MachineBasicBlock::reverse_iterator DIt(CheckCall);
+ MachineBasicBlock::reverse_iterator DEnd = CurMBB->rend();
+
+ // Reverse iterate from CheckCall to find LOAD_STACK_GUARD
+ for (; DIt != DEnd; ++DIt) {
+ MachineInstr &MI = *DIt;
+ if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
+ // Clone the LOAD_STACK_GUARD instruction
+ MachineInstr *ClonedInstr = MF.CloneMachineInstr(&MI);
+
+ // Get the register class of the original destination register
+ Register OrigReg = MI.getOperand(0).getReg();
+ const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
+
+ // Create a new virtual register in the same register class
+ Register NewReg = MRI.createVirtualRegister(RegClass);
+
+ // Update operand 0 (destination) of the cloned instruction
+ MachineOperand &DestOperand = ClonedInstr->getOperand(0);
+ if (DestOperand.isReg() && DestOperand.isDef()) {
+ DestOperand.setReg(NewReg); // Set the new virtual register
+ }
+
+ // Return the modified cloned instruction
+ return ClonedInstr;
+ }
+ }
+
+ // If no LOAD_STACK_GUARD instruction was found, return nullptr
+ return nullptr;
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
+ MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
+ MachineInstr *SeqMI[5]) {
+
+ MachineBasicBlock::iterator UIt(CheckCall);
+ MachineBasicBlock::reverse_iterator DIt(CheckCall);
+
+ // Move forward to find the stack adjustment after the call
+ // to __security_check_cookie
+ ++UIt;
+ SeqMI[4] = &*UIt;
+
+ // Assign the BL instruction (call to __security_check_cookie)
+ SeqMI[3] = CheckCall;
+
+ // COPY function slot cookie
+ ++DIt;
+ SeqMI[2] = &*DIt;
+
+ // Move backward to find the instruction that loads the security cookie from
+ // the stack
+ ++DIt;
+ SeqMI[1] = &*DIt;
+
+ ++DIt; // Find ADJCALLSTACKDOWN
+ SeqMI[0] = &*DIt;
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::FinishBlock(
+ MachineBasicBlock *MBB) {
+ LivePhysRegs LiveRegs;
+ computeAndAddLiveIns(LiveRegs, *MBB);
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::FinishFunction(
+ MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB) {
+ FailMBB->getParent()->RenumberBlocks();
+ // FailMBB includes call to MSCV RT where __security_check_cookie
+ // function is called. This function uses regcall and it expects cookie
+ // value from stack slot.( even if this is modified)
+ // Before going further we compute back livein for this block to make sure
+ // it is live and provided.
+ FinishBlock(FailMBB);
+ FinishBlock(NewRetMBB);
+}
+
+bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
+ MachineFunction &MF) {
+ bool Changed = false;
+ const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
+
+ if (!STI.getTargetTriple().isWindowsMSVCEnvironment())
+ return Changed;
+
+ // Check if security cookie was installed or not
+ Module &M = *MF.getFunction().getParent();
+ GlobalVariable *GV = M.getGlobalVariable("__security_cookie");
+ if (!GV)
+ return Changed;
+
+ const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
+
+ // Check if security check cookie call was installed or not
+ auto [CurMBB, CheckCall] = getSecurityCheckerBasicBlock(MF);
+ if (!CheckCall)
+ return Changed;
+
+ // Get sequence of instruction in CurMBB responsible for calling
+ // __security_check_cookie
+ MachineInstr *SeqMI[5];
+ getGuardCheckSequence(CurMBB, CheckCall, SeqMI);
+
+ // Find LOAD_STACK_GUARD in CurrMBB and build a new LOAD_STACK_GUARD
+ // instruction with new destination register
+ MachineInstr *ClonedInstr = cloneLoadStackGuard(CurMBB, CheckCall);
+ if (!ClonedInstr)
+ return Changed;
+
+ // Insert cloned LOAD_STACK_GUARD right before the call to
+ // __security_check_cookie
+ MachineBasicBlock::iterator InsertPt(SeqMI[0]);
+ CurMBB->insert(InsertPt, ClonedInstr);
+
+ auto CookieLoadReg = SeqMI[1]->getOperand(0).getReg();
+ auto GlobalCookieReg = ClonedInstr->getOperand(0).getReg();
+
+ // Move LDRXui that loads __security_cookie from stack, right after
+ // the cloned LOAD_STACK_GUARD
+ CurMBB->splice(InsertPt, CurMBB, std::next(InsertPt));
+
+ // Create a new virtual register for the CMP instruction result
+ Register DiscardReg =
+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+
+ // Emit the CMP instruction to compare stack cookie with global cookie
+ BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::SUBSXrr))
+ .addReg(DiscardReg, RegState::Define | RegState::Dead) // Result discarded
+ .addReg(CookieLoadReg) // First operand: stack cookie
+ .addReg(GlobalCookieReg); // Second operand: global cookie
+
+ // Create FailMBB basic block to call __security_check_cookie
+ MachineBasicBlock *FailMBB = MF.CreateMachineBasicBlock();
+ MF.insert(MF.end(), FailMBB);
+
+ // Create NewRetMBB basic block to skip call to __security_check_cookie
+ MachineBasicBlock *NewRetMBB = MF.CreateMachineBasicBlock();
+ MF.insert(MF.end(), NewRetMBB);
+
+ // Conditional branch to FailMBB if cookies do not match
+ BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::Bcc))
+ .addImm(AArch64CC::NE) // Condition: Not Equal
+ .addMBB(FailMBB); // Failure block
+
+ // Add an unconditional branch to NewRetMBB.
+ BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::B))
+ .addMBB(NewRetMBB);
+
+ // Move fail check squence from CurMBB to FailMBB
+ MachineBasicBlock::iterator U2It(SeqMI[4]);
+ ++U2It;
+ FailMBB->splice(FailMBB->end(), CurMBB, InsertPt, U2It);
+
+ // Insert a BRK instruction at the end of the FailMBB
+ BuildMI(*FailMBB, FailMBB->end(), DebugLoc(), TII->get(AArch64::BRK))
+ .addImm(0); // Immediate value for BRK
+
+ // Move remaining instructions after CheckCall to NewRetMBB.
+ NewRetMBB->splice(NewRetMBB->end(), CurMBB, U2It, CurMBB->end());
+
+ // Restructure Basic Blocks
+ CurMBB->addSuccessor(NewRetMBB);
+ CurMBB->addSuccessor(FailMBB);
+
+ FinishFunction(FailMBB, NewRetMBB);
+
+ return !Changed;
+}
diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt
index 2300e479bc1106..7ee9c3174f38c5 100644
--- a/llvm/lib/Target/AArch64/CMakeLists.txt
+++ b/llvm/lib/Target/AArch64/CMakeLists.txt
@@ -85,6 +85,7 @@ add_llvm_target(AArch64CodeGen
AArch64TargetMachine.cpp
AArch64TargetObjectFile.cpp
AArch64TargetTransformInfo.cpp
+ AArch64WinFixupBufferSecurityCheck.cpp
SMEABIPass.cpp
SMEPeepholeOpt.cpp
SVEIntrinsicOpts.cpp
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
index 6aefc5341da072..1e2ef173638ec5 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
@@ -17,8 +17,12 @@ define void @caller() sspreq {
; CHECK-NEXT: ldr x8, [x8, :lo12:__security_cookie]
; CHECK-NEXT: str x8, [sp, #8]
; CHECK-NEXT: bl callee
+; CHECK-NEXT: adrp x8, __security_cookie
; CHECK-NEXT: ldr x0, [sp, #8]
-; CHECK-NEXT: bl __security_check_cookie
+; CHECK-NEXT: ldr x8, [x8, :lo12:__security_cookie]
+; CHECK-NEXT: subs x8, x0, x8
+; CHECK-NEXT: b.ne .LBB0_2
+; CHECK: // %bb.1:
; CHECK-NEXT: .seh_startepilogue
; CHECK-NEXT: ldr x30, [sp, #16] // 8-byte Folded Reload
; CHECK-NEXT: .seh_save_reg x30, 16
@@ -26,6 +30,9 @@ define void @caller() sspreq {
; CHECK-NEXT: .seh_stackalloc 32
; CHECK-NEXT: .seh_endepilogue
; CHECK-NEXT: ret
+; CHECK: .LBB0_2:
+; CHECK-NEXT: bl __security_check_cookie
+; CHECK-NEXT: brk #0
; CHECK-NEXT: .seh_endfunclet
; CHECK-NEXT: .seh_endproc
entry:
diff --git a/llvm/test/CodeGen/AArch64/O0-pipeline.ll b/llvm/test/CodeGen/AArch64/O0-pipeline.ll
index 0d079881cb909c..f2e2845d6a95b9 100644
--- a/llvm/test/CodeGen/AArch64/O0-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O0-pipeline.ll
@@ -54,6 +54,7 @@
; CHECK-NEXT: AArch64 Instruction Selection
; CHECK-NEXT: Finalize ISel and expand pseudo-instructions
; CHECK-NEXT: Local Stack Slot Allocation
+; CHECK-NEXT: AArch64 Windows Fixup Buffer Security Check
; CHECK-NEXT: Eliminate PHI nodes for register allocation
; CHECK-NEXT: Two-Address instruction pass
; CHECK-NEXT: Fast Register Allocator
diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
index b5d5e27afa17ad..82559bfb0b172f 100644
--- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
@@ -159,14 +159,16 @@
; CHECK-NEXT: Remove dead machine instructions
; CHECK-NEXT: AArch64 MI Peephole Optimization pass
; CHECK-NEXT: AArch64 Dead register definitions
+; CHECK-NEXT: AArch64 Windows Fixup Buffer Security Check
; CHECK-NEXT: Detect Dead Lanes
; CHECK-NEXT: Init Undef Pass
; CHECK-NEXT: Process Implicit Definitions
; CHECK-NEXT: Remove unreachable machine basic blocks
; CHECK-NEXT: Live Variable Analysis
+; CHECK-NEXT: MachineDominator Tree Construction
+; CHECK-NEXT: Machine Natural Loop Construction
; CHECK-NEXT: Eliminate PHI nodes for register allocation
; CHECK-NEXT: Two-Address instruction pass
-; CHECK-NEXT: MachineDominator Tree Construction
; CHECK-NEXT: Slot index numbering
; CHECK-NEXT: Live Interval Analysis
; CHECK-NEXT: Register Coalescer
More information about the llvm-commits
mailing list