[llvm] [CodeGen] Port `CFGuard` to new pass manager (PR #75146)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 12 00:02:39 PST 2023


https://github.com/paperchalice created https://github.com/llvm/llvm-project/pull/75146

Port `CFGuard` to new pass manager, add a pass parameter to choose guard mechanism.

>From 1bcaa82c55866afda6afe0350d028a81b274611d Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Tue, 12 Dec 2023 15:27:03 +0800
Subject: [PATCH] [CodeGen] Port `CFGuard` to new pass manager

---
 .../include/llvm/CodeGen/CodeGenPassBuilder.h |  1 +
 .../llvm/CodeGen/MachinePassRegistry.def      |  3 +-
 llvm/include/llvm/Transforms/CFGuard.h        | 13 +++
 llvm/lib/Passes/CMakeLists.txt                |  1 +
 llvm/lib/Passes/PassBuilder.cpp               | 21 +++++
 llvm/lib/Passes/PassRegistry.def              |  4 +
 llvm/lib/Transforms/CFGuard/CFGuard.cpp       | 85 +++++++++++--------
 7 files changed, 89 insertions(+), 39 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
index 4f4f00b4e0e03b..8c6561ec8fdf99 100644
--- a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
+++ b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
@@ -47,6 +47,7 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Target/CGPassBuilderOption.h"
 #include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/CFGuard.h"
 #include "llvm/Transforms/Scalar/ConstantHoisting.h"
 #include "llvm/Transforms/Scalar/LoopPassManager.h"
 #include "llvm/Transforms/Scalar/LoopStrengthReduce.h"
diff --git a/llvm/include/llvm/CodeGen/MachinePassRegistry.def b/llvm/include/llvm/CodeGen/MachinePassRegistry.def
index b98daa485cf8f3..d690b913e6b6ce 100644
--- a/llvm/include/llvm/CodeGen/MachinePassRegistry.def
+++ b/llvm/include/llvm/CodeGen/MachinePassRegistry.def
@@ -38,6 +38,7 @@ FUNCTION_ANALYSIS("targetir", TargetIRAnalysis,
 #define FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR)
 #endif
 FUNCTION_PASS("callbrprepare", CallBrPreparePass, ())
+FUNCTION_PASS("cfguard", CFGuardPass, ())
 FUNCTION_PASS("consthoist", ConstantHoistingPass, ())
 FUNCTION_PASS("dwarf-eh-prepare", DwarfEHPreparePass, (TM))
 FUNCTION_PASS("ee-instrument", EntryExitInstrumenterPass, (false))
@@ -123,8 +124,6 @@ MACHINE_FUNCTION_ANALYSIS("pass-instrumentation", PassInstrumentationAnalysis,
 #define DUMMY_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR)
 #endif
 DUMMY_FUNCTION_PASS("atomic-expand", AtomicExpandPass, ())
-DUMMY_FUNCTION_PASS("cfguard-check", CFGuardCheckPass, ())
-DUMMY_FUNCTION_PASS("cfguard-dispatch", CFGuardDispatchPass, ())
 DUMMY_FUNCTION_PASS("codegenprepare", CodeGenPreparePass, ())
 DUMMY_FUNCTION_PASS("expandmemcmp", ExpandMemCmpPass, ())
 DUMMY_FUNCTION_PASS("gc-lowering", GCLoweringPass, ())
diff --git a/llvm/include/llvm/Transforms/CFGuard.h b/llvm/include/llvm/Transforms/CFGuard.h
index 86fcbc3c13e8b7..caf822a2ec9fb3 100644
--- a/llvm/include/llvm/Transforms/CFGuard.h
+++ b/llvm/include/llvm/Transforms/CFGuard.h
@@ -11,10 +11,23 @@
 #ifndef LLVM_TRANSFORMS_CFGUARD_H
 #define LLVM_TRANSFORMS_CFGUARD_H
 
+#include "llvm/IR/PassManager.h"
+
 namespace llvm {
 
 class FunctionPass;
 
+class CFGuardPass : public PassInfoMixin<CFGuardPass> {
+public:
+  enum class Mechanism { Check, Dispatch };
+
+  CFGuardPass(Mechanism M = Mechanism::Check) : GuardMechanism(M) {}
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
+
+private:
+  Mechanism GuardMechanism;
+};
+
 /// Insert Control FLow Guard checks on indirect function calls.
 FunctionPass *createCFGuardCheckPass();
 
diff --git a/llvm/lib/Passes/CMakeLists.txt b/llvm/lib/Passes/CMakeLists.txt
index e42edfe9496974..98d2de76c0e114 100644
--- a/llvm/lib/Passes/CMakeLists.txt
+++ b/llvm/lib/Passes/CMakeLists.txt
@@ -16,6 +16,7 @@ add_llvm_component_library(LLVMPasses
   LINK_COMPONENTS
   AggressiveInstCombine
   Analysis
+  CFGuard
   CodeGen
   Core
   Coroutines
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 41f1ce966ad2eb..9eb27892580b3f 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -99,6 +99,7 @@
 #include "llvm/Support/Regex.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
+#include "llvm/Transforms/CFGuard.h"
 #include "llvm/Transforms/Coroutines/CoroCleanup.h"
 #include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h"
 #include "llvm/Transforms/Coroutines/CoroEarly.h"
@@ -737,6 +738,26 @@ Expected<bool> parsePostOrderFunctionAttrsPassOptions(StringRef Params) {
                                "PostOrderFunctionAttrs");
 }
 
+Expected<CFGuardPass::Mechanism> parseCFGuardPassOptions(StringRef Params) {
+  if (Params.empty())
+    return CFGuardPass::Mechanism::Check;
+
+  auto [Param, RHS] = Params.split(';');
+  if (!RHS.empty())
+    return make_error<StringError>(
+        formatv("too many CFGuardPass parameters '{0}' ", Params).str(),
+        inconvertibleErrorCode());
+
+  if (Param == "check")
+    return CFGuardPass::Mechanism::Check;
+  if (Param == "dispatch")
+    return CFGuardPass::Mechanism::Dispatch;
+
+  return make_error<StringError>(
+      formatv("invalid CFGuardPass mechanism: '{0}' ", Param).str(),
+      inconvertibleErrorCode());
+}
+
 Expected<bool> parseEarlyCSEPassOptions(StringRef Params) {
   return parseSinglePassOption(Params, "memssa", "EarlyCSE");
 }
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 9afb73b2e8c168..cb0b8c5898d8d3 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -432,6 +432,10 @@ FUNCTION_PASS("wasm-eh-prepare", WasmEHPreparePass())
 #ifndef FUNCTION_PASS_WITH_PARAMS
 #define FUNCTION_PASS_WITH_PARAMS(NAME, CLASS, CREATE_PASS, PARSER, PARAMS)
 #endif
+FUNCTION_PASS_WITH_PARAMS(
+    "cfguard", "CFGuardPass",
+    [](CFGuardPass::Mechanism M) { return CFGuardPass(M); },
+    parseCFGuardPassOptions, "check;dispatch")
 FUNCTION_PASS_WITH_PARAMS(
     "early-cse", "EarlyCSEPass",
     [](bool UseMemorySSA) { return EarlyCSEPass(UseMemorySSA); },
diff --git a/llvm/lib/Transforms/CFGuard/CFGuard.cpp b/llvm/lib/Transforms/CFGuard/CFGuard.cpp
index 387734358775b3..4d4306576017be 100644
--- a/llvm/lib/Transforms/CFGuard/CFGuard.cpp
+++ b/llvm/lib/Transforms/CFGuard/CFGuard.cpp
@@ -34,25 +34,22 @@ namespace {
 
 /// Adds Control Flow Guard (CFG) checks on indirect function calls/invokes.
 /// These checks ensure that the target address corresponds to the start of an
-/// address-taken function. X86_64 targets use the CF_Dispatch mechanism. X86,
-/// ARM, and AArch64 targets use the CF_Check machanism.
-class CFGuard : public FunctionPass {
+/// address-taken function. X86_64 targets use the Mechanism::Dispatch
+/// mechanism. X86, ARM, and AArch64 targets use the Mechanism::Check machanism.
+class CFGuardImpl {
 public:
-  static char ID;
-
-  enum Mechanism { CF_Check, CF_Dispatch };
-
-  // Default constructor required for the INITIALIZE_PASS macro.
-  CFGuard() : FunctionPass(ID) {
-    initializeCFGuardPass(*PassRegistry::getPassRegistry());
-    // By default, use the guard check mechanism.
-    GuardMechanism = CF_Check;
-  }
-
-  // Recommended constructor used to specify the type of guard mechanism.
-  CFGuard(Mechanism Var) : FunctionPass(ID) {
-    initializeCFGuardPass(*PassRegistry::getPassRegistry());
-    GuardMechanism = Var;
+  using Mechanism = CFGuardPass::Mechanism;
+
+  CFGuardImpl(Mechanism M) : GuardMechanism(M) {
+    // Get or insert the guard check or dispatch global symbols.
+    switch (GuardMechanism) {
+    case Mechanism::Check:
+      GuardFnName = "__guard_check_icall_fptr";
+      break;
+    case Mechanism::Dispatch:
+      GuardFnName = "__guard_dispatch_icall_fptr";
+      break;
+    }
   }
 
   /// Inserts a Control Flow Guard (CFG) check on an indirect call using the CFG
@@ -141,21 +138,37 @@ class CFGuard : public FunctionPass {
   /// \param CB indirect call to instrument.
   void insertCFGuardDispatch(CallBase *CB);
 
-  bool doInitialization(Module &M) override;
-  bool runOnFunction(Function &F) override;
+  bool doInitialization(Module &M);
+  bool runOnFunction(Function &F);
 
 private:
   // Only add checks if the module has the cfguard=2 flag.
   int cfguard_module_flag = 0;
-  Mechanism GuardMechanism = CF_Check;
+  StringRef GuardFnName;
+  Mechanism GuardMechanism = Mechanism::Check;
   FunctionType *GuardFnType = nullptr;
   PointerType *GuardFnPtrType = nullptr;
   Constant *GuardFnGlobal = nullptr;
 };
 
+class CFGuard : public FunctionPass {
+  CFGuardImpl Impl;
+
+public:
+  static char ID;
+
+  // Default constructor required for the INITIALIZE_PASS macro.
+  CFGuard(CFGuardImpl::Mechanism M) : FunctionPass(ID), Impl(M) {
+    initializeCFGuardPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool doInitialization(Module &M) override { return Impl.doInitialization(M); }
+  bool runOnFunction(Function &F) override { return Impl.runOnFunction(F); }
+};
+
 } // end anonymous namespace
 
-void CFGuard::insertCFGuardCheck(CallBase *CB) {
+void CFGuardImpl::insertCFGuardCheck(CallBase *CB) {
 
   assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
          "Only applicable for Windows targets");
@@ -184,7 +197,7 @@ void CFGuard::insertCFGuardCheck(CallBase *CB) {
   GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
 }
 
-void CFGuard::insertCFGuardDispatch(CallBase *CB) {
+void CFGuardImpl::insertCFGuardDispatch(CallBase *CB) {
 
   assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
          "Only applicable for Windows targets");
@@ -218,7 +231,7 @@ void CFGuard::insertCFGuardDispatch(CallBase *CB) {
   CB->eraseFromParent();
 }
 
-bool CFGuard::doInitialization(Module &M) {
+bool CFGuardImpl::doInitialization(Module &M) {
 
   // Check if this module has the cfguard flag and read its value.
   if (auto *MD =
@@ -235,15 +248,6 @@ bool CFGuard::doInitialization(Module &M) {
                         {PointerType::getUnqual(M.getContext())}, false);
   GuardFnPtrType = PointerType::get(GuardFnType, 0);
 
-  // Get or insert the guard check or dispatch global symbols.
-  llvm::StringRef GuardFnName;
-  if (GuardMechanism == CF_Check) {
-    GuardFnName = "__guard_check_icall_fptr";
-  } else if (GuardMechanism == CF_Dispatch) {
-    GuardFnName = "__guard_dispatch_icall_fptr";
-  } else {
-    assert(false && "Invalid CFGuard mechanism");
-  }
   GuardFnGlobal = M.getOrInsertGlobal(GuardFnName, GuardFnPtrType, [&] {
     auto *Var = new GlobalVariable(M, GuardFnPtrType, false,
                                    GlobalVariable::ExternalLinkage, nullptr,
@@ -255,7 +259,7 @@ bool CFGuard::doInitialization(Module &M) {
   return true;
 }
 
-bool CFGuard::runOnFunction(Function &F) {
+bool CFGuardImpl::runOnFunction(Function &F) {
 
   // Skip modules for which CFGuard checks have been disabled.
   if (cfguard_module_flag != 2)
@@ -283,7 +287,7 @@ bool CFGuard::runOnFunction(Function &F) {
   }
 
   // For each indirect call/invoke, add the appropriate dispatch or check.
-  if (GuardMechanism == CF_Dispatch) {
+  if (GuardMechanism == Mechanism::Dispatch) {
     for (CallBase *CB : IndirectCalls) {
       insertCFGuardDispatch(CB);
     }
@@ -296,13 +300,20 @@ bool CFGuard::runOnFunction(Function &F) {
   return true;
 }
 
+PreservedAnalyses CFGuardPass::run(Function &F, FunctionAnalysisManager &FAM) {
+  CFGuardImpl Impl(GuardMechanism);
+  bool Changed = Impl.doInitialization(*F.getParent());
+  Changed |= Impl.runOnFunction(F);
+  return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
+}
+
 char CFGuard::ID = 0;
 INITIALIZE_PASS(CFGuard, "CFGuard", "CFGuard", false, false)
 
 FunctionPass *llvm::createCFGuardCheckPass() {
-  return new CFGuard(CFGuard::CF_Check);
+  return new CFGuard(CFGuardPass::Mechanism::Check);
 }
 
 FunctionPass *llvm::createCFGuardDispatchPass() {
-  return new CFGuard(CFGuard::CF_Dispatch);
+  return new CFGuard(CFGuardPass::Mechanism::Dispatch);
 }



More information about the llvm-commits mailing list