[llvm] b12449f - [CodeGen] Refactor and document ThunkInserter (#97468)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 4 07:03:51 PDT 2024
Author: Anatoly Trosinenko
Date: 2024-07-04T17:03:47+03:00
New Revision: b12449fb289708f3d31c107d6e7977044a01da62
URL: https://github.com/llvm/llvm-project/commit/b12449fb289708f3d31c107d6e7977044a01da62
DIFF: https://github.com/llvm/llvm-project/commit/b12449fb289708f3d31c107d6e7977044a01da62.diff
LOG: [CodeGen] Refactor and document ThunkInserter (#97468)
In preparation for supporting BLRA* instructions in SLS Hardening on
AArch64, refactor ThunkInserter class.
The main intention of this commit is to document the way to merge the
BLR-rewriting logic of the AArch64SLSHardening pass into the
SLSBLRThunkInserter class. This makes it possible to only call
createThunkFunction for the thunks that are actually referenced.
Ultimately, it will prevent SLSBLRThunkInserter from unconditionally
generating about 1800 thunk functions corresponding to every possible
combination of operands passed to BLRAA or BLRAB instructions.
This particular commit does not affect the generated machine code and
consists of the following changes:
* document the existing behavior of ThunkInserter class
* introduce ThunkInserterPass template class to get rid of mostly
identical boilerplate code in ARM, AArch64 and X86 implementations
* move the InsertedThunks parameter from `mayUseThunk` to `insertThunks`
method
Added:
Modified:
llvm/include/llvm/CodeGen/IndirectThunks.h
llvm/lib/Target/AArch64/AArch64SLSHardening.cpp
llvm/lib/Target/ARM/ARMSLSHardening.cpp
llvm/lib/Target/X86/X86IndirectThunks.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/IndirectThunks.h b/llvm/include/llvm/CodeGen/IndirectThunks.h
index 9b064ab788bf7..6c16b326fedd0 100644
--- a/llvm/include/llvm/CodeGen/IndirectThunks.h
+++ b/llvm/include/llvm/CodeGen/IndirectThunks.h
@@ -1,4 +1,4 @@
-//===---- IndirectThunks.h - Indirect Thunk Base Class ----------*- C++ -*-===//
+//===---- IndirectThunks.h - Indirect thunk insertion helpers ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,7 +7,9 @@
//===----------------------------------------------------------------------===//
///
/// \file
-/// Contains a base class for Passes that inject an MI thunk.
+/// Contains a base ThunkInserter class that simplifies injection of MI thunks
+/// as well as a default implementation of MachineFunctionPass wrapping
+/// several `ThunkInserter`s for targets to extend.
///
//===----------------------------------------------------------------------===//
@@ -15,26 +17,95 @@
#define LLVM_CODEGEN_INDIRECTTHUNKS_H
#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
namespace llvm {
+/// This class assists in inserting MI thunk functions into the module and
+/// rewriting the existing machine functions to call these thunks.
+///
+/// One of the common cases is implementing security mitigations that involve
+/// replacing some machine code patterns with calls to special thunk functions.
+///
+/// Inserting a module pass late in the codegen pipeline may increase memory
+/// usage, as it serializes the transformations and forces preceding passes to
+/// produce machine code for all functions before running the module pass.
+/// For that reason, ThunkInserter can be driven by a MachineFunctionPass by
+/// passing one MachineFunction at a time to its `run(MMI, MF)` method.
+/// Then, the derived class should
+/// * call createThunkFunction from its insertThunks method exactly once for
+/// each of the thunk functions to be inserted
+/// * populate the thunk in its populateThunk method
+///
+/// Note that if some other pass is responsible for rewriting the functions,
+/// the insertThunks method may simply create all possible thunks at once,
+/// probably postponed until the first occurrence of possibly affected MF.
+///
+/// Alternatively, insertThunks method can rewrite MF by itself and only insert
+/// the thunks being called. In that case InsertedThunks variable can be used
+/// to track which thunks were already inserted.
+///
+/// In any case, the thunk function has to be inserted on behalf of some other
+/// function and then populated on its own "iteration" later - this is because
+/// MachineFunctionPass will see the newly created functions, but they first
+/// have to go through the preceding passes from the same pass manager,
+/// possibly even through the instruction selector.
+//
+// FIXME Maybe implement a documented and less surprising way of modifying
+// the module from a MachineFunctionPass that is restricted to inserting
+// completely new functions to the module.
template <typename Derived, typename InsertedThunksTy = bool>
class ThunkInserter {
Derived &getDerived() { return *static_cast<Derived *>(this); }
-protected:
// A variable used to track whether (and possible which) thunks have been
// inserted so far. InsertedThunksTy is usually a bool, but can be other types
// to represent more than one type of thunk. Requires an |= operator to
// accumulate results.
InsertedThunksTy InsertedThunks;
- void doInitialization(Module &M) {}
+
+protected:
+ // Interface for subclasses to use.
+
+ /// Create an empty thunk function.
+ ///
+ /// The new function will eventually be passed to populateThunk. If multiple
+ /// thunks are created, populateThunk can distinguish them by their names.
void createThunkFunction(MachineModuleInfo &MMI, StringRef Name,
bool Comdat = true, StringRef TargetAttrs = "");
+protected:
+ // Interface for subclasses to implement.
+ //
+ // Note: all functions are non-virtual and are called via getDerived().
+ // Note: only doInitialization() has an implementation.
+
+ /// Initializes thunk inserter.
+ void doInitialization(Module &M) {}
+
+ /// Returns common prefix for thunk function's names.
+ const char *getThunkPrefix(); // undefined
+
+ /// Checks if MF may use thunks (true - maybe, false - definitely not).
+ bool mayUseThunk(const MachineFunction &MF); // undefined
+
+ /// Rewrites the function if necessary, returns the set of thunks added.
+ InsertedThunksTy insertThunks(MachineModuleInfo &MMI, MachineFunction &MF,
+ InsertedThunksTy ExistingThunks); // undefined
+
+ /// Populate the thunk function with instructions.
+ ///
+ /// If multiple thunks are created, the content that must be inserted in the
+ /// thunk function body should be derived from the MF's name.
+ ///
+ /// Depending on the preceding passes in the pass manager, by the time
+ /// populateThunk is called, MF may have a few target-specific instructions
+ /// (such as a single MBB containing the return instruction).
+ void populateThunk(MachineFunction &MF); // undefined
+
public:
void init(Module &M) {
InsertedThunks = InsertedThunksTy{};
@@ -53,7 +124,7 @@ void ThunkInserter<Derived, InsertedThunksTy>::createThunkFunction(
Module &M = const_cast<Module &>(*MMI.getModule());
LLVMContext &Ctx = M.getContext();
- auto Type = FunctionType::get(Type::getVoidTy(Ctx), false);
+ auto *Type = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(Type,
Comdat ? GlobalValue::LinkOnceODRLinkage
: GlobalValue::InternalLinkage,
@@ -95,19 +166,15 @@ bool ThunkInserter<Derived, InsertedThunksTy>::run(MachineModuleInfo &MMI,
MachineFunction &MF) {
// If MF is not a thunk, check to see if we need to insert a thunk.
if (!MF.getName().starts_with(getDerived().getThunkPrefix())) {
- // Only add a thunk if one of the functions has the corresponding feature
- // enabled in its subtarget, and doesn't enable external thunks. The target
- // can use InsertedThunks to detect whether relevant thunks have already
- // been inserted.
- // FIXME: Conditionalize on indirect calls so we don't emit a thunk when
- // nothing will end up calling it.
- // FIXME: It's a little silly to look at every function just to enumerate
- // the subtargets, but eventually we'll want to look at them for indirect
- // calls, so maybe this is OK.
- if (!getDerived().mayUseThunk(MF, InsertedThunks))
+ // Only add thunks if one of the functions may use them.
+ if (!getDerived().mayUseThunk(MF))
return false;
- InsertedThunks |= getDerived().insertThunks(MMI, MF);
+ // The target can use InsertedThunks to detect whether relevant thunks
+ // have already been inserted.
+ // FIXME: Provide the way for insertThunks to notify us whether it changed
+ // the MF, instead of conservatively assuming it did.
+ InsertedThunks |= getDerived().insertThunks(MMI, MF, InsertedThunks);
return true;
}
@@ -116,6 +183,40 @@ bool ThunkInserter<Derived, InsertedThunksTy>::run(MachineModuleInfo &MMI,
return true;
}
+/// Basic implementation of MachineFunctionPass wrapping one or more
+/// `ThunkInserter`s passed as type parameters.
+template <typename... Inserters>
+class ThunkInserterPass : public MachineFunctionPass {
+protected:
+ std::tuple<Inserters...> TIs;
+
+ ThunkInserterPass(char &ID) : MachineFunctionPass(ID) {}
+
+public:
+ bool doInitialization(Module &M) override {
+ initTIs(M, TIs);
+ return false;
+ }
+
+ bool runOnMachineFunction(MachineFunction &MF) override {
+ auto &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
+ return runTIs(MMI, MF, TIs);
+ }
+
+private:
+ template <typename... ThunkInserterT>
+ static void initTIs(Module &M,
+ std::tuple<ThunkInserterT...> &ThunkInserters) {
+ (..., std::get<ThunkInserterT>(ThunkInserters).init(M));
+ }
+
+ template <typename... ThunkInserterT>
+ static bool runTIs(MachineModuleInfo &MMI, MachineFunction &MF,
+ std::tuple<ThunkInserterT...> &ThunkInserters) {
+ return (0 | ... | std::get<ThunkInserterT>(ThunkInserters).run(MMI, MF));
+ }
+};
+
} // namespace llvm
#endif
diff --git a/llvm/lib/Target/AArch64/AArch64SLSHardening.cpp b/llvm/lib/Target/AArch64/AArch64SLSHardening.cpp
index 41bbc003fd9bf..7660de5c082d1 100644
--- a/llvm/lib/Target/AArch64/AArch64SLSHardening.cpp
+++ b/llvm/lib/Target/AArch64/AArch64SLSHardening.cpp
@@ -183,15 +183,12 @@ static const struct ThunkNameAndReg {
namespace {
struct SLSBLRThunkInserter : ThunkInserter<SLSBLRThunkInserter> {
const char *getThunkPrefix() { return SLSBLRNamePrefix; }
- bool mayUseThunk(const MachineFunction &MF, bool InsertedThunks) {
- if (InsertedThunks)
- return false;
+ bool mayUseThunk(const MachineFunction &MF) {
ComdatThunks &= !MF.getSubtarget<AArch64Subtarget>().hardenSlsNoComdat();
- // FIXME: This could also check if there are any BLRs in the function
- // to more accurately reflect if a thunk will be needed.
return MF.getSubtarget<AArch64Subtarget>().hardenSlsBlr();
}
- bool insertThunks(MachineModuleInfo &MMI, MachineFunction &MF);
+ bool insertThunks(MachineModuleInfo &MMI, MachineFunction &MF,
+ bool ExistingThunks);
void populateThunk(MachineFunction &MF);
private:
@@ -200,7 +197,10 @@ struct SLSBLRThunkInserter : ThunkInserter<SLSBLRThunkInserter> {
} // namespace
bool SLSBLRThunkInserter::insertThunks(MachineModuleInfo &MMI,
- MachineFunction &MF) {
+ MachineFunction &MF,
+ bool ExistingThunks) {
+ if (ExistingThunks)
+ return false;
// FIXME: It probably would be possible to filter which thunks to produce
// based on which registers are actually used in BLR instructions in this
// function. But would that be a worthwhile optimization?
@@ -210,6 +210,8 @@ bool SLSBLRThunkInserter::insertThunks(MachineModuleInfo &MMI,
}
void SLSBLRThunkInserter::populateThunk(MachineFunction &MF) {
+ assert(MF.getFunction().hasComdat() == ComdatThunks &&
+ "ComdatThunks value changed since MF creation");
// FIXME: How to better communicate Register number, rather than through
// name and lookup table?
assert(MF.getName().starts_with(getThunkPrefix()));
@@ -411,30 +413,13 @@ FunctionPass *llvm::createAArch64SLSHardeningPass() {
}
namespace {
-class AArch64IndirectThunks : public MachineFunctionPass {
+class AArch64IndirectThunks : public ThunkInserterPass<SLSBLRThunkInserter> {
public:
static char ID;
- AArch64IndirectThunks() : MachineFunctionPass(ID) {}
+ AArch64IndirectThunks() : ThunkInserterPass(ID) {}
StringRef getPassName() const override { return "AArch64 Indirect Thunks"; }
-
- bool doInitialization(Module &M) override;
- bool runOnMachineFunction(MachineFunction &MF) override;
-
-private:
- std::tuple<SLSBLRThunkInserter> TIs;
-
- template <typename... ThunkInserterT>
- static void initTIs(Module &M,
- std::tuple<ThunkInserterT...> &ThunkInserters) {
- (..., std::get<ThunkInserterT>(ThunkInserters).init(M));
- }
- template <typename... ThunkInserterT>
- static bool runTIs(MachineModuleInfo &MMI, MachineFunction &MF,
- std::tuple<ThunkInserterT...> &ThunkInserters) {
- return (0 | ... | std::get<ThunkInserterT>(ThunkInserters).run(MMI, MF));
- }
};
} // end anonymous namespace
@@ -444,14 +429,3 @@ char AArch64IndirectThunks::ID = 0;
FunctionPass *llvm::createAArch64IndirectThunks() {
return new AArch64IndirectThunks();
}
-
-bool AArch64IndirectThunks::doInitialization(Module &M) {
- initTIs(M, TIs);
- return false;
-}
-
-bool AArch64IndirectThunks::runOnMachineFunction(MachineFunction &MF) {
- LLVM_DEBUG(dbgs() << getPassName() << '\n');
- auto &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
- return runTIs(MMI, MF, TIs);
-}
diff --git a/llvm/lib/Target/ARM/ARMSLSHardening.cpp b/llvm/lib/Target/ARM/ARMSLSHardening.cpp
index d9ff14ead60e2..d77db17090feb 100644
--- a/llvm/lib/Target/ARM/ARMSLSHardening.cpp
+++ b/llvm/lib/Target/ARM/ARMSLSHardening.cpp
@@ -163,7 +163,7 @@ static const struct ThunkNameRegMode {
// An enum for tracking whether Arm and Thumb thunks have been inserted into the
// current module so far.
-enum ArmInsertedThunks { ArmThunk = 1, ThumbThunk = 2 };
+enum ArmInsertedThunks { NoThunk = 0, ArmThunk = 1, ThumbThunk = 2 };
inline ArmInsertedThunks &operator|=(ArmInsertedThunks &X,
ArmInsertedThunks Y) {
@@ -174,19 +174,12 @@ namespace {
struct SLSBLRThunkInserter
: ThunkInserter<SLSBLRThunkInserter, ArmInsertedThunks> {
const char *getThunkPrefix() { return SLSBLRNamePrefix; }
- bool mayUseThunk(const MachineFunction &MF,
- ArmInsertedThunks InsertedThunks) {
- if ((InsertedThunks & ArmThunk &&
- !MF.getSubtarget<ARMSubtarget>().isThumb()) ||
- (InsertedThunks & ThumbThunk &&
- MF.getSubtarget<ARMSubtarget>().isThumb()))
- return false;
+ bool mayUseThunk(const MachineFunction &MF) {
ComdatThunks &= !MF.getSubtarget<ARMSubtarget>().hardenSlsNoComdat();
- // FIXME: This could also check if there are any indirect calls in the
- // function to more accurately reflect if a thunk will be needed.
return MF.getSubtarget<ARMSubtarget>().hardenSlsBlr();
}
- ArmInsertedThunks insertThunks(MachineModuleInfo &MMI, MachineFunction &MF);
+ ArmInsertedThunks insertThunks(MachineModuleInfo &MMI, MachineFunction &MF,
+ ArmInsertedThunks InsertedThunks);
void populateThunk(MachineFunction &MF);
private:
@@ -194,8 +187,14 @@ struct SLSBLRThunkInserter
};
} // namespace
-ArmInsertedThunks SLSBLRThunkInserter::insertThunks(MachineModuleInfo &MMI,
- MachineFunction &MF) {
+ArmInsertedThunks
+SLSBLRThunkInserter::insertThunks(MachineModuleInfo &MMI, MachineFunction &MF,
+ ArmInsertedThunks InsertedThunks) {
+ if ((InsertedThunks & ArmThunk &&
+ !MF.getSubtarget<ARMSubtarget>().isThumb()) ||
+ (InsertedThunks & ThumbThunk &&
+ MF.getSubtarget<ARMSubtarget>().isThumb()))
+ return NoThunk;
// FIXME: It probably would be possible to filter which thunks to produce
// based on which registers are actually used in indirect calls in this
// function. But would that be a worthwhile optimization?
@@ -208,6 +207,8 @@ ArmInsertedThunks SLSBLRThunkInserter::insertThunks(MachineModuleInfo &MMI,
}
void SLSBLRThunkInserter::populateThunk(MachineFunction &MF) {
+ assert(MF.getFunction().hasComdat() == ComdatThunks &&
+ "ComdatThunks value changed since MF creation");
// FIXME: How to better communicate Register number, rather than through
// name and lookup table?
assert(MF.getName().starts_with(getThunkPrefix()));
@@ -384,38 +385,14 @@ FunctionPass *llvm::createARMSLSHardeningPass() {
}
namespace {
-class ARMIndirectThunks : public MachineFunctionPass {
+class ARMIndirectThunks : public ThunkInserterPass<SLSBLRThunkInserter> {
public:
static char ID;
- ARMIndirectThunks() : MachineFunctionPass(ID) {}
+ ARMIndirectThunks() : ThunkInserterPass(ID) {}
StringRef getPassName() const override { return "ARM Indirect Thunks"; }
-
- bool doInitialization(Module &M) override;
- bool runOnMachineFunction(MachineFunction &MF) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- MachineFunctionPass::getAnalysisUsage(AU);
- AU.addRequired<MachineModuleInfoWrapperPass>();
- AU.addPreserved<MachineModuleInfoWrapperPass>();
- }
-
-private:
- std::tuple<SLSBLRThunkInserter> TIs;
-
- template <typename... ThunkInserterT>
- static void initTIs(Module &M,
- std::tuple<ThunkInserterT...> &ThunkInserters) {
- (..., std::get<ThunkInserterT>(ThunkInserters).init(M));
- }
- template <typename... ThunkInserterT>
- static bool runTIs(MachineModuleInfo &MMI, MachineFunction &MF,
- std::tuple<ThunkInserterT...> &ThunkInserters) {
- return (0 | ... | std::get<ThunkInserterT>(ThunkInserters).run(MMI, MF));
- }
};
-
} // end anonymous namespace
char ARMIndirectThunks::ID = 0;
@@ -423,14 +400,3 @@ char ARMIndirectThunks::ID = 0;
FunctionPass *llvm::createARMIndirectThunks() {
return new ARMIndirectThunks();
}
-
-bool ARMIndirectThunks::doInitialization(Module &M) {
- initTIs(M, TIs);
- return false;
-}
-
-bool ARMIndirectThunks::runOnMachineFunction(MachineFunction &MF) {
- LLVM_DEBUG(dbgs() << getPassName() << '\n');
- auto &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
- return runTIs(MMI, MF, TIs);
-}
diff --git a/llvm/lib/Target/X86/X86IndirectThunks.cpp b/llvm/lib/Target/X86/X86IndirectThunks.cpp
index ecc52600f7593..4f4a8d8bd09d5 100644
--- a/llvm/lib/Target/X86/X86IndirectThunks.cpp
+++ b/llvm/lib/Target/X86/X86IndirectThunks.cpp
@@ -61,26 +61,26 @@ static const char R11LVIThunkName[] = "__llvm_lvi_thunk_r11";
namespace {
struct RetpolineThunkInserter : ThunkInserter<RetpolineThunkInserter> {
const char *getThunkPrefix() { return RetpolineNamePrefix; }
- bool mayUseThunk(const MachineFunction &MF, bool InsertedThunks) {
- if (InsertedThunks)
- return false;
+ bool mayUseThunk(const MachineFunction &MF) {
const auto &STI = MF.getSubtarget<X86Subtarget>();
return (STI.useRetpolineIndirectCalls() ||
STI.useRetpolineIndirectBranches()) &&
!STI.useRetpolineExternalThunk();
}
- bool insertThunks(MachineModuleInfo &MMI, MachineFunction &MF);
+ bool insertThunks(MachineModuleInfo &MMI, MachineFunction &MF,
+ bool ExistingThunks);
void populateThunk(MachineFunction &MF);
};
struct LVIThunkInserter : ThunkInserter<LVIThunkInserter> {
const char *getThunkPrefix() { return LVIThunkNamePrefix; }
- bool mayUseThunk(const MachineFunction &MF, bool InsertedThunks) {
- if (InsertedThunks)
- return false;
+ bool mayUseThunk(const MachineFunction &MF) {
return MF.getSubtarget<X86Subtarget>().useLVIControlFlowIntegrity();
}
- bool insertThunks(MachineModuleInfo &MMI, MachineFunction &MF) {
+ bool insertThunks(MachineModuleInfo &MMI, MachineFunction &MF,
+ bool ExistingThunks) {
+ if (ExistingThunks)
+ return false;
createThunkFunction(MMI, R11LVIThunkName);
return true;
}
@@ -104,36 +104,23 @@ struct LVIThunkInserter : ThunkInserter<LVIThunkInserter> {
}
};
-class X86IndirectThunks : public MachineFunctionPass {
+class X86IndirectThunks
+ : public ThunkInserterPass<RetpolineThunkInserter, LVIThunkInserter> {
public:
static char ID;
- X86IndirectThunks() : MachineFunctionPass(ID) {}
+ X86IndirectThunks() : ThunkInserterPass(ID) {}
StringRef getPassName() const override { return "X86 Indirect Thunks"; }
-
- bool doInitialization(Module &M) override;
- bool runOnMachineFunction(MachineFunction &MF) override;
-
-private:
- std::tuple<RetpolineThunkInserter, LVIThunkInserter> TIs;
-
- template <typename... ThunkInserterT>
- static void initTIs(Module &M,
- std::tuple<ThunkInserterT...> &ThunkInserters) {
- (..., std::get<ThunkInserterT>(ThunkInserters).init(M));
- }
- template <typename... ThunkInserterT>
- static bool runTIs(MachineModuleInfo &MMI, MachineFunction &MF,
- std::tuple<ThunkInserterT...> &ThunkInserters) {
- return (0 | ... | std::get<ThunkInserterT>(ThunkInserters).run(MMI, MF));
- }
};
} // end anonymous namespace
bool RetpolineThunkInserter::insertThunks(MachineModuleInfo &MMI,
- MachineFunction &MF) {
+ MachineFunction &MF,
+ bool ExistingThunks) {
+ if (ExistingThunks)
+ return false;
if (MMI.getTarget().getTargetTriple().getArch() == Triple::x86_64)
createThunkFunction(MMI, R11RetpolineName);
else
@@ -259,14 +246,3 @@ FunctionPass *llvm::createX86IndirectThunksPass() {
}
char X86IndirectThunks::ID = 0;
-
-bool X86IndirectThunks::doInitialization(Module &M) {
- initTIs(M, TIs);
- return false;
-}
-
-bool X86IndirectThunks::runOnMachineFunction(MachineFunction &MF) {
- LLVM_DEBUG(dbgs() << getPassName() << '\n');
- auto &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
- return runTIs(MMI, MF, TIs);
-}
More information about the llvm-commits
mailing list