[llvm] [AArch64][CodeGen] Optimize security cookie check with New Fixup Pass (PR #121938)

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 24 15:23:28 PST 2025


================
@@ -0,0 +1,298 @@
+//===- AArch64WinFixupBufferSecurityCheck.cpp Fixup Buffer Security Check -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Buffer Security Check implementation inserts windows specific callback into
+// code. On windows, __security_check_cookie call gets call everytime function
+// is return without fixup. Since this function is defined in runtime library,
+// it incures cost of call in dll which simply does comparison and returns most
+// time. With Fixup, We selective move to call in DLL only if comparison fails.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineLoopInfo.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/IR/Module.h"
+
+#include "AArch64.h"
+#include "AArch64InstrInfo.h"
+#include "AArch64Subtarget.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-win-fixup-bscheck"
+
+namespace {
+
+class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
+public:
+  static char ID;
+
+  AArch64WinFixupBufferSecurityCheckPass() : MachineFunctionPass(ID) {}
+
+  StringRef getPassName() const override {
+    return "AArch64 Windows Fixup Buffer Security Check";
+  }
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+  std::pair<MachineInstr *, MachineInstr *>
+  findSecurityCheckAndLoadStackGuard(MachineFunction &MF);
+
+  MachineInstr *cloneLoadStackGuard(MachineFunction &MF, MachineInstr *MI);
+
+  bool getGuardCheckSequence(MachineInstr *CheckCall, MachineInstr *SeqMI[5]);
+
+  void finishBlock(MachineBasicBlock *MBB);
+
+  void finishFunction(MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB);
+};
+} // end anonymous namespace
+
+char AArch64WinFixupBufferSecurityCheckPass::ID = 0;
+
+INITIALIZE_PASS(AArch64WinFixupBufferSecurityCheckPass, DEBUG_TYPE, DEBUG_TYPE,
+                false, false)
+
+FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
+  return new AArch64WinFixupBufferSecurityCheckPass();
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::getAnalysisUsage(
+    AnalysisUsage &AU) const {
+  AU.addUsedIfAvailable<MachineDominatorTreeWrapperPass>();
+  AU.addPreserved<MachineDominatorTreeWrapperPass>();
+  AU.addPreserved<MachineLoopInfoWrapperPass>();
+  MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+std::pair<MachineInstr *, MachineInstr *>
+AArch64WinFixupBufferSecurityCheckPass::findSecurityCheckAndLoadStackGuard(
+    MachineFunction &MF) {
+
+  MachineInstr *SecurityCheckCall = nullptr;
+  MachineInstr *LoadStackGuard = nullptr;
+
+  for (auto &MBB : MF) {
+    for (auto &MI : MBB) {
+      if (!LoadStackGuard && MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
+        LoadStackGuard = &MI;
+      }
+
+      if (MI.isCall() && MI.getNumExplicitOperands() == 1) {
+        auto MO = MI.getOperand(0);
+        if (MO.isGlobal()) {
+          auto Callee = dyn_cast<Function>(MO.getGlobal());
+          if (Callee && Callee->getName() == "__security_check_cookie") {
+            SecurityCheckCall = &MI;
+          }
+        }
+      }
+
+      // If both are found, return them
+      if (LoadStackGuard && SecurityCheckCall) {
+        return std::make_pair(LoadStackGuard, SecurityCheckCall);
+      }
+    }
+  }
+
+  return std::make_pair(nullptr, nullptr);
+}
+
+MachineInstr *
+AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(MachineFunction &MF,
+                                                            MachineInstr *MI) {
+
+  MachineInstr *ClonedInstr = MF.CloneMachineInstr(MI);
+
+  // Get the register class of the original destination register
+  Register OrigReg = MI->getOperand(0).getReg();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
+
+  // Create a new virtual register in the same register class
+  Register NewReg = MRI.createVirtualRegister(RegClass);
+
+  // Update operand 0 (destination) of the cloned instruction
+  MachineOperand &DestOperand = ClonedInstr->getOperand(0);
+  if (DestOperand.isReg() && DestOperand.isDef()) {
+    DestOperand.setReg(NewReg); // Set the new virtual register
+  }
+
+  return ClonedInstr;
+}
+
+bool AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
+    MachineInstr *CheckCall, MachineInstr *SeqMI[5]) {
+
+  MachineBasicBlock *MBB = CheckCall->getParent();
+
+  MachineBasicBlock::iterator UIt(CheckCall);
+  MachineBasicBlock::reverse_iterator DIt(CheckCall);
+
+  // Move forward to find the stack adjustment after the call
+  ++UIt;
+  if (UIt == MBB->end() || UIt->getOpcode() != AArch64::ADJCALLSTACKUP) {
+    return false;
+  }
+  SeqMI[4] = &*UIt;
+
+  // Assign the BL instruction (call to __security_check_cookie)
+  SeqMI[3] = CheckCall;
+
+  // Move backward to find the COPY instruction for the function slot cookie
+  // argument passing
+  ++DIt;
+  if (DIt == MBB->rend() || DIt->getOpcode() != AArch64::COPY) {
+    return false;
----------------
efriedma-quic wrote:

Also check that the COPY destination is x0, and the source is the ldr.

https://github.com/llvm/llvm-project/pull/121938


More information about the llvm-commits mailing list