[llvm] [AArch64][CodeGen] Optimize security cookie check with New Fixup Pass (PR #121938)
Omair Javaid via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 16 00:30:03 PST 2025
https://github.com/omjavaid updated https://github.com/llvm/llvm-project/pull/121938
>From 87592ba64c03f96611a52084c0194e33711f9020 Mon Sep 17 00:00:00 2001
From: Muhammad Omair Javaid <omair.javaid at linaro.org>
Date: Mon, 6 Jan 2025 16:31:18 +0500
Subject: [PATCH 1/3] Optimize security cookie check with New Fixup Pass
This patch adds the AArch64WinFixupBufferSecurityCheckPass to optimize
the handling of buffer security checks on AArch64 Windows targets.
The pass selectively replaces __security_check_cookie calls with inline
comparisons, transferring control to the runtime library only on failure.
A similar implementation for X86 Windows target was implemented by #95904
---
llvm/lib/Target/AArch64/AArch64.h | 2 +
.../Target/AArch64/AArch64TargetMachine.cpp | 1 +
.../AArch64WinFixupBufferSecurityCheck.cpp | 275 ++++++++++++++++++
llvm/lib/Target/AArch64/CMakeLists.txt | 1 +
.../irtranslator-stack-protector-windows.ll | 9 +-
llvm/test/CodeGen/AArch64/O0-pipeline.ll | 1 +
llvm/test/CodeGen/AArch64/O3-pipeline.ll | 4 +-
7 files changed, 291 insertions(+), 2 deletions(-)
create mode 100644 llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index ffa578d412b3c2..33e40ae9f89190 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -72,6 +72,7 @@ FunctionPass *createAArch64PostLegalizerLowering();
FunctionPass *createAArch64PostSelectOptimize();
FunctionPass *createAArch64StackTaggingPass(bool IsOptNone);
FunctionPass *createAArch64StackTaggingPreRAPass();
+FunctionPass *createAArch64WinFixupBufferSecurityCheckPass();
ModulePass *createAArch64Arm64ECCallLoweringPass();
void initializeAArch64A53Fix835769Pass(PassRegistry&);
@@ -105,6 +106,7 @@ void initializeAArch64SpeculationHardeningPass(PassRegistry &);
void initializeAArch64StackTaggingPass(PassRegistry &);
void initializeAArch64StackTaggingPreRAPass(PassRegistry &);
void initializeAArch64StorePairSuppressPass(PassRegistry&);
+void initializeAArch64WinFixupBufferSecurityCheckPassPass(PassRegistry &);
void initializeFalkorHWPFFixPass(PassRegistry&);
void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&);
void initializeLDTLSCleanupPass(PassRegistry&);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index 07f072446081a3..30d2be9c2a832b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -812,6 +812,7 @@ void AArch64PassConfig::addPreRegAlloc() {
}
if (TM->getOptLevel() != CodeGenOptLevel::None && EnableMachinePipeliner)
addPass(&MachinePipelinerID);
+ addPass(createAArch64WinFixupBufferSecurityCheckPass());
}
void AArch64PassConfig::addPostRegAlloc() {
diff --git a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
new file mode 100644
index 00000000000000..f8c7c332aaf7b8
--- /dev/null
+++ b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
@@ -0,0 +1,275 @@
+//===- AArch64WinFixupBufferSecurityCheck.cpp Fix Buffer Security Check Call
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Buffer Security Check implementation inserts windows specific callback into
+// code. On windows, __security_check_cookie call gets call everytime function
+// is return without fixup. Since this function is defined in runtime library,
+// it incures cost of call in dll which simply does comparison and returns most
+// time. With Fixup, We selective move to call in DLL only if comparison fails.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/IR/Module.h"
+
+#include "AArch64.h"
+#include "AArch64InstrInfo.h"
+#include "AArch64Subtarget.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-win-fixup-bscheck"
+
+namespace {
+
+class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
+public:
+ static char ID;
+
+ AArch64WinFixupBufferSecurityCheckPass() : MachineFunctionPass(ID) {}
+
+ StringRef getPassName() const override {
+ return "AArch64 Windows Fixup Buffer Security Check";
+ }
+
+ bool runOnMachineFunction(MachineFunction &MF) override;
+
+ std::pair<MachineBasicBlock *, MachineInstr *>
+ getSecurityCheckerBasicBlock(MachineFunction &MF);
+
+ MachineInstr *cloneLoadStackGuard(MachineBasicBlock *CurMBB,
+ MachineInstr *CheckCall);
+
+ void getGuardCheckSequence(MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
+ MachineInstr *SeqMI[5]);
+
+ void SplitBasicBlock(MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
+ MachineBasicBlock::iterator SplitIt);
+
+ void FinishBlock(MachineBasicBlock *MBB);
+
+ void FinishFunction(MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB);
+};
+} // end anonymous namespace
+
+char AArch64WinFixupBufferSecurityCheckPass::ID = 0;
+
+INITIALIZE_PASS(AArch64WinFixupBufferSecurityCheckPass, DEBUG_TYPE, DEBUG_TYPE,
+ false, false)
+
+FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
+ return new AArch64WinFixupBufferSecurityCheckPass();
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::SplitBasicBlock(
+ MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
+ MachineBasicBlock::iterator SplitIt) {
+ NewRetMBB->splice(NewRetMBB->end(), CurMBB, SplitIt, CurMBB->end());
+}
+
+std::pair<MachineBasicBlock *, MachineInstr *>
+AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock(
+ MachineFunction &MF) {
+ for (auto &MBB : MF) {
+ for (auto &MI : MBB) {
+ if (MI.getOpcode() == AArch64::BL && MI.getNumExplicitOperands() == 1) {
+ auto MO = MI.getOperand(0);
+ if (MO.isGlobal()) {
+ auto Callee = dyn_cast<Function>(MO.getGlobal());
+ if (Callee && Callee->getName() == "__security_check_cookie") {
+ return std::make_pair(&MBB, &MI);
+ break;
+ }
+ }
+ }
+ }
+ }
+ return std::make_pair(nullptr, nullptr);
+}
+
+MachineInstr *AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(
+ MachineBasicBlock *CurMBB, MachineInstr *CheckCall) {
+ // Ensure that we have a valid MachineBasicBlock and CheckCall
+ if (!CurMBB || !CheckCall)
+ return nullptr;
+
+ MachineFunction &MF = *CurMBB->getParent();
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+
+ // Initialize reverse iterator starting just before CheckCall
+ MachineBasicBlock::reverse_iterator DIt(CheckCall);
+ MachineBasicBlock::reverse_iterator DEnd = CurMBB->rend();
+
+ // Reverse iterate from CheckCall to find LOAD_STACK_GUARD
+ for (; DIt != DEnd; ++DIt) {
+ MachineInstr &MI = *DIt;
+ if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
+ // Clone the LOAD_STACK_GUARD instruction
+ MachineInstr *ClonedInstr = MF.CloneMachineInstr(&MI);
+
+ // Get the register class of the original destination register
+ Register OrigReg = MI.getOperand(0).getReg();
+ const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
+
+ // Create a new virtual register in the same register class
+ Register NewReg = MRI.createVirtualRegister(RegClass);
+
+ // Update operand 0 (destination) of the cloned instruction
+ MachineOperand &DestOperand = ClonedInstr->getOperand(0);
+ if (DestOperand.isReg() && DestOperand.isDef()) {
+ DestOperand.setReg(NewReg); // Set the new virtual register
+ }
+
+ // Return the modified cloned instruction
+ return ClonedInstr;
+ }
+ }
+
+ // If no LOAD_STACK_GUARD instruction was found, return nullptr
+ return nullptr;
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
+ MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
+ MachineInstr *SeqMI[5]) {
+
+ MachineBasicBlock::iterator UIt(CheckCall);
+ MachineBasicBlock::reverse_iterator DIt(CheckCall);
+
+ // Move forward to find the stack adjustment after the call
+ // to __security_check_cookie
+ ++UIt;
+ SeqMI[4] = &*UIt;
+
+ // Assign the BL instruction (call to __security_check_cookie)
+ SeqMI[3] = CheckCall;
+
+ // COPY function slot cookie
+ ++DIt;
+ SeqMI[2] = &*DIt;
+
+ // Move backward to find the instruction that loads the security cookie from
+ // the stack
+ ++DIt;
+ SeqMI[1] = &*DIt;
+
+ ++DIt; // Find ADJCALLSTACKDOWN
+ SeqMI[0] = &*DIt;
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::FinishBlock(
+ MachineBasicBlock *MBB) {
+ LivePhysRegs LiveRegs;
+ computeAndAddLiveIns(LiveRegs, *MBB);
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::FinishFunction(
+ MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB) {
+ FailMBB->getParent()->RenumberBlocks();
+ // FailMBB includes call to MSCV RT where __security_check_cookie
+ // function is called. This function uses regcall and it expects cookie
+ // value from stack slot.( even if this is modified)
+ // Before going further we compute back livein for this block to make sure
+ // it is live and provided.
+ FinishBlock(FailMBB);
+ FinishBlock(NewRetMBB);
+}
+
+bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
+ MachineFunction &MF) {
+ bool Changed = false;
+ const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
+
+ if (!STI.getTargetTriple().isWindowsMSVCEnvironment())
+ return Changed;
+
+ // Check if security cookie was installed or not
+ Module &M = *MF.getFunction().getParent();
+ GlobalVariable *GV = M.getGlobalVariable("__security_cookie");
+ if (!GV)
+ return Changed;
+
+ const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
+
+ // Check if security check cookie call was installed or not
+ auto [CurMBB, CheckCall] = getSecurityCheckerBasicBlock(MF);
+ if (!CheckCall)
+ return Changed;
+
+ // Get sequence of instruction in CurMBB responsible for calling
+ // __security_check_cookie
+ MachineInstr *SeqMI[5];
+ getGuardCheckSequence(CurMBB, CheckCall, SeqMI);
+
+ // Find LOAD_STACK_GUARD in CurrMBB and build a new LOAD_STACK_GUARD
+ // instruction with new destination register
+ MachineInstr *ClonedInstr = cloneLoadStackGuard(CurMBB, CheckCall);
+ if (!ClonedInstr)
+ return Changed;
+
+ // Insert cloned LOAD_STACK_GUARD right before the call to
+ // __security_check_cookie
+ MachineBasicBlock::iterator InsertPt(SeqMI[0]);
+ CurMBB->insert(InsertPt, ClonedInstr);
+
+ auto CookieLoadReg = SeqMI[1]->getOperand(0).getReg();
+ auto GlobalCookieReg = ClonedInstr->getOperand(0).getReg();
+
+ // Move LDRXui that loads __security_cookie from stack, right after
+ // the cloned LOAD_STACK_GUARD
+ CurMBB->splice(InsertPt, CurMBB, std::next(InsertPt));
+
+ // Create a new virtual register for the CMP instruction result
+ Register DiscardReg =
+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+
+ // Emit the CMP instruction to compare stack cookie with global cookie
+ BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::SUBSXrr))
+ .addReg(DiscardReg, RegState::Define | RegState::Dead) // Result discarded
+ .addReg(CookieLoadReg) // First operand: stack cookie
+ .addReg(GlobalCookieReg); // Second operand: global cookie
+
+ // Create FailMBB basic block to call __security_check_cookie
+ MachineBasicBlock *FailMBB = MF.CreateMachineBasicBlock();
+ MF.insert(MF.end(), FailMBB);
+
+ // Create NewRetMBB basic block to skip call to __security_check_cookie
+ MachineBasicBlock *NewRetMBB = MF.CreateMachineBasicBlock();
+ MF.insert(MF.end(), NewRetMBB);
+
+ // Conditional branch to FailMBB if cookies do not match
+ BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::Bcc))
+ .addImm(AArch64CC::NE) // Condition: Not Equal
+ .addMBB(FailMBB); // Failure block
+
+ // Add an unconditional branch to NewRetMBB.
+ BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::B))
+ .addMBB(NewRetMBB);
+
+ // Move fail check squence from CurMBB to FailMBB
+ MachineBasicBlock::iterator U2It(SeqMI[4]);
+ ++U2It;
+ FailMBB->splice(FailMBB->end(), CurMBB, InsertPt, U2It);
+
+ // Insert a BRK instruction at the end of the FailMBB
+ BuildMI(*FailMBB, FailMBB->end(), DebugLoc(), TII->get(AArch64::BRK))
+ .addImm(0); // Immediate value for BRK
+
+ // Move remaining instructions after CheckCall to NewRetMBB.
+ NewRetMBB->splice(NewRetMBB->end(), CurMBB, U2It, CurMBB->end());
+
+ // Restructure Basic Blocks
+ CurMBB->addSuccessor(NewRetMBB);
+ CurMBB->addSuccessor(FailMBB);
+
+ FinishFunction(FailMBB, NewRetMBB);
+
+ return !Changed;
+}
diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt
index 2300e479bc1106..7ee9c3174f38c5 100644
--- a/llvm/lib/Target/AArch64/CMakeLists.txt
+++ b/llvm/lib/Target/AArch64/CMakeLists.txt
@@ -85,6 +85,7 @@ add_llvm_target(AArch64CodeGen
AArch64TargetMachine.cpp
AArch64TargetObjectFile.cpp
AArch64TargetTransformInfo.cpp
+ AArch64WinFixupBufferSecurityCheck.cpp
SMEABIPass.cpp
SMEPeepholeOpt.cpp
SVEIntrinsicOpts.cpp
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
index 6aefc5341da072..1e2ef173638ec5 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
@@ -17,8 +17,12 @@ define void @caller() sspreq {
; CHECK-NEXT: ldr x8, [x8, :lo12:__security_cookie]
; CHECK-NEXT: str x8, [sp, #8]
; CHECK-NEXT: bl callee
+; CHECK-NEXT: adrp x8, __security_cookie
; CHECK-NEXT: ldr x0, [sp, #8]
-; CHECK-NEXT: bl __security_check_cookie
+; CHECK-NEXT: ldr x8, [x8, :lo12:__security_cookie]
+; CHECK-NEXT: subs x8, x0, x8
+; CHECK-NEXT: b.ne .LBB0_2
+; CHECK: // %bb.1:
; CHECK-NEXT: .seh_startepilogue
; CHECK-NEXT: ldr x30, [sp, #16] // 8-byte Folded Reload
; CHECK-NEXT: .seh_save_reg x30, 16
@@ -26,6 +30,9 @@ define void @caller() sspreq {
; CHECK-NEXT: .seh_stackalloc 32
; CHECK-NEXT: .seh_endepilogue
; CHECK-NEXT: ret
+; CHECK: .LBB0_2:
+; CHECK-NEXT: bl __security_check_cookie
+; CHECK-NEXT: brk #0
; CHECK-NEXT: .seh_endfunclet
; CHECK-NEXT: .seh_endproc
entry:
diff --git a/llvm/test/CodeGen/AArch64/O0-pipeline.ll b/llvm/test/CodeGen/AArch64/O0-pipeline.ll
index 0d079881cb909c..f2e2845d6a95b9 100644
--- a/llvm/test/CodeGen/AArch64/O0-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O0-pipeline.ll
@@ -54,6 +54,7 @@
; CHECK-NEXT: AArch64 Instruction Selection
; CHECK-NEXT: Finalize ISel and expand pseudo-instructions
; CHECK-NEXT: Local Stack Slot Allocation
+; CHECK-NEXT: AArch64 Windows Fixup Buffer Security Check
; CHECK-NEXT: Eliminate PHI nodes for register allocation
; CHECK-NEXT: Two-Address instruction pass
; CHECK-NEXT: Fast Register Allocator
diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
index b5d5e27afa17ad..82559bfb0b172f 100644
--- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
@@ -159,14 +159,16 @@
; CHECK-NEXT: Remove dead machine instructions
; CHECK-NEXT: AArch64 MI Peephole Optimization pass
; CHECK-NEXT: AArch64 Dead register definitions
+; CHECK-NEXT: AArch64 Windows Fixup Buffer Security Check
; CHECK-NEXT: Detect Dead Lanes
; CHECK-NEXT: Init Undef Pass
; CHECK-NEXT: Process Implicit Definitions
; CHECK-NEXT: Remove unreachable machine basic blocks
; CHECK-NEXT: Live Variable Analysis
+; CHECK-NEXT: MachineDominator Tree Construction
+; CHECK-NEXT: Machine Natural Loop Construction
; CHECK-NEXT: Eliminate PHI nodes for register allocation
; CHECK-NEXT: Two-Address instruction pass
-; CHECK-NEXT: MachineDominator Tree Construction
; CHECK-NEXT: Slot index numbering
; CHECK-NEXT: Live Interval Analysis
; CHECK-NEXT: Register Coalescer
>From 70ad30a22dbafc747e6cfc1c37cfc604f0c062c9 Mon Sep 17 00:00:00 2001
From: Muhammad Omair Javaid <omair.javaid at linaro.org>
Date: Thu, 9 Jan 2025 01:17:24 +0500
Subject: [PATCH 2/3] Fix arsenm's review comments
---
.../AArch64WinFixupBufferSecurityCheck.cpp | 29 ++++++-------------
1 file changed, 9 insertions(+), 20 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
index f8c7c332aaf7b8..072d6bb69d0c11 100644
--- a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
+++ b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
@@ -1,5 +1,4 @@
-//===- AArch64WinFixupBufferSecurityCheck.cpp Fix Buffer Security Check Call
-//-===//
+//===- AArch64WinFixupBufferSecurityCheck.cpp Fixup Buffer Security Check -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -50,12 +49,9 @@ class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
void getGuardCheckSequence(MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
MachineInstr *SeqMI[5]);
- void SplitBasicBlock(MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
- MachineBasicBlock::iterator SplitIt);
+ void finishBlock(MachineBasicBlock *MBB);
- void FinishBlock(MachineBasicBlock *MBB);
-
- void FinishFunction(MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB);
+ void finishFunction(MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB);
};
} // end anonymous namespace
@@ -68,24 +64,17 @@ FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
return new AArch64WinFixupBufferSecurityCheckPass();
}
-void AArch64WinFixupBufferSecurityCheckPass::SplitBasicBlock(
- MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
- MachineBasicBlock::iterator SplitIt) {
- NewRetMBB->splice(NewRetMBB->end(), CurMBB, SplitIt, CurMBB->end());
-}
-
std::pair<MachineBasicBlock *, MachineInstr *>
AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock(
MachineFunction &MF) {
for (auto &MBB : MF) {
for (auto &MI : MBB) {
- if (MI.getOpcode() == AArch64::BL && MI.getNumExplicitOperands() == 1) {
+ if (MI.isCall() && MI.getNumExplicitOperands() == 1) {
auto MO = MI.getOperand(0);
if (MO.isGlobal()) {
auto Callee = dyn_cast<Function>(MO.getGlobal());
if (Callee && Callee->getName() == "__security_check_cookie") {
return std::make_pair(&MBB, &MI);
- break;
}
}
}
@@ -164,13 +153,13 @@ void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
SeqMI[0] = &*DIt;
}
-void AArch64WinFixupBufferSecurityCheckPass::FinishBlock(
+void AArch64WinFixupBufferSecurityCheckPass::finishBlock(
MachineBasicBlock *MBB) {
LivePhysRegs LiveRegs;
computeAndAddLiveIns(LiveRegs, *MBB);
}
-void AArch64WinFixupBufferSecurityCheckPass::FinishFunction(
+void AArch64WinFixupBufferSecurityCheckPass::finishFunction(
MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB) {
FailMBB->getParent()->RenumberBlocks();
// FailMBB includes call to MSCV RT where __security_check_cookie
@@ -178,8 +167,8 @@ void AArch64WinFixupBufferSecurityCheckPass::FinishFunction(
// value from stack slot.( even if this is modified)
// Before going further we compute back livein for this block to make sure
// it is live and provided.
- FinishBlock(FailMBB);
- FinishBlock(NewRetMBB);
+ finishBlock(FailMBB);
+ finishBlock(NewRetMBB);
}
bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
@@ -269,7 +258,7 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
CurMBB->addSuccessor(NewRetMBB);
CurMBB->addSuccessor(FailMBB);
- FinishFunction(FailMBB, NewRetMBB);
+ finishFunction(FailMBB, NewRetMBB);
return !Changed;
}
>From 1abf5e8ed5964ba557107ac8c5ae0a2363c63559 Mon Sep 17 00:00:00 2001
From: Muhammad Omair Javaid <omair.javaid at linaro.org>
Date: Tue, 14 Jan 2025 17:31:23 +0500
Subject: [PATCH 3/3] Fix review comments by Eli
- Fix LOAD_STACK_GUARD detection logic
- Make transform run when security check sequence matched
- Preserve dominator tree
---
.../AArch64WinFixupBufferSecurityCheck.cpp | 156 +++++++++++-------
llvm/test/CodeGen/AArch64/O3-pipeline.ll | 3 +-
2 files changed, 96 insertions(+), 63 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
index 072d6bb69d0c11..0a37024c24e02e 100644
--- a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
+++ b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
@@ -13,8 +13,10 @@
//===----------------------------------------------------------------------===//
#include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineLoopInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/IR/Module.h"
@@ -40,14 +42,14 @@ class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
bool runOnMachineFunction(MachineFunction &MF) override;
- std::pair<MachineBasicBlock *, MachineInstr *>
- getSecurityCheckerBasicBlock(MachineFunction &MF);
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
- MachineInstr *cloneLoadStackGuard(MachineBasicBlock *CurMBB,
- MachineInstr *CheckCall);
+ std::pair<MachineInstr *, MachineInstr *>
+ findSecurityCheckAndLoadStackGuard(MachineFunction &MF);
- void getGuardCheckSequence(MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
- MachineInstr *SeqMI[5]);
+ MachineInstr *cloneLoadStackGuard(MachineFunction &MF, MachineInstr *MI);
+
+ bool getGuardCheckSequence(MachineInstr *CheckCall, MachineInstr *SeqMI[5]);
void finishBlock(MachineBasicBlock *MBB);
@@ -64,93 +66,113 @@ FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
return new AArch64WinFixupBufferSecurityCheckPass();
}
-std::pair<MachineBasicBlock *, MachineInstr *>
-AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock(
+void AArch64WinFixupBufferSecurityCheckPass::getAnalysisUsage(
+ AnalysisUsage &AU) const {
+ AU.addUsedIfAvailable<MachineDominatorTreeWrapperPass>();
+ AU.addPreserved<MachineDominatorTreeWrapperPass>();
+ AU.addPreserved<MachineLoopInfoWrapperPass>();
+ MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+std::pair<MachineInstr *, MachineInstr *>
+AArch64WinFixupBufferSecurityCheckPass::findSecurityCheckAndLoadStackGuard(
MachineFunction &MF) {
+
+ MachineInstr *SecurityCheckCall = nullptr;
+ MachineInstr *LoadStackGuard = nullptr;
+
for (auto &MBB : MF) {
for (auto &MI : MBB) {
+ if (!LoadStackGuard && MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
+ LoadStackGuard = &MI;
+ }
+
if (MI.isCall() && MI.getNumExplicitOperands() == 1) {
auto MO = MI.getOperand(0);
if (MO.isGlobal()) {
auto Callee = dyn_cast<Function>(MO.getGlobal());
if (Callee && Callee->getName() == "__security_check_cookie") {
- return std::make_pair(&MBB, &MI);
+ SecurityCheckCall = &MI;
}
}
}
+
+ // If both are found, return them
+ if (LoadStackGuard && SecurityCheckCall) {
+ return std::make_pair(LoadStackGuard, SecurityCheckCall);
+ }
}
}
+
return std::make_pair(nullptr, nullptr);
}
-MachineInstr *AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(
- MachineBasicBlock *CurMBB, MachineInstr *CheckCall) {
- // Ensure that we have a valid MachineBasicBlock and CheckCall
- if (!CurMBB || !CheckCall)
- return nullptr;
+MachineInstr *
+AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(MachineFunction &MF,
+ MachineInstr *MI) {
- MachineFunction &MF = *CurMBB->getParent();
+ MachineInstr *ClonedInstr = MF.CloneMachineInstr(MI);
+
+ // Get the register class of the original destination register
+ Register OrigReg = MI->getOperand(0).getReg();
MachineRegisterInfo &MRI = MF.getRegInfo();
+ const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
- // Initialize reverse iterator starting just before CheckCall
- MachineBasicBlock::reverse_iterator DIt(CheckCall);
- MachineBasicBlock::reverse_iterator DEnd = CurMBB->rend();
-
- // Reverse iterate from CheckCall to find LOAD_STACK_GUARD
- for (; DIt != DEnd; ++DIt) {
- MachineInstr &MI = *DIt;
- if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
- // Clone the LOAD_STACK_GUARD instruction
- MachineInstr *ClonedInstr = MF.CloneMachineInstr(&MI);
-
- // Get the register class of the original destination register
- Register OrigReg = MI.getOperand(0).getReg();
- const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
-
- // Create a new virtual register in the same register class
- Register NewReg = MRI.createVirtualRegister(RegClass);
-
- // Update operand 0 (destination) of the cloned instruction
- MachineOperand &DestOperand = ClonedInstr->getOperand(0);
- if (DestOperand.isReg() && DestOperand.isDef()) {
- DestOperand.setReg(NewReg); // Set the new virtual register
- }
+ // Create a new virtual register in the same register class
+ Register NewReg = MRI.createVirtualRegister(RegClass);
- // Return the modified cloned instruction
- return ClonedInstr;
- }
+ // Update operand 0 (destination) of the cloned instruction
+ MachineOperand &DestOperand = ClonedInstr->getOperand(0);
+ if (DestOperand.isReg() && DestOperand.isDef()) {
+ DestOperand.setReg(NewReg); // Set the new virtual register
}
- // If no LOAD_STACK_GUARD instruction was found, return nullptr
- return nullptr;
+ return ClonedInstr;
}
-void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
- MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
- MachineInstr *SeqMI[5]) {
+bool AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
+ MachineInstr *CheckCall, MachineInstr *SeqMI[5]) {
+
+ MachineBasicBlock *MBB = CheckCall->getParent();
MachineBasicBlock::iterator UIt(CheckCall);
MachineBasicBlock::reverse_iterator DIt(CheckCall);
// Move forward to find the stack adjustment after the call
- // to __security_check_cookie
++UIt;
+ if (UIt == MBB->end() || UIt->getOpcode() != AArch64::ADJCALLSTACKUP) {
+ return false;
+ }
SeqMI[4] = &*UIt;
// Assign the BL instruction (call to __security_check_cookie)
SeqMI[3] = CheckCall;
- // COPY function slot cookie
+ // Move backward to find the COPY instruction for the function slot cookie
+ // argument passing
++DIt;
+ if (DIt == MBB->rend() || DIt->getOpcode() != AArch64::COPY) {
+ return false;
+ }
SeqMI[2] = &*DIt;
// Move backward to find the instruction that loads the security cookie from
// the stack
++DIt;
+ if (DIt == MBB->rend() || DIt->getOpcode() != AArch64::LDRXui) {
+ return false;
+ }
SeqMI[1] = &*DIt;
- ++DIt; // Find ADJCALLSTACKDOWN
+ // Move backward to find the stack adjustment before the call
+ ++DIt;
+ if (DIt == MBB->rend() || DIt->getOpcode() != AArch64::ADJCALLSTACKDOWN) {
+ return false;
+ }
SeqMI[0] = &*DIt;
+
+ // If all instructions are matched and stored, the sequence is valid
+ return true;
}
void AArch64WinFixupBufferSecurityCheckPass::finishBlock(
@@ -185,21 +207,23 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
if (!GV)
return Changed;
- const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
-
- // Check if security check cookie call was installed or not
- auto [CurMBB, CheckCall] = getSecurityCheckerBasicBlock(MF);
- if (!CheckCall)
+ // Find LOAD_STACK_GUARD and __security_check_cookie instructions
+ auto [StackGuard, CheckCall] = findSecurityCheckAndLoadStackGuard(MF);
+ if (!CheckCall || !StackGuard)
return Changed;
- // Get sequence of instruction in CurMBB responsible for calling
+ // Get sequence of instructions in current basic block responsible for calling
// __security_check_cookie
MachineInstr *SeqMI[5];
- getGuardCheckSequence(CurMBB, CheckCall, SeqMI);
+ if (!getGuardCheckSequence(CheckCall, SeqMI))
+ return Changed;
+
+ const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
+ MachineBasicBlock *CurMBB = CheckCall->getParent();
// Find LOAD_STACK_GUARD in CurrMBB and build a new LOAD_STACK_GUARD
// instruction with new destination register
- MachineInstr *ClonedInstr = cloneLoadStackGuard(CurMBB, CheckCall);
+ MachineInstr *ClonedInstr = cloneLoadStackGuard(MF, StackGuard);
if (!ClonedInstr)
return Changed;
@@ -216,13 +240,14 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
CurMBB->splice(InsertPt, CurMBB, std::next(InsertPt));
// Create a new virtual register for the CMP instruction result
- Register DiscardReg =
- MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ Register DiscardReg = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
// Emit the CMP instruction to compare stack cookie with global cookie
BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::SUBSXrr))
- .addReg(DiscardReg, RegState::Define | RegState::Dead) // Result discarded
- .addReg(CookieLoadReg) // First operand: stack cookie
+ .addReg(DiscardReg,
+ RegState::Define | RegState::Dead) // Result discarded
+ .addReg(CookieLoadReg) // First operand: stack cookie
.addReg(GlobalCookieReg); // Second operand: global cookie
// Create FailMBB basic block to call __security_check_cookie
@@ -258,6 +283,15 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
CurMBB->addSuccessor(NewRetMBB);
CurMBB->addSuccessor(FailMBB);
+ MachineDominatorTreeWrapperPass *WrapperPass =
+ getAnalysisIfAvailable<MachineDominatorTreeWrapperPass>();
+ MachineDominatorTree *MDT =
+ WrapperPass ? &WrapperPass->getDomTree() : nullptr;
+ if (MDT) {
+ MDT->addNewBlock(FailMBB, CurMBB);
+ MDT->addNewBlock(NewRetMBB, CurMBB);
+ }
+
finishFunction(FailMBB, NewRetMBB);
return !Changed;
diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
index 82559bfb0b172f..7ac7b7c73fe348 100644
--- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
@@ -165,10 +165,9 @@
; CHECK-NEXT: Process Implicit Definitions
; CHECK-NEXT: Remove unreachable machine basic blocks
; CHECK-NEXT: Live Variable Analysis
-; CHECK-NEXT: MachineDominator Tree Construction
-; CHECK-NEXT: Machine Natural Loop Construction
; CHECK-NEXT: Eliminate PHI nodes for register allocation
; CHECK-NEXT: Two-Address instruction pass
+; CHECK-NEXT: MachineDominator Tree Construction
; CHECK-NEXT: Slot index numbering
; CHECK-NEXT: Live Interval Analysis
; CHECK-NEXT: Register Coalescer
More information about the llvm-commits
mailing list