[llvm] [AArch64][CodeGen] Optimize security cookie check with New Fixup Pass (PR #121938)

Omair Javaid via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 7 06:19:29 PST 2025


https://github.com/omjavaid created https://github.com/llvm/llvm-project/pull/121938

This patch adds the AArch64WinFixupBufferSecurityCheckPass to optimize the handling of buffer security checks on AArch64 Windows targets.

The pass selectively replaces __security_check_cookie calls with inline comparisons, transferring control to the runtime library only on failure.

A similar implementation for X86 Windows target was implemented by #95904

>From 87592ba64c03f96611a52084c0194e33711f9020 Mon Sep 17 00:00:00 2001
From: Muhammad Omair Javaid <omair.javaid at linaro.org>
Date: Mon, 6 Jan 2025 16:31:18 +0500
Subject: [PATCH] Optimize security cookie check with New Fixup Pass

This patch adds the AArch64WinFixupBufferSecurityCheckPass to optimize
the handling of buffer security checks on AArch64 Windows targets.

The pass selectively replaces __security_check_cookie calls with inline
comparisons, transferring control to the runtime library only on failure.

A similar implementation for X86 Windows target was implemented by #95904
---
 llvm/lib/Target/AArch64/AArch64.h             |   2 +
 .../Target/AArch64/AArch64TargetMachine.cpp   |   1 +
 .../AArch64WinFixupBufferSecurityCheck.cpp    | 275 ++++++++++++++++++
 llvm/lib/Target/AArch64/CMakeLists.txt        |   1 +
 .../irtranslator-stack-protector-windows.ll   |   9 +-
 llvm/test/CodeGen/AArch64/O0-pipeline.ll      |   1 +
 llvm/test/CodeGen/AArch64/O3-pipeline.ll      |   4 +-
 7 files changed, 291 insertions(+), 2 deletions(-)
 create mode 100644 llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp

diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index ffa578d412b3c2..33e40ae9f89190 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -72,6 +72,7 @@ FunctionPass *createAArch64PostLegalizerLowering();
 FunctionPass *createAArch64PostSelectOptimize();
 FunctionPass *createAArch64StackTaggingPass(bool IsOptNone);
 FunctionPass *createAArch64StackTaggingPreRAPass();
+FunctionPass *createAArch64WinFixupBufferSecurityCheckPass();
 ModulePass *createAArch64Arm64ECCallLoweringPass();
 
 void initializeAArch64A53Fix835769Pass(PassRegistry&);
@@ -105,6 +106,7 @@ void initializeAArch64SpeculationHardeningPass(PassRegistry &);
 void initializeAArch64StackTaggingPass(PassRegistry &);
 void initializeAArch64StackTaggingPreRAPass(PassRegistry &);
 void initializeAArch64StorePairSuppressPass(PassRegistry&);
+void initializeAArch64WinFixupBufferSecurityCheckPassPass(PassRegistry &);
 void initializeFalkorHWPFFixPass(PassRegistry&);
 void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&);
 void initializeLDTLSCleanupPass(PassRegistry&);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index 07f072446081a3..30d2be9c2a832b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -812,6 +812,7 @@ void AArch64PassConfig::addPreRegAlloc() {
   }
   if (TM->getOptLevel() != CodeGenOptLevel::None && EnableMachinePipeliner)
     addPass(&MachinePipelinerID);
+  addPass(createAArch64WinFixupBufferSecurityCheckPass());
 }
 
 void AArch64PassConfig::addPostRegAlloc() {
diff --git a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
new file mode 100644
index 00000000000000..f8c7c332aaf7b8
--- /dev/null
+++ b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
@@ -0,0 +1,275 @@
+//===- AArch64WinFixupBufferSecurityCheck.cpp Fix Buffer Security Check Call
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Buffer Security Check implementation inserts windows specific callback into
+// code. On windows, __security_check_cookie call gets call everytime function
+// is return without fixup. Since this function is defined in runtime library,
+// it incures cost of call in dll which simply does comparison and returns most
+// time. With Fixup, We selective move to call in DLL only if comparison fails.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/IR/Module.h"
+
+#include "AArch64.h"
+#include "AArch64InstrInfo.h"
+#include "AArch64Subtarget.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-win-fixup-bscheck"
+
+namespace {
+
+class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
+public:
+  static char ID;
+
+  AArch64WinFixupBufferSecurityCheckPass() : MachineFunctionPass(ID) {}
+
+  StringRef getPassName() const override {
+    return "AArch64 Windows Fixup Buffer Security Check";
+  }
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+
+  std::pair<MachineBasicBlock *, MachineInstr *>
+  getSecurityCheckerBasicBlock(MachineFunction &MF);
+
+  MachineInstr *cloneLoadStackGuard(MachineBasicBlock *CurMBB,
+                                    MachineInstr *CheckCall);
+
+  void getGuardCheckSequence(MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
+                             MachineInstr *SeqMI[5]);
+
+  void SplitBasicBlock(MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
+                       MachineBasicBlock::iterator SplitIt);
+
+  void FinishBlock(MachineBasicBlock *MBB);
+
+  void FinishFunction(MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB);
+};
+} // end anonymous namespace
+
+char AArch64WinFixupBufferSecurityCheckPass::ID = 0;
+
+INITIALIZE_PASS(AArch64WinFixupBufferSecurityCheckPass, DEBUG_TYPE, DEBUG_TYPE,
+                false, false)
+
+FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
+  return new AArch64WinFixupBufferSecurityCheckPass();
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::SplitBasicBlock(
+    MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
+    MachineBasicBlock::iterator SplitIt) {
+  NewRetMBB->splice(NewRetMBB->end(), CurMBB, SplitIt, CurMBB->end());
+}
+
+std::pair<MachineBasicBlock *, MachineInstr *>
+AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock(
+    MachineFunction &MF) {
+  for (auto &MBB : MF) {
+    for (auto &MI : MBB) {
+      if (MI.getOpcode() == AArch64::BL && MI.getNumExplicitOperands() == 1) {
+        auto MO = MI.getOperand(0);
+        if (MO.isGlobal()) {
+          auto Callee = dyn_cast<Function>(MO.getGlobal());
+          if (Callee && Callee->getName() == "__security_check_cookie") {
+            return std::make_pair(&MBB, &MI);
+            break;
+          }
+        }
+      }
+    }
+  }
+  return std::make_pair(nullptr, nullptr);
+}
+
+MachineInstr *AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(
+    MachineBasicBlock *CurMBB, MachineInstr *CheckCall) {
+  // Ensure that we have a valid MachineBasicBlock and CheckCall
+  if (!CurMBB || !CheckCall)
+    return nullptr;
+
+  MachineFunction &MF = *CurMBB->getParent();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  // Initialize reverse iterator starting just before CheckCall
+  MachineBasicBlock::reverse_iterator DIt(CheckCall);
+  MachineBasicBlock::reverse_iterator DEnd = CurMBB->rend();
+
+  // Reverse iterate from CheckCall to find LOAD_STACK_GUARD
+  for (; DIt != DEnd; ++DIt) {
+    MachineInstr &MI = *DIt;
+    if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
+      // Clone the LOAD_STACK_GUARD instruction
+      MachineInstr *ClonedInstr = MF.CloneMachineInstr(&MI);
+
+      // Get the register class of the original destination register
+      Register OrigReg = MI.getOperand(0).getReg();
+      const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
+
+      // Create a new virtual register in the same register class
+      Register NewReg = MRI.createVirtualRegister(RegClass);
+
+      // Update operand 0 (destination) of the cloned instruction
+      MachineOperand &DestOperand = ClonedInstr->getOperand(0);
+      if (DestOperand.isReg() && DestOperand.isDef()) {
+        DestOperand.setReg(NewReg); // Set the new virtual register
+      }
+
+      // Return the modified cloned instruction
+      return ClonedInstr;
+    }
+  }
+
+  // If no LOAD_STACK_GUARD instruction was found, return nullptr
+  return nullptr;
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
+    MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
+    MachineInstr *SeqMI[5]) {
+
+  MachineBasicBlock::iterator UIt(CheckCall);
+  MachineBasicBlock::reverse_iterator DIt(CheckCall);
+
+  // Move forward to find the stack adjustment after the call
+  // to __security_check_cookie
+  ++UIt;
+  SeqMI[4] = &*UIt;
+
+  // Assign the BL instruction (call to __security_check_cookie)
+  SeqMI[3] = CheckCall;
+
+  // COPY function slot cookie
+  ++DIt;
+  SeqMI[2] = &*DIt;
+
+  // Move backward to find the instruction that loads the security cookie from
+  // the stack
+  ++DIt;
+  SeqMI[1] = &*DIt;
+
+  ++DIt; // Find ADJCALLSTACKDOWN
+  SeqMI[0] = &*DIt;
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::FinishBlock(
+    MachineBasicBlock *MBB) {
+  LivePhysRegs LiveRegs;
+  computeAndAddLiveIns(LiveRegs, *MBB);
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::FinishFunction(
+    MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB) {
+  FailMBB->getParent()->RenumberBlocks();
+  // FailMBB includes call to MSCV RT  where __security_check_cookie
+  // function is called. This function uses regcall and it expects cookie
+  // value from stack slot.( even if this is modified)
+  // Before going further we compute back livein for this block to make sure
+  // it is live and provided.
+  FinishBlock(FailMBB);
+  FinishBlock(NewRetMBB);
+}
+
+bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
+    MachineFunction &MF) {
+  bool Changed = false;
+  const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
+
+  if (!STI.getTargetTriple().isWindowsMSVCEnvironment())
+    return Changed;
+
+  // Check if security cookie was installed or not
+  Module &M = *MF.getFunction().getParent();
+  GlobalVariable *GV = M.getGlobalVariable("__security_cookie");
+  if (!GV)
+    return Changed;
+
+  const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
+
+  // Check if security check cookie call was installed or not
+  auto [CurMBB, CheckCall] = getSecurityCheckerBasicBlock(MF);
+  if (!CheckCall)
+    return Changed;
+
+  // Get sequence of instruction in CurMBB responsible for calling
+  // __security_check_cookie
+  MachineInstr *SeqMI[5];
+  getGuardCheckSequence(CurMBB, CheckCall, SeqMI);
+
+  // Find LOAD_STACK_GUARD in CurrMBB and build a new LOAD_STACK_GUARD
+  // instruction with new destination register
+  MachineInstr *ClonedInstr = cloneLoadStackGuard(CurMBB, CheckCall);
+  if (!ClonedInstr)
+    return Changed;
+
+  // Insert cloned LOAD_STACK_GUARD right before the call to
+  // __security_check_cookie
+  MachineBasicBlock::iterator InsertPt(SeqMI[0]);
+  CurMBB->insert(InsertPt, ClonedInstr);
+
+  auto CookieLoadReg = SeqMI[1]->getOperand(0).getReg();
+  auto GlobalCookieReg = ClonedInstr->getOperand(0).getReg();
+
+  // Move LDRXui that loads __security_cookie from stack, right after
+  // the cloned LOAD_STACK_GUARD
+  CurMBB->splice(InsertPt, CurMBB, std::next(InsertPt));
+
+  // Create a new virtual register for the CMP instruction result
+  Register DiscardReg =
+      MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+
+  // Emit the CMP instruction to compare stack cookie with global cookie
+  BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::SUBSXrr))
+      .addReg(DiscardReg, RegState::Define | RegState::Dead) // Result discarded
+      .addReg(CookieLoadReg)    // First operand: stack cookie
+      .addReg(GlobalCookieReg); // Second operand: global cookie
+
+  // Create FailMBB basic block to call __security_check_cookie
+  MachineBasicBlock *FailMBB = MF.CreateMachineBasicBlock();
+  MF.insert(MF.end(), FailMBB);
+
+  // Create NewRetMBB basic block to skip call to __security_check_cookie
+  MachineBasicBlock *NewRetMBB = MF.CreateMachineBasicBlock();
+  MF.insert(MF.end(), NewRetMBB);
+
+  // Conditional branch to FailMBB if cookies do not match
+  BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::Bcc))
+      .addImm(AArch64CC::NE) // Condition: Not Equal
+      .addMBB(FailMBB);      // Failure block
+
+  // Add an unconditional branch to NewRetMBB.
+  BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::B))
+      .addMBB(NewRetMBB);
+
+  // Move fail check squence from CurMBB to FailMBB
+  MachineBasicBlock::iterator U2It(SeqMI[4]);
+  ++U2It;
+  FailMBB->splice(FailMBB->end(), CurMBB, InsertPt, U2It);
+
+  // Insert a BRK instruction at the end of the FailMBB
+  BuildMI(*FailMBB, FailMBB->end(), DebugLoc(), TII->get(AArch64::BRK))
+      .addImm(0); // Immediate value for BRK
+
+  // Move remaining instructions after CheckCall to NewRetMBB.
+  NewRetMBB->splice(NewRetMBB->end(), CurMBB, U2It, CurMBB->end());
+
+  // Restructure Basic Blocks
+  CurMBB->addSuccessor(NewRetMBB);
+  CurMBB->addSuccessor(FailMBB);
+
+  FinishFunction(FailMBB, NewRetMBB);
+
+  return !Changed;
+}
diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt
index 2300e479bc1106..7ee9c3174f38c5 100644
--- a/llvm/lib/Target/AArch64/CMakeLists.txt
+++ b/llvm/lib/Target/AArch64/CMakeLists.txt
@@ -85,6 +85,7 @@ add_llvm_target(AArch64CodeGen
   AArch64TargetMachine.cpp
   AArch64TargetObjectFile.cpp
   AArch64TargetTransformInfo.cpp
+  AArch64WinFixupBufferSecurityCheck.cpp
   SMEABIPass.cpp
   SMEPeepholeOpt.cpp
   SVEIntrinsicOpts.cpp
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
index 6aefc5341da072..1e2ef173638ec5 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
@@ -17,8 +17,12 @@ define void @caller() sspreq {
 ; CHECK-NEXT:    ldr x8, [x8, :lo12:__security_cookie]
 ; CHECK-NEXT:    str x8, [sp, #8]
 ; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    adrp x8, __security_cookie
 ; CHECK-NEXT:    ldr x0, [sp, #8]
-; CHECK-NEXT:    bl __security_check_cookie
+; CHECK-NEXT:    ldr x8, [x8, :lo12:__security_cookie]
+; CHECK-NEXT:    subs x8, x0, x8
+; CHECK-NEXT:    b.ne .LBB0_2
+; CHECK:       // %bb.1:
 ; CHECK-NEXT:    .seh_startepilogue
 ; CHECK-NEXT:    ldr x30, [sp, #16] // 8-byte Folded Reload
 ; CHECK-NEXT:    .seh_save_reg x30, 16
@@ -26,6 +30,9 @@ define void @caller() sspreq {
 ; CHECK-NEXT:    .seh_stackalloc 32
 ; CHECK-NEXT:    .seh_endepilogue
 ; CHECK-NEXT:    ret
+; CHECK:       .LBB0_2:
+; CHECK-NEXT:    bl __security_check_cookie
+; CHECK-NEXT:    brk #0
 ; CHECK-NEXT:    .seh_endfunclet
 ; CHECK-NEXT:    .seh_endproc
 entry:
diff --git a/llvm/test/CodeGen/AArch64/O0-pipeline.ll b/llvm/test/CodeGen/AArch64/O0-pipeline.ll
index 0d079881cb909c..f2e2845d6a95b9 100644
--- a/llvm/test/CodeGen/AArch64/O0-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O0-pipeline.ll
@@ -54,6 +54,7 @@
 ; CHECK-NEXT:       AArch64 Instruction Selection
 ; CHECK-NEXT:       Finalize ISel and expand pseudo-instructions
 ; CHECK-NEXT:       Local Stack Slot Allocation
+; CHECK-NEXT:       AArch64 Windows Fixup Buffer Security Check
 ; CHECK-NEXT:       Eliminate PHI nodes for register allocation
 ; CHECK-NEXT:       Two-Address instruction pass
 ; CHECK-NEXT:       Fast Register Allocator
diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
index b5d5e27afa17ad..82559bfb0b172f 100644
--- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
@@ -159,14 +159,16 @@
 ; CHECK-NEXT:       Remove dead machine instructions
 ; CHECK-NEXT:       AArch64 MI Peephole Optimization pass
 ; CHECK-NEXT:       AArch64 Dead register definitions
+; CHECK-NEXT:       AArch64 Windows Fixup Buffer Security Check
 ; CHECK-NEXT:       Detect Dead Lanes
 ; CHECK-NEXT:       Init Undef Pass
 ; CHECK-NEXT:       Process Implicit Definitions
 ; CHECK-NEXT:       Remove unreachable machine basic blocks
 ; CHECK-NEXT:       Live Variable Analysis
+; CHECK-NEXT:       MachineDominator Tree Construction
+; CHECK-NEXT:       Machine Natural Loop  Construction
 ; CHECK-NEXT:       Eliminate PHI nodes for register allocation
 ; CHECK-NEXT:       Two-Address instruction pass
-; CHECK-NEXT:       MachineDominator Tree Construction
 ; CHECK-NEXT:       Slot index numbering
 ; CHECK-NEXT:       Live Interval Analysis
 ; CHECK-NEXT:       Register Coalescer



More information about the llvm-commits mailing list