[llvm] [AArch64][CodeGen] Optimize security cookie check with New Fixup Pass (PR #121938)
Eli Friedman via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 9 11:54:44 PST 2025
================
@@ -0,0 +1,275 @@
+//===- AArch64WinFixupBufferSecurityCheck.cpp Fix Buffer Security Check Call
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Buffer Security Check implementation inserts windows specific callback into
+// code. On windows, __security_check_cookie call gets call everytime function
+// is return without fixup. Since this function is defined in runtime library,
+// it incures cost of call in dll which simply does comparison and returns most
+// time. With Fixup, We selective move to call in DLL only if comparison fails.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/IR/Module.h"
+
+#include "AArch64.h"
+#include "AArch64InstrInfo.h"
+#include "AArch64Subtarget.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-win-fixup-bscheck"
+
+namespace {
+
+class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
+public:
+ static char ID;
+
+ AArch64WinFixupBufferSecurityCheckPass() : MachineFunctionPass(ID) {}
+
+ StringRef getPassName() const override {
+ return "AArch64 Windows Fixup Buffer Security Check";
+ }
+
+ bool runOnMachineFunction(MachineFunction &MF) override;
+
+ std::pair<MachineBasicBlock *, MachineInstr *>
+ getSecurityCheckerBasicBlock(MachineFunction &MF);
+
+ MachineInstr *cloneLoadStackGuard(MachineBasicBlock *CurMBB,
+ MachineInstr *CheckCall);
+
+ void getGuardCheckSequence(MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
+ MachineInstr *SeqMI[5]);
+
+ void SplitBasicBlock(MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
+ MachineBasicBlock::iterator SplitIt);
+
+ void FinishBlock(MachineBasicBlock *MBB);
+
+ void FinishFunction(MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB);
+};
+} // end anonymous namespace
+
+char AArch64WinFixupBufferSecurityCheckPass::ID = 0;
+
+INITIALIZE_PASS(AArch64WinFixupBufferSecurityCheckPass, DEBUG_TYPE, DEBUG_TYPE,
+ false, false)
+
+FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
+ return new AArch64WinFixupBufferSecurityCheckPass();
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::SplitBasicBlock(
+ MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
+ MachineBasicBlock::iterator SplitIt) {
+ NewRetMBB->splice(NewRetMBB->end(), CurMBB, SplitIt, CurMBB->end());
+}
+
+std::pair<MachineBasicBlock *, MachineInstr *>
+AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock(
+ MachineFunction &MF) {
+ for (auto &MBB : MF) {
+ for (auto &MI : MBB) {
+ if (MI.getOpcode() == AArch64::BL && MI.getNumExplicitOperands() == 1) {
+ auto MO = MI.getOperand(0);
+ if (MO.isGlobal()) {
+ auto Callee = dyn_cast<Function>(MO.getGlobal());
+ if (Callee && Callee->getName() == "__security_check_cookie") {
+ return std::make_pair(&MBB, &MI);
+ break;
+ }
+ }
+ }
+ }
+ }
+ return std::make_pair(nullptr, nullptr);
+}
+
+MachineInstr *AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(
+ MachineBasicBlock *CurMBB, MachineInstr *CheckCall) {
+ // Ensure that we have a valid MachineBasicBlock and CheckCall
+ if (!CurMBB || !CheckCall)
+ return nullptr;
+
+ MachineFunction &MF = *CurMBB->getParent();
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+
+ // Initialize reverse iterator starting just before CheckCall
+ MachineBasicBlock::reverse_iterator DIt(CheckCall);
+ MachineBasicBlock::reverse_iterator DEnd = CurMBB->rend();
+
+ // Reverse iterate from CheckCall to find LOAD_STACK_GUARD
+ for (; DIt != DEnd; ++DIt) {
+ MachineInstr &MI = *DIt;
+ if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
+ // Clone the LOAD_STACK_GUARD instruction
+ MachineInstr *ClonedInstr = MF.CloneMachineInstr(&MI);
+
+ // Get the register class of the original destination register
+ Register OrigReg = MI.getOperand(0).getReg();
+ const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
+
+ // Create a new virtual register in the same register class
+ Register NewReg = MRI.createVirtualRegister(RegClass);
+
+ // Update operand 0 (destination) of the cloned instruction
+ MachineOperand &DestOperand = ClonedInstr->getOperand(0);
+ if (DestOperand.isReg() && DestOperand.isDef()) {
+ DestOperand.setReg(NewReg); // Set the new virtual register
+ }
+
+ // Return the modified cloned instruction
+ return ClonedInstr;
+ }
+ }
+
+ // If no LOAD_STACK_GUARD instruction was found, return nullptr
+ return nullptr;
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
+ MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
+ MachineInstr *SeqMI[5]) {
+
+ MachineBasicBlock::iterator UIt(CheckCall);
+ MachineBasicBlock::reverse_iterator DIt(CheckCall);
+
+ // Move forward to find the stack adjustment after the call
+ // to __security_check_cookie
----------------
efriedma-quic wrote:
I think we usually don't reschedule the code because ADJCALLSTACKDOWN exists at this point in the pipeline, and acts as a scheduling boundary. Which might work out today, but isn't reliable; it could easily change in the future.
I'm also not sure off the top of my head if we can consistently load the stack guard in a single instruction.
It's probably fine if you abort the transform on an unexpected sequence.
https://github.com/llvm/llvm-project/pull/121938
More information about the llvm-commits
mailing list