[llvm] [AArch64][CodeGen] Optimize security cookie check with New Fixup Pass (PR #121938)

Omair Javaid via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 16 00:30:03 PST 2025


https://github.com/omjavaid updated https://github.com/llvm/llvm-project/pull/121938

>From 87592ba64c03f96611a52084c0194e33711f9020 Mon Sep 17 00:00:00 2001
From: Muhammad Omair Javaid <omair.javaid at linaro.org>
Date: Mon, 6 Jan 2025 16:31:18 +0500
Subject: [PATCH 1/3] Optimize security cookie check with New Fixup Pass

This patch adds the AArch64WinFixupBufferSecurityCheckPass to optimize
the handling of buffer security checks on AArch64 Windows targets.

The pass selectively replaces __security_check_cookie calls with inline
comparisons, transferring control to the runtime library only on failure.

A similar implementation for X86 Windows target was implemented by #95904
---
 llvm/lib/Target/AArch64/AArch64.h             |   2 +
 .../Target/AArch64/AArch64TargetMachine.cpp   |   1 +
 .../AArch64WinFixupBufferSecurityCheck.cpp    | 275 ++++++++++++++++++
 llvm/lib/Target/AArch64/CMakeLists.txt        |   1 +
 .../irtranslator-stack-protector-windows.ll   |   9 +-
 llvm/test/CodeGen/AArch64/O0-pipeline.ll      |   1 +
 llvm/test/CodeGen/AArch64/O3-pipeline.ll      |   4 +-
 7 files changed, 291 insertions(+), 2 deletions(-)
 create mode 100644 llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp

diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index ffa578d412b3c2..33e40ae9f89190 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -72,6 +72,7 @@ FunctionPass *createAArch64PostLegalizerLowering();
 FunctionPass *createAArch64PostSelectOptimize();
 FunctionPass *createAArch64StackTaggingPass(bool IsOptNone);
 FunctionPass *createAArch64StackTaggingPreRAPass();
+FunctionPass *createAArch64WinFixupBufferSecurityCheckPass();
 ModulePass *createAArch64Arm64ECCallLoweringPass();
 
 void initializeAArch64A53Fix835769Pass(PassRegistry&);
@@ -105,6 +106,7 @@ void initializeAArch64SpeculationHardeningPass(PassRegistry &);
 void initializeAArch64StackTaggingPass(PassRegistry &);
 void initializeAArch64StackTaggingPreRAPass(PassRegistry &);
 void initializeAArch64StorePairSuppressPass(PassRegistry&);
+void initializeAArch64WinFixupBufferSecurityCheckPassPass(PassRegistry &);
 void initializeFalkorHWPFFixPass(PassRegistry&);
 void initializeFalkorMarkStridedAccessesLegacyPass(PassRegistry&);
 void initializeLDTLSCleanupPass(PassRegistry&);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index 07f072446081a3..30d2be9c2a832b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -812,6 +812,7 @@ void AArch64PassConfig::addPreRegAlloc() {
   }
   if (TM->getOptLevel() != CodeGenOptLevel::None && EnableMachinePipeliner)
     addPass(&MachinePipelinerID);
+  addPass(createAArch64WinFixupBufferSecurityCheckPass());
 }
 
 void AArch64PassConfig::addPostRegAlloc() {
diff --git a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
new file mode 100644
index 00000000000000..f8c7c332aaf7b8
--- /dev/null
+++ b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
@@ -0,0 +1,275 @@
+//===- AArch64WinFixupBufferSecurityCheck.cpp Fix Buffer Security Check Call
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Buffer Security Check implementation inserts windows specific callback into
+// code. On windows, __security_check_cookie call gets call everytime function
+// is return without fixup. Since this function is defined in runtime library,
+// it incures cost of call in dll which simply does comparison and returns most
+// time. With Fixup, We selective move to call in DLL only if comparison fails.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/IR/Module.h"
+
+#include "AArch64.h"
+#include "AArch64InstrInfo.h"
+#include "AArch64Subtarget.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-win-fixup-bscheck"
+
+namespace {
+
+class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
+public:
+  static char ID;
+
+  AArch64WinFixupBufferSecurityCheckPass() : MachineFunctionPass(ID) {}
+
+  StringRef getPassName() const override {
+    return "AArch64 Windows Fixup Buffer Security Check";
+  }
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+
+  std::pair<MachineBasicBlock *, MachineInstr *>
+  getSecurityCheckerBasicBlock(MachineFunction &MF);
+
+  MachineInstr *cloneLoadStackGuard(MachineBasicBlock *CurMBB,
+                                    MachineInstr *CheckCall);
+
+  void getGuardCheckSequence(MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
+                             MachineInstr *SeqMI[5]);
+
+  void SplitBasicBlock(MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
+                       MachineBasicBlock::iterator SplitIt);
+
+  void FinishBlock(MachineBasicBlock *MBB);
+
+  void FinishFunction(MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB);
+};
+} // end anonymous namespace
+
+char AArch64WinFixupBufferSecurityCheckPass::ID = 0;
+
+INITIALIZE_PASS(AArch64WinFixupBufferSecurityCheckPass, DEBUG_TYPE, DEBUG_TYPE,
+                false, false)
+
+FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
+  return new AArch64WinFixupBufferSecurityCheckPass();
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::SplitBasicBlock(
+    MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
+    MachineBasicBlock::iterator SplitIt) {
+  NewRetMBB->splice(NewRetMBB->end(), CurMBB, SplitIt, CurMBB->end());
+}
+
+std::pair<MachineBasicBlock *, MachineInstr *>
+AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock(
+    MachineFunction &MF) {
+  for (auto &MBB : MF) {
+    for (auto &MI : MBB) {
+      if (MI.getOpcode() == AArch64::BL && MI.getNumExplicitOperands() == 1) {
+        auto MO = MI.getOperand(0);
+        if (MO.isGlobal()) {
+          auto Callee = dyn_cast<Function>(MO.getGlobal());
+          if (Callee && Callee->getName() == "__security_check_cookie") {
+            return std::make_pair(&MBB, &MI);
+            break;
+          }
+        }
+      }
+    }
+  }
+  return std::make_pair(nullptr, nullptr);
+}
+
+MachineInstr *AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(
+    MachineBasicBlock *CurMBB, MachineInstr *CheckCall) {
+  // Ensure that we have a valid MachineBasicBlock and CheckCall
+  if (!CurMBB || !CheckCall)
+    return nullptr;
+
+  MachineFunction &MF = *CurMBB->getParent();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  // Initialize reverse iterator starting just before CheckCall
+  MachineBasicBlock::reverse_iterator DIt(CheckCall);
+  MachineBasicBlock::reverse_iterator DEnd = CurMBB->rend();
+
+  // Reverse iterate from CheckCall to find LOAD_STACK_GUARD
+  for (; DIt != DEnd; ++DIt) {
+    MachineInstr &MI = *DIt;
+    if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
+      // Clone the LOAD_STACK_GUARD instruction
+      MachineInstr *ClonedInstr = MF.CloneMachineInstr(&MI);
+
+      // Get the register class of the original destination register
+      Register OrigReg = MI.getOperand(0).getReg();
+      const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
+
+      // Create a new virtual register in the same register class
+      Register NewReg = MRI.createVirtualRegister(RegClass);
+
+      // Update operand 0 (destination) of the cloned instruction
+      MachineOperand &DestOperand = ClonedInstr->getOperand(0);
+      if (DestOperand.isReg() && DestOperand.isDef()) {
+        DestOperand.setReg(NewReg); // Set the new virtual register
+      }
+
+      // Return the modified cloned instruction
+      return ClonedInstr;
+    }
+  }
+
+  // If no LOAD_STACK_GUARD instruction was found, return nullptr
+  return nullptr;
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
+    MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
+    MachineInstr *SeqMI[5]) {
+
+  MachineBasicBlock::iterator UIt(CheckCall);
+  MachineBasicBlock::reverse_iterator DIt(CheckCall);
+
+  // Move forward to find the stack adjustment after the call
+  // to __security_check_cookie
+  ++UIt;
+  SeqMI[4] = &*UIt;
+
+  // Assign the BL instruction (call to __security_check_cookie)
+  SeqMI[3] = CheckCall;
+
+  // COPY function slot cookie
+  ++DIt;
+  SeqMI[2] = &*DIt;
+
+  // Move backward to find the instruction that loads the security cookie from
+  // the stack
+  ++DIt;
+  SeqMI[1] = &*DIt;
+
+  ++DIt; // Find ADJCALLSTACKDOWN
+  SeqMI[0] = &*DIt;
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::FinishBlock(
+    MachineBasicBlock *MBB) {
+  LivePhysRegs LiveRegs;
+  computeAndAddLiveIns(LiveRegs, *MBB);
+}
+
+void AArch64WinFixupBufferSecurityCheckPass::FinishFunction(
+    MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB) {
+  FailMBB->getParent()->RenumberBlocks();
+  // FailMBB includes call to MSCV RT  where __security_check_cookie
+  // function is called. This function uses regcall and it expects cookie
+  // value from stack slot.( even if this is modified)
+  // Before going further we compute back livein for this block to make sure
+  // it is live and provided.
+  FinishBlock(FailMBB);
+  FinishBlock(NewRetMBB);
+}
+
+bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
+    MachineFunction &MF) {
+  bool Changed = false;
+  const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
+
+  if (!STI.getTargetTriple().isWindowsMSVCEnvironment())
+    return Changed;
+
+  // Check if security cookie was installed or not
+  Module &M = *MF.getFunction().getParent();
+  GlobalVariable *GV = M.getGlobalVariable("__security_cookie");
+  if (!GV)
+    return Changed;
+
+  const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
+
+  // Check if security check cookie call was installed or not
+  auto [CurMBB, CheckCall] = getSecurityCheckerBasicBlock(MF);
+  if (!CheckCall)
+    return Changed;
+
+  // Get sequence of instruction in CurMBB responsible for calling
+  // __security_check_cookie
+  MachineInstr *SeqMI[5];
+  getGuardCheckSequence(CurMBB, CheckCall, SeqMI);
+
+  // Find LOAD_STACK_GUARD in CurrMBB and build a new LOAD_STACK_GUARD
+  // instruction with new destination register
+  MachineInstr *ClonedInstr = cloneLoadStackGuard(CurMBB, CheckCall);
+  if (!ClonedInstr)
+    return Changed;
+
+  // Insert cloned LOAD_STACK_GUARD right before the call to
+  // __security_check_cookie
+  MachineBasicBlock::iterator InsertPt(SeqMI[0]);
+  CurMBB->insert(InsertPt, ClonedInstr);
+
+  auto CookieLoadReg = SeqMI[1]->getOperand(0).getReg();
+  auto GlobalCookieReg = ClonedInstr->getOperand(0).getReg();
+
+  // Move LDRXui that loads __security_cookie from stack, right after
+  // the cloned LOAD_STACK_GUARD
+  CurMBB->splice(InsertPt, CurMBB, std::next(InsertPt));
+
+  // Create a new virtual register for the CMP instruction result
+  Register DiscardReg =
+      MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+
+  // Emit the CMP instruction to compare stack cookie with global cookie
+  BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::SUBSXrr))
+      .addReg(DiscardReg, RegState::Define | RegState::Dead) // Result discarded
+      .addReg(CookieLoadReg)    // First operand: stack cookie
+      .addReg(GlobalCookieReg); // Second operand: global cookie
+
+  // Create FailMBB basic block to call __security_check_cookie
+  MachineBasicBlock *FailMBB = MF.CreateMachineBasicBlock();
+  MF.insert(MF.end(), FailMBB);
+
+  // Create NewRetMBB basic block to skip call to __security_check_cookie
+  MachineBasicBlock *NewRetMBB = MF.CreateMachineBasicBlock();
+  MF.insert(MF.end(), NewRetMBB);
+
+  // Conditional branch to FailMBB if cookies do not match
+  BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::Bcc))
+      .addImm(AArch64CC::NE) // Condition: Not Equal
+      .addMBB(FailMBB);      // Failure block
+
+  // Add an unconditional branch to NewRetMBB.
+  BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::B))
+      .addMBB(NewRetMBB);
+
+  // Move fail check squence from CurMBB to FailMBB
+  MachineBasicBlock::iterator U2It(SeqMI[4]);
+  ++U2It;
+  FailMBB->splice(FailMBB->end(), CurMBB, InsertPt, U2It);
+
+  // Insert a BRK instruction at the end of the FailMBB
+  BuildMI(*FailMBB, FailMBB->end(), DebugLoc(), TII->get(AArch64::BRK))
+      .addImm(0); // Immediate value for BRK
+
+  // Move remaining instructions after CheckCall to NewRetMBB.
+  NewRetMBB->splice(NewRetMBB->end(), CurMBB, U2It, CurMBB->end());
+
+  // Restructure Basic Blocks
+  CurMBB->addSuccessor(NewRetMBB);
+  CurMBB->addSuccessor(FailMBB);
+
+  FinishFunction(FailMBB, NewRetMBB);
+
+  return !Changed;
+}
diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt
index 2300e479bc1106..7ee9c3174f38c5 100644
--- a/llvm/lib/Target/AArch64/CMakeLists.txt
+++ b/llvm/lib/Target/AArch64/CMakeLists.txt
@@ -85,6 +85,7 @@ add_llvm_target(AArch64CodeGen
   AArch64TargetMachine.cpp
   AArch64TargetObjectFile.cpp
   AArch64TargetTransformInfo.cpp
+  AArch64WinFixupBufferSecurityCheck.cpp
   SMEABIPass.cpp
   SMEPeepholeOpt.cpp
   SVEIntrinsicOpts.cpp
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
index 6aefc5341da072..1e2ef173638ec5 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-stack-protector-windows.ll
@@ -17,8 +17,12 @@ define void @caller() sspreq {
 ; CHECK-NEXT:    ldr x8, [x8, :lo12:__security_cookie]
 ; CHECK-NEXT:    str x8, [sp, #8]
 ; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    adrp x8, __security_cookie
 ; CHECK-NEXT:    ldr x0, [sp, #8]
-; CHECK-NEXT:    bl __security_check_cookie
+; CHECK-NEXT:    ldr x8, [x8, :lo12:__security_cookie]
+; CHECK-NEXT:    subs x8, x0, x8
+; CHECK-NEXT:    b.ne .LBB0_2
+; CHECK:       // %bb.1:
 ; CHECK-NEXT:    .seh_startepilogue
 ; CHECK-NEXT:    ldr x30, [sp, #16] // 8-byte Folded Reload
 ; CHECK-NEXT:    .seh_save_reg x30, 16
@@ -26,6 +30,9 @@ define void @caller() sspreq {
 ; CHECK-NEXT:    .seh_stackalloc 32
 ; CHECK-NEXT:    .seh_endepilogue
 ; CHECK-NEXT:    ret
+; CHECK:       .LBB0_2:
+; CHECK-NEXT:    bl __security_check_cookie
+; CHECK-NEXT:    brk #0
 ; CHECK-NEXT:    .seh_endfunclet
 ; CHECK-NEXT:    .seh_endproc
 entry:
diff --git a/llvm/test/CodeGen/AArch64/O0-pipeline.ll b/llvm/test/CodeGen/AArch64/O0-pipeline.ll
index 0d079881cb909c..f2e2845d6a95b9 100644
--- a/llvm/test/CodeGen/AArch64/O0-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O0-pipeline.ll
@@ -54,6 +54,7 @@
 ; CHECK-NEXT:       AArch64 Instruction Selection
 ; CHECK-NEXT:       Finalize ISel and expand pseudo-instructions
 ; CHECK-NEXT:       Local Stack Slot Allocation
+; CHECK-NEXT:       AArch64 Windows Fixup Buffer Security Check
 ; CHECK-NEXT:       Eliminate PHI nodes for register allocation
 ; CHECK-NEXT:       Two-Address instruction pass
 ; CHECK-NEXT:       Fast Register Allocator
diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
index b5d5e27afa17ad..82559bfb0b172f 100644
--- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
@@ -159,14 +159,16 @@
 ; CHECK-NEXT:       Remove dead machine instructions
 ; CHECK-NEXT:       AArch64 MI Peephole Optimization pass
 ; CHECK-NEXT:       AArch64 Dead register definitions
+; CHECK-NEXT:       AArch64 Windows Fixup Buffer Security Check
 ; CHECK-NEXT:       Detect Dead Lanes
 ; CHECK-NEXT:       Init Undef Pass
 ; CHECK-NEXT:       Process Implicit Definitions
 ; CHECK-NEXT:       Remove unreachable machine basic blocks
 ; CHECK-NEXT:       Live Variable Analysis
+; CHECK-NEXT:       MachineDominator Tree Construction
+; CHECK-NEXT:       Machine Natural Loop  Construction
 ; CHECK-NEXT:       Eliminate PHI nodes for register allocation
 ; CHECK-NEXT:       Two-Address instruction pass
-; CHECK-NEXT:       MachineDominator Tree Construction
 ; CHECK-NEXT:       Slot index numbering
 ; CHECK-NEXT:       Live Interval Analysis
 ; CHECK-NEXT:       Register Coalescer

>From 70ad30a22dbafc747e6cfc1c37cfc604f0c062c9 Mon Sep 17 00:00:00 2001
From: Muhammad Omair Javaid <omair.javaid at linaro.org>
Date: Thu, 9 Jan 2025 01:17:24 +0500
Subject: [PATCH 2/3] Fix arsenm's review comments

---
 .../AArch64WinFixupBufferSecurityCheck.cpp    | 29 ++++++-------------
 1 file changed, 9 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
index f8c7c332aaf7b8..072d6bb69d0c11 100644
--- a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
+++ b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
@@ -1,5 +1,4 @@
-//===- AArch64WinFixupBufferSecurityCheck.cpp Fix Buffer Security Check Call
-//-===//
+//===- AArch64WinFixupBufferSecurityCheck.cpp Fixup Buffer Security Check -===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -50,12 +49,9 @@ class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
   void getGuardCheckSequence(MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
                              MachineInstr *SeqMI[5]);
 
-  void SplitBasicBlock(MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
-                       MachineBasicBlock::iterator SplitIt);
+  void finishBlock(MachineBasicBlock *MBB);
 
-  void FinishBlock(MachineBasicBlock *MBB);
-
-  void FinishFunction(MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB);
+  void finishFunction(MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB);
 };
 } // end anonymous namespace
 
@@ -68,24 +64,17 @@ FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
   return new AArch64WinFixupBufferSecurityCheckPass();
 }
 
-void AArch64WinFixupBufferSecurityCheckPass::SplitBasicBlock(
-    MachineBasicBlock *CurMBB, MachineBasicBlock *NewRetMBB,
-    MachineBasicBlock::iterator SplitIt) {
-  NewRetMBB->splice(NewRetMBB->end(), CurMBB, SplitIt, CurMBB->end());
-}
-
 std::pair<MachineBasicBlock *, MachineInstr *>
 AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock(
     MachineFunction &MF) {
   for (auto &MBB : MF) {
     for (auto &MI : MBB) {
-      if (MI.getOpcode() == AArch64::BL && MI.getNumExplicitOperands() == 1) {
+      if (MI.isCall() && MI.getNumExplicitOperands() == 1) {
         auto MO = MI.getOperand(0);
         if (MO.isGlobal()) {
           auto Callee = dyn_cast<Function>(MO.getGlobal());
           if (Callee && Callee->getName() == "__security_check_cookie") {
             return std::make_pair(&MBB, &MI);
-            break;
           }
         }
       }
@@ -164,13 +153,13 @@ void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
   SeqMI[0] = &*DIt;
 }
 
-void AArch64WinFixupBufferSecurityCheckPass::FinishBlock(
+void AArch64WinFixupBufferSecurityCheckPass::finishBlock(
     MachineBasicBlock *MBB) {
   LivePhysRegs LiveRegs;
   computeAndAddLiveIns(LiveRegs, *MBB);
 }
 
-void AArch64WinFixupBufferSecurityCheckPass::FinishFunction(
+void AArch64WinFixupBufferSecurityCheckPass::finishFunction(
     MachineBasicBlock *FailMBB, MachineBasicBlock *NewRetMBB) {
   FailMBB->getParent()->RenumberBlocks();
   // FailMBB includes call to MSCV RT  where __security_check_cookie
@@ -178,8 +167,8 @@ void AArch64WinFixupBufferSecurityCheckPass::FinishFunction(
   // value from stack slot.( even if this is modified)
   // Before going further we compute back livein for this block to make sure
   // it is live and provided.
-  FinishBlock(FailMBB);
-  FinishBlock(NewRetMBB);
+  finishBlock(FailMBB);
+  finishBlock(NewRetMBB);
 }
 
 bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
@@ -269,7 +258,7 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
   CurMBB->addSuccessor(NewRetMBB);
   CurMBB->addSuccessor(FailMBB);
 
-  FinishFunction(FailMBB, NewRetMBB);
+  finishFunction(FailMBB, NewRetMBB);
 
   return !Changed;
 }

>From 1abf5e8ed5964ba557107ac8c5ae0a2363c63559 Mon Sep 17 00:00:00 2001
From: Muhammad Omair Javaid <omair.javaid at linaro.org>
Date: Tue, 14 Jan 2025 17:31:23 +0500
Subject: [PATCH 3/3] Fix review comments by Eli

- Fix LOAD_STACK_GUARD detection logic
- Make transform run when security check sequence matched
- Preserve dominator tree
---
 .../AArch64WinFixupBufferSecurityCheck.cpp    | 156 +++++++++++-------
 llvm/test/CodeGen/AArch64/O3-pipeline.ll      |   3 +-
 2 files changed, 96 insertions(+), 63 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
index 072d6bb69d0c11..0a37024c24e02e 100644
--- a/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
+++ b/llvm/lib/Target/AArch64/AArch64WinFixupBufferSecurityCheck.cpp
@@ -13,8 +13,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/CodeGen/LivePhysRegs.h"
+#include "llvm/CodeGen/MachineDominators.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineLoopInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/IR/Module.h"
 
@@ -40,14 +42,14 @@ class AArch64WinFixupBufferSecurityCheckPass : public MachineFunctionPass {
 
   bool runOnMachineFunction(MachineFunction &MF) override;
 
-  std::pair<MachineBasicBlock *, MachineInstr *>
-  getSecurityCheckerBasicBlock(MachineFunction &MF);
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
 
-  MachineInstr *cloneLoadStackGuard(MachineBasicBlock *CurMBB,
-                                    MachineInstr *CheckCall);
+  std::pair<MachineInstr *, MachineInstr *>
+  findSecurityCheckAndLoadStackGuard(MachineFunction &MF);
 
-  void getGuardCheckSequence(MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
-                             MachineInstr *SeqMI[5]);
+  MachineInstr *cloneLoadStackGuard(MachineFunction &MF, MachineInstr *MI);
+
+  bool getGuardCheckSequence(MachineInstr *CheckCall, MachineInstr *SeqMI[5]);
 
   void finishBlock(MachineBasicBlock *MBB);
 
@@ -64,93 +66,113 @@ FunctionPass *llvm::createAArch64WinFixupBufferSecurityCheckPass() {
   return new AArch64WinFixupBufferSecurityCheckPass();
 }
 
-std::pair<MachineBasicBlock *, MachineInstr *>
-AArch64WinFixupBufferSecurityCheckPass::getSecurityCheckerBasicBlock(
+void AArch64WinFixupBufferSecurityCheckPass::getAnalysisUsage(
+    AnalysisUsage &AU) const {
+  AU.addUsedIfAvailable<MachineDominatorTreeWrapperPass>();
+  AU.addPreserved<MachineDominatorTreeWrapperPass>();
+  AU.addPreserved<MachineLoopInfoWrapperPass>();
+  MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+std::pair<MachineInstr *, MachineInstr *>
+AArch64WinFixupBufferSecurityCheckPass::findSecurityCheckAndLoadStackGuard(
     MachineFunction &MF) {
+
+  MachineInstr *SecurityCheckCall = nullptr;
+  MachineInstr *LoadStackGuard = nullptr;
+
   for (auto &MBB : MF) {
     for (auto &MI : MBB) {
+      if (!LoadStackGuard && MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
+        LoadStackGuard = &MI;
+      }
+
       if (MI.isCall() && MI.getNumExplicitOperands() == 1) {
         auto MO = MI.getOperand(0);
         if (MO.isGlobal()) {
           auto Callee = dyn_cast<Function>(MO.getGlobal());
           if (Callee && Callee->getName() == "__security_check_cookie") {
-            return std::make_pair(&MBB, &MI);
+            SecurityCheckCall = &MI;
           }
         }
       }
+
+      // If both are found, return them
+      if (LoadStackGuard && SecurityCheckCall) {
+        return std::make_pair(LoadStackGuard, SecurityCheckCall);
+      }
     }
   }
+
   return std::make_pair(nullptr, nullptr);
 }
 
-MachineInstr *AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(
-    MachineBasicBlock *CurMBB, MachineInstr *CheckCall) {
-  // Ensure that we have a valid MachineBasicBlock and CheckCall
-  if (!CurMBB || !CheckCall)
-    return nullptr;
+MachineInstr *
+AArch64WinFixupBufferSecurityCheckPass::cloneLoadStackGuard(MachineFunction &MF,
+                                                            MachineInstr *MI) {
 
-  MachineFunction &MF = *CurMBB->getParent();
+  MachineInstr *ClonedInstr = MF.CloneMachineInstr(MI);
+
+  // Get the register class of the original destination register
+  Register OrigReg = MI->getOperand(0).getReg();
   MachineRegisterInfo &MRI = MF.getRegInfo();
+  const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
 
-  // Initialize reverse iterator starting just before CheckCall
-  MachineBasicBlock::reverse_iterator DIt(CheckCall);
-  MachineBasicBlock::reverse_iterator DEnd = CurMBB->rend();
-
-  // Reverse iterate from CheckCall to find LOAD_STACK_GUARD
-  for (; DIt != DEnd; ++DIt) {
-    MachineInstr &MI = *DIt;
-    if (MI.getOpcode() == TargetOpcode::LOAD_STACK_GUARD) {
-      // Clone the LOAD_STACK_GUARD instruction
-      MachineInstr *ClonedInstr = MF.CloneMachineInstr(&MI);
-
-      // Get the register class of the original destination register
-      Register OrigReg = MI.getOperand(0).getReg();
-      const TargetRegisterClass *RegClass = MRI.getRegClass(OrigReg);
-
-      // Create a new virtual register in the same register class
-      Register NewReg = MRI.createVirtualRegister(RegClass);
-
-      // Update operand 0 (destination) of the cloned instruction
-      MachineOperand &DestOperand = ClonedInstr->getOperand(0);
-      if (DestOperand.isReg() && DestOperand.isDef()) {
-        DestOperand.setReg(NewReg); // Set the new virtual register
-      }
+  // Create a new virtual register in the same register class
+  Register NewReg = MRI.createVirtualRegister(RegClass);
 
-      // Return the modified cloned instruction
-      return ClonedInstr;
-    }
+  // Update operand 0 (destination) of the cloned instruction
+  MachineOperand &DestOperand = ClonedInstr->getOperand(0);
+  if (DestOperand.isReg() && DestOperand.isDef()) {
+    DestOperand.setReg(NewReg); // Set the new virtual register
   }
 
-  // If no LOAD_STACK_GUARD instruction was found, return nullptr
-  return nullptr;
+  return ClonedInstr;
 }
 
-void AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
-    MachineBasicBlock *CurMBB, MachineInstr *CheckCall,
-    MachineInstr *SeqMI[5]) {
+bool AArch64WinFixupBufferSecurityCheckPass::getGuardCheckSequence(
+    MachineInstr *CheckCall, MachineInstr *SeqMI[5]) {
+
+  MachineBasicBlock *MBB = CheckCall->getParent();
 
   MachineBasicBlock::iterator UIt(CheckCall);
   MachineBasicBlock::reverse_iterator DIt(CheckCall);
 
   // Move forward to find the stack adjustment after the call
-  // to __security_check_cookie
   ++UIt;
+  if (UIt == MBB->end() || UIt->getOpcode() != AArch64::ADJCALLSTACKUP) {
+    return false;
+  }
   SeqMI[4] = &*UIt;
 
   // Assign the BL instruction (call to __security_check_cookie)
   SeqMI[3] = CheckCall;
 
-  // COPY function slot cookie
+  // Move backward to find the COPY instruction for the function slot cookie
+  // argument passing
   ++DIt;
+  if (DIt == MBB->rend() || DIt->getOpcode() != AArch64::COPY) {
+    return false;
+  }
   SeqMI[2] = &*DIt;
 
   // Move backward to find the instruction that loads the security cookie from
   // the stack
   ++DIt;
+  if (DIt == MBB->rend() || DIt->getOpcode() != AArch64::LDRXui) {
+    return false;
+  }
   SeqMI[1] = &*DIt;
 
-  ++DIt; // Find ADJCALLSTACKDOWN
+  // Move backward to find the stack adjustment before the call
+  ++DIt;
+  if (DIt == MBB->rend() || DIt->getOpcode() != AArch64::ADJCALLSTACKDOWN) {
+    return false;
+  }
   SeqMI[0] = &*DIt;
+
+  // If all instructions are matched and stored, the sequence is valid
+  return true;
 }
 
 void AArch64WinFixupBufferSecurityCheckPass::finishBlock(
@@ -185,21 +207,23 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
   if (!GV)
     return Changed;
 
-  const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
-
-  // Check if security check cookie call was installed or not
-  auto [CurMBB, CheckCall] = getSecurityCheckerBasicBlock(MF);
-  if (!CheckCall)
+  // Find LOAD_STACK_GUARD and __security_check_cookie instructions
+  auto [StackGuard, CheckCall] = findSecurityCheckAndLoadStackGuard(MF);
+  if (!CheckCall || !StackGuard)
     return Changed;
 
-  // Get sequence of instruction in CurMBB responsible for calling
+  // Get sequence of instructions in current basic block responsible for calling
   // __security_check_cookie
   MachineInstr *SeqMI[5];
-  getGuardCheckSequence(CurMBB, CheckCall, SeqMI);
+  if (!getGuardCheckSequence(CheckCall, SeqMI))
+    return Changed;
+
+  const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
+  MachineBasicBlock *CurMBB = CheckCall->getParent();
 
   // Find LOAD_STACK_GUARD in CurrMBB and build a new LOAD_STACK_GUARD
   // instruction with new destination register
-  MachineInstr *ClonedInstr = cloneLoadStackGuard(CurMBB, CheckCall);
+  MachineInstr *ClonedInstr = cloneLoadStackGuard(MF, StackGuard);
   if (!ClonedInstr)
     return Changed;
 
@@ -216,13 +240,14 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
   CurMBB->splice(InsertPt, CurMBB, std::next(InsertPt));
 
   // Create a new virtual register for the CMP instruction result
-  Register DiscardReg =
-      MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  Register DiscardReg = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
 
   // Emit the CMP instruction to compare stack cookie with global cookie
   BuildMI(*CurMBB, InsertPt, DebugLoc(), TII->get(AArch64::SUBSXrr))
-      .addReg(DiscardReg, RegState::Define | RegState::Dead) // Result discarded
-      .addReg(CookieLoadReg)    // First operand: stack cookie
+      .addReg(DiscardReg,
+              RegState::Define | RegState::Dead) // Result discarded
+      .addReg(CookieLoadReg)                     // First operand: stack cookie
       .addReg(GlobalCookieReg); // Second operand: global cookie
 
   // Create FailMBB basic block to call __security_check_cookie
@@ -258,6 +283,15 @@ bool AArch64WinFixupBufferSecurityCheckPass::runOnMachineFunction(
   CurMBB->addSuccessor(NewRetMBB);
   CurMBB->addSuccessor(FailMBB);
 
+  MachineDominatorTreeWrapperPass *WrapperPass =
+      getAnalysisIfAvailable<MachineDominatorTreeWrapperPass>();
+  MachineDominatorTree *MDT =
+      WrapperPass ? &WrapperPass->getDomTree() : nullptr;
+  if (MDT) {
+    MDT->addNewBlock(FailMBB, CurMBB);
+    MDT->addNewBlock(NewRetMBB, CurMBB);
+  }
+
   finishFunction(FailMBB, NewRetMBB);
 
   return !Changed;
diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
index 82559bfb0b172f..7ac7b7c73fe348 100644
--- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
@@ -165,10 +165,9 @@
 ; CHECK-NEXT:       Process Implicit Definitions
 ; CHECK-NEXT:       Remove unreachable machine basic blocks
 ; CHECK-NEXT:       Live Variable Analysis
-; CHECK-NEXT:       MachineDominator Tree Construction
-; CHECK-NEXT:       Machine Natural Loop  Construction
 ; CHECK-NEXT:       Eliminate PHI nodes for register allocation
 ; CHECK-NEXT:       Two-Address instruction pass
+; CHECK-NEXT:       MachineDominator Tree Construction
 ; CHECK-NEXT:       Slot index numbering
 ; CHECK-NEXT:       Live Interval Analysis
 ; CHECK-NEXT:       Register Coalescer



More information about the llvm-commits mailing list