Skip to content

Commit 048cf88

Browse files
authored
[ModuleUtils] Add transformGlobal{C,D}tors (#101757)
For #101772
1 parent a0a9bf5 commit 048cf88

File tree

3 files changed

+98
-3
lines changed

3 files changed

+98
-3
lines changed

llvm/include/llvm/Transforms/Utils/ModuleUtils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class FunctionCallee;
3030
class GlobalIFunc;
3131
class GlobalValue;
3232
class Constant;
33+
class ConstantStruct;
3334
class Value;
3435
class Type;
3536

@@ -44,6 +45,14 @@ void appendToGlobalCtors(Module &M, Function *F, int Priority,
4445
void appendToGlobalDtors(Module &M, Function *F, int Priority,
4546
Constant *Data = nullptr);
4647

48+
/// Apply 'Fn' to the list of global ctors of module M and replace contructor
49+
/// record with the one returned by `Fn`. If `nullptr` was returned, the
50+
/// corresponding constructor will be removed from the array. For details see
51+
/// https://llvm.org/docs/LangRef.html#the-llvm-global-ctors-global-variable
52+
using GlobalCtorTransformFn = llvm::function_ref<Constant *(Constant *)>;
53+
void transformGlobalCtors(Module &M, const GlobalCtorTransformFn &Fn);
54+
void transformGlobalDtors(Module &M, const GlobalCtorTransformFn &Fn);
55+
4756
/// Sets the KCFI type for the function. Used for compiler-generated functions
4857
/// that are indirectly called in instrumented code.
4958
void setKCFIType(Module &M, Function &F, StringRef MangledType);

llvm/lib/Transforms/Utils/ModuleUtils.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,50 @@ void llvm::appendToGlobalDtors(Module &M, Function *F, int Priority, Constant *D
7979
appendToGlobalArray("llvm.global_dtors", M, F, Priority, Data);
8080
}
8181

82+
static void transformGlobalArray(StringRef ArrayName, Module &M,
83+
const GlobalCtorTransformFn &Fn) {
84+
GlobalVariable *GVCtor = M.getNamedGlobal(ArrayName);
85+
if (!GVCtor)
86+
return;
87+
88+
IRBuilder<> IRB(M.getContext());
89+
SmallVector<Constant *, 16> CurrentCtors;
90+
bool Changed = false;
91+
StructType *EltTy =
92+
cast<StructType>(GVCtor->getValueType()->getArrayElementType());
93+
if (Constant *Init = GVCtor->getInitializer()) {
94+
CurrentCtors.reserve(Init->getNumOperands());
95+
for (Value *OP : Init->operands()) {
96+
Constant *C = cast<Constant>(OP);
97+
Constant *NewC = Fn(C);
98+
Changed |= (!NewC || NewC != C);
99+
if (NewC)
100+
CurrentCtors.push_back(NewC);
101+
}
102+
}
103+
if (!Changed)
104+
return;
105+
106+
GVCtor->eraseFromParent();
107+
108+
// Create a new initializer.
109+
ArrayType *AT = ArrayType::get(EltTy, CurrentCtors.size());
110+
Constant *NewInit = ConstantArray::get(AT, CurrentCtors);
111+
112+
// Create the new global variable and replace all uses of
113+
// the old global variable with the new one.
114+
(void)new GlobalVariable(M, NewInit->getType(), false,
115+
GlobalValue::AppendingLinkage, NewInit, ArrayName);
116+
}
117+
118+
void llvm::transformGlobalCtors(Module &M, const GlobalCtorTransformFn &Fn) {
119+
transformGlobalArray("llvm.global_ctors", M, Fn);
120+
}
121+
122+
void llvm::transformGlobalDtors(Module &M, const GlobalCtorTransformFn &Fn) {
123+
transformGlobalArray("llvm.global_dtors", M, Fn);
124+
}
125+
82126
static void collectUsedGlobals(GlobalVariable *GV,
83127
SmallSetVector<Constant *, 16> &Init) {
84128
if (!GV || !GV->hasInitializer())

llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,23 @@ TEST(ModuleUtils, AppendToUsedList2) {
7070
}
7171

7272
using AppendFnType = decltype(&appendToGlobalCtors);
73-
using ParamType = std::tuple<StringRef, AppendFnType>;
73+
using TransformFnType = decltype(&transformGlobalCtors);
74+
using ParamType = std::tuple<StringRef, AppendFnType, TransformFnType>;
7475
class ModuleUtilsTest : public testing::TestWithParam<ParamType> {
7576
public:
7677
StringRef arrayName() const { return std::get<0>(GetParam()); }
7778
AppendFnType appendFn() const { return std::get<AppendFnType>(GetParam()); }
79+
TransformFnType transformFn() const {
80+
return std::get<TransformFnType>(GetParam());
81+
}
7882
};
7983

8084
INSTANTIATE_TEST_SUITE_P(
8185
ModuleUtilsTestCtors, ModuleUtilsTest,
82-
::testing::Values(ParamType{"llvm.global_ctors", &appendToGlobalCtors},
83-
ParamType{"llvm.global_dtors", &appendToGlobalDtors}));
86+
::testing::Values(ParamType{"llvm.global_ctors", &appendToGlobalCtors,
87+
&transformGlobalCtors},
88+
ParamType{"llvm.global_dtors", &appendToGlobalDtors,
89+
&transformGlobalDtors}));
8490

8591
TEST_P(ModuleUtilsTest, AppendToMissingArray) {
8692
LLVMContext C;
@@ -124,3 +130,39 @@ TEST_P(ModuleUtilsTest, AppendToArray) {
124130
11, nullptr);
125131
EXPECT_EQ(3, getListSize(*M, arrayName()));
126132
}
133+
134+
TEST_P(ModuleUtilsTest, UpdateArray) {
135+
LLVMContext C;
136+
137+
std::unique_ptr<Module> M =
138+
parseIR(C, (R"(@)" + arrayName() +
139+
R"( = appending global [2 x { i32, ptr, ptr }] [
140+
{ i32, ptr, ptr } { i32 65535, ptr null, ptr null },
141+
{ i32, ptr, ptr } { i32 0, ptr null, ptr null }]
142+
)")
143+
.str());
144+
145+
EXPECT_EQ(2, getListSize(*M, arrayName()));
146+
transformFn()(*M, [](Constant *C) -> Constant * {
147+
ConstantStruct *CS = dyn_cast<ConstantStruct>(C);
148+
if (!CS)
149+
return nullptr;
150+
StructType *EltTy = cast<StructType>(C->getType());
151+
Constant *CSVals[3] = {
152+
ConstantInt::getSigned(CS->getOperand(0)->getType(), 12),
153+
CS->getOperand(1),
154+
CS->getOperand(2),
155+
};
156+
return ConstantStruct::get(EltTy,
157+
ArrayRef(CSVals, EltTy->getNumElements()));
158+
});
159+
EXPECT_EQ(1, getListSize(*M, arrayName()));
160+
ConstantArray *CA = dyn_cast<ConstantArray>(
161+
M->getGlobalVariable(arrayName())->getInitializer());
162+
ASSERT_NE(nullptr, CA);
163+
ConstantStruct *CS = dyn_cast<ConstantStruct>(CA->getOperand(0));
164+
ASSERT_NE(nullptr, CS);
165+
ConstantInt *Pri = dyn_cast<ConstantInt>(CS->getOperand(0));
166+
ASSERT_NE(nullptr, Pri);
167+
EXPECT_EQ(12u, Pri->getLimitedValue());
168+
}

0 commit comments

Comments
 (0)