[llvm-branch-commits] [ModuleUtils] Add updateGlobalCtors/updateGlobalDtors (PR #101757)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Aug 2 14:52:27 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Vitaly Buka (vitalybuka)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/101757.diff
3 Files Affected:
- (modified) llvm/include/llvm/Transforms/Utils/ModuleUtils.h (+9)
- (modified) llvm/lib/Transforms/Utils/ModuleUtils.cpp (+44)
- (modified) llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp (+43-3)
``````````diff
diff --git a/llvm/include/llvm/Transforms/Utils/ModuleUtils.h b/llvm/include/llvm/Transforms/Utils/ModuleUtils.h
index 1ec87505544f8..37d6a3e33315a 100644
--- a/llvm/include/llvm/Transforms/Utils/ModuleUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/ModuleUtils.h
@@ -30,6 +30,7 @@ class FunctionCallee;
class GlobalIFunc;
class GlobalValue;
class Constant;
+class ConstantStruct;
class Value;
class Type;
@@ -44,6 +45,14 @@ void appendToGlobalCtors(Module &M, Function *F, int Priority,
void appendToGlobalDtors(Module &M, Function *F, int Priority,
Constant *Data = nullptr);
+/// Apply 'Fn' to the list of global ctors of module M and replace contructor
+/// record with the one returned by `Fn`. If `nullptr` was returned, the
+/// corresponding constructor will be removed from the array. For details see
+/// https://llvm.org/docs/LangRef.html#the-llvm-global-ctors-global-variable
+using GlobalCtorUpdateFn = llvm::function_ref<Constant *(Constant *)>;
+void updateGlobalCtors(Module &M, const GlobalCtorUpdateFn &Fn);
+void updateGlobalDtors(Module &M, const GlobalCtorUpdateFn &Fn);
+
/// Sets the KCFI type for the function. Used for compiler-generated functions
/// that are indirectly called in instrumented code.
void setKCFIType(Module &M, Function &F, StringRef MangledType);
diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
index 122279160cc7e..e443d820a2256 100644
--- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp
+++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
@@ -79,6 +79,50 @@ void llvm::appendToGlobalDtors(Module &M, Function *F, int Priority, Constant *D
appendToGlobalArray("llvm.global_dtors", M, F, Priority, Data);
}
+static void updateGlobalArray(StringRef ArrayName, Module &M,
+ const GlobalCtorUpdateFn &Fn) {
+ GlobalVariable *GVCtor = M.getNamedGlobal(ArrayName);
+ if (!GVCtor)
+ return;
+
+ IRBuilder<> IRB(M.getContext());
+ SmallVector<Constant *, 16> CurrentCtors;
+ bool Changed = false;
+ StructType *EltTy =
+ cast<StructType>(GVCtor->getValueType()->getArrayElementType());
+ if (Constant *Init = GVCtor->getInitializer()) {
+ CurrentCtors.reserve(Init->getNumOperands());
+ for (Value *OP : Init->operands()) {
+ Constant *C = cast<Constant>(OP);
+ Constant *NewC = Fn(C);
+ Changed |= (!NewC || NewC != C);
+ if (NewC)
+ CurrentCtors.push_back(NewC);
+ }
+ }
+ if (!Changed)
+ return;
+
+ GVCtor->eraseFromParent();
+
+ // Create a new initializer.
+ ArrayType *AT = ArrayType::get(EltTy, CurrentCtors.size());
+ Constant *NewInit = ConstantArray::get(AT, CurrentCtors);
+
+ // Create the new global variable and replace all uses of
+ // the old global variable with the new one.
+ (void)new GlobalVariable(M, NewInit->getType(), false,
+ GlobalValue::AppendingLinkage, NewInit, ArrayName);
+}
+
+void llvm::updateGlobalCtors(Module &M, const GlobalCtorUpdateFn &Fn) {
+ updateGlobalArray("llvm.global_ctors", M, Fn);
+}
+
+void llvm::updateGlobalDtors(Module &M, const GlobalCtorUpdateFn &Fn) {
+ updateGlobalArray("llvm.global_dtors", M, Fn);
+}
+
static void collectUsedGlobals(GlobalVariable *GV,
SmallSetVector<Constant *, 16> &Init) {
if (!GV || !GV->hasInitializer())
diff --git a/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp b/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp
index 0ed7be9620a6f..582448a14ba8a 100644
--- a/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp
@@ -70,17 +70,21 @@ TEST(ModuleUtils, AppendToUsedList2) {
}
using AppendFnType = decltype(&appendToGlobalCtors);
-using ParamType = std::tuple<StringRef, AppendFnType>;
+using UpdateFnType = decltype(&updateGlobalCtors);
+using ParamType = std::tuple<StringRef, AppendFnType, UpdateFnType>;
class ModuleUtilsTest : public testing::TestWithParam<ParamType> {
public:
StringRef arrayName() const { return std::get<0>(GetParam()); }
AppendFnType appendFn() const { return std::get<AppendFnType>(GetParam()); }
+ UpdateFnType updateFn() const { return std::get<UpdateFnType>(GetParam()); }
};
INSTANTIATE_TEST_SUITE_P(
ModuleUtilsTestCtors, ModuleUtilsTest,
- ::testing::Values(ParamType{"llvm.global_ctors", &appendToGlobalCtors},
- ParamType{"llvm.global_dtors", &appendToGlobalDtors}));
+ ::testing::Values(ParamType{"llvm.global_ctors", &appendToGlobalCtors,
+ &updateGlobalCtors},
+ ParamType{"llvm.global_dtors", &appendToGlobalDtors,
+ &updateGlobalDtors}));
TEST_P(ModuleUtilsTest, AppendToMissingArray) {
LLVMContext C;
@@ -124,3 +128,39 @@ TEST_P(ModuleUtilsTest, AppendToArray) {
11, nullptr);
EXPECT_EQ(3, getListSize(*M, arrayName()));
}
+
+TEST_P(ModuleUtilsTest, UpdateArray) {
+ LLVMContext C;
+
+ std::unique_ptr<Module> M =
+ parseIR(C, (R"(@)" + arrayName() +
+ R"( = appending global [2 x { i32, ptr, ptr }] [
+ { i32, ptr, ptr } { i32 65535, ptr null, ptr null },
+ { i32, ptr, ptr } { i32 0, ptr null, ptr null }]
+ )")
+ .str());
+
+ EXPECT_EQ(2, getListSize(*M, arrayName()));
+ updateFn()(*M, [](Constant *C) -> Constant * {
+ ConstantStruct *CS = dyn_cast<ConstantStruct>(C);
+ if (!CS)
+ return nullptr;
+ StructType *EltTy = cast<StructType>(C->getType());
+ Constant *CSVals[3] = {
+ ConstantInt::getSigned(CS->getOperand(0)->getType(), 12),
+ CS->getOperand(1),
+ CS->getOperand(2),
+ };
+ return ConstantStruct::get(EltTy,
+ ArrayRef(CSVals, EltTy->getNumElements()));
+ });
+ EXPECT_EQ(1, getListSize(*M, arrayName()));
+ ConstantArray *CA = dyn_cast<ConstantArray>(
+ M->getGlobalVariable(arrayName())->getInitializer());
+ ASSERT_NE(nullptr, CA);
+ ConstantStruct *CS = dyn_cast<ConstantStruct>(CA->getOperand(0));
+ ASSERT_NE(nullptr, CS);
+ ConstantInt *Pri = dyn_cast<ConstantInt>(CS->getOperand(0));
+ ASSERT_NE(nullptr, Pri);
+ EXPECT_EQ(12u, Pri->getLimitedValue());
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/101757
More information about the llvm-branch-commits
mailing list