[llvm] ae3e825 - [SandboxIR] Implement GlobalIFunc (#108622)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 13 13:03:13 PDT 2024
Author: vporpo
Date: 2024-09-13T13:03:10-07:00
New Revision: ae3e82585e61eca1ee6100e81cde68b608faf0a8
URL: https://github.com/llvm/llvm-project/commit/ae3e82585e61eca1ee6100e81cde68b608faf0a8
DIFF: https://github.com/llvm/llvm-project/commit/ae3e82585e61eca1ee6100e81cde68b608faf0a8.diff
LOG: [SandboxIR] Implement GlobalIFunc (#108622)
This patch implements sandboxir::GlobalIFunc mirroring
llvm::GlobalIFunc.
Added:
Modified:
llvm/include/llvm/SandboxIR/SandboxIR.h
llvm/lib/SandboxIR/SandboxIR.cpp
llvm/unittests/SandboxIR/SandboxIRTest.cpp
llvm/unittests/SandboxIR/TrackerTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 24c34466b4415e..624309def4df9f 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -128,6 +128,7 @@ class DSOLocalEquivalent;
class ConstantTokenNone;
class GlobalValue;
class GlobalObject;
+class GlobalIFunc;
class Context;
class Function;
class Instruction;
@@ -332,6 +333,7 @@ class Value {
friend class GlobalValue; // For `Val`.
friend class DSOLocalEquivalent; // For `Val`.
friend class GlobalObject; // For `Val`.
+ friend class GlobalIFunc; // For `Val`.
/// All values point to the context.
Context &Ctx;
@@ -1128,6 +1130,7 @@ class GlobalValue : public Constant {
friend class Context; // For constructor.
public:
+ using LinkageTypes = llvm::GlobalValue::LinkageTypes;
/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From) {
switch (From->getSubclassID()) {
@@ -1285,6 +1288,88 @@ class GlobalObject : public GlobalValue {
}
};
+/// Provides API functions, like getIterator() and getReverseIterator() to
+/// GlobalIFunc, Function, GlobalVariable and GlobalAlias. In LLVM IR these are
+/// provided by ilist_node.
+template <typename GlobalT, typename LLVMGlobalT, typename ParentT,
+ typename LLVMParentT>
+class GlobalWithNodeAPI : public ParentT {
+ /// Helper for mapped_iterator.
+ struct LLVMGVToGV {
+ Context &Ctx;
+ LLVMGVToGV(Context &Ctx) : Ctx(Ctx) {}
+ GlobalT &operator()(LLVMGlobalT &LLVMGV) const;
+ };
+
+public:
+ GlobalWithNodeAPI(Value::ClassID ID, LLVMParentT *C, Context &Ctx)
+ : ParentT(ID, C, Ctx) {}
+
+ // TODO: Missing getParent(). Should be added once Module is available.
+
+ using iterator = mapped_iterator<
+ decltype(static_cast<LLVMGlobalT *>(nullptr)->getIterator()), LLVMGVToGV>;
+ using reverse_iterator = mapped_iterator<
+ decltype(static_cast<LLVMGlobalT *>(nullptr)->getReverseIterator()),
+ LLVMGVToGV>;
+ iterator getIterator() const {
+ auto *LLVMGV = cast<LLVMGlobalT>(this->Val);
+ LLVMGVToGV ToGV(this->Ctx);
+ return map_iterator(LLVMGV->getIterator(), ToGV);
+ }
+ reverse_iterator getReverseIterator() const {
+ auto *LLVMGV = cast<LLVMGlobalT>(this->Val);
+ LLVMGVToGV ToGV(this->Ctx);
+ return map_iterator(LLVMGV->getReverseIterator(), ToGV);
+ }
+};
+
+class GlobalIFunc final
+ : public GlobalWithNodeAPI<GlobalIFunc, llvm::GlobalIFunc, GlobalObject,
+ llvm::GlobalObject> {
+ GlobalIFunc(llvm::GlobalObject *C, Context &Ctx)
+ : GlobalWithNodeAPI(ClassID::GlobalIFunc, C, Ctx) {}
+ friend class Context; // For constructor.
+
+public:
+ /// For isa/dyn_cast.
+ static bool classof(const sandboxir::Value *From) {
+ return From->getSubclassID() == ClassID::GlobalIFunc;
+ }
+
+ // TODO: Missing create() because we don't have a sandboxir::Module yet.
+
+ // TODO: Missing functions: copyAttributesFrom(), removeFromParent(),
+ // eraseFromParent()
+
+ void setResolver(Constant *Resolver);
+
+ Constant *getResolver() const;
+
+ // Return the resolver function after peeling off potential ConstantExpr
+ // indirection.
+ Function *getResolverFunction();
+ const Function *getResolverFunction() const {
+ return const_cast<GlobalIFunc *>(this)->getResolverFunction();
+ }
+
+ static bool isValidLinkage(LinkageTypes L) {
+ return llvm::GlobalIFunc::isValidLinkage(L);
+ }
+
+ // TODO: Missing applyAlongResolverPath().
+
+#ifndef NDEBUG
+ void verify() const override {
+ assert(isa<llvm::GlobalIFunc>(Val) && "Expected a GlobalIFunc!");
+ }
+ void dumpOS(raw_ostream &OS) const override {
+ dumpCommonPrefix(OS);
+ dumpCommonSuffix(OS);
+ }
+#endif
+};
+
class BlockAddress final : public Constant {
BlockAddress(llvm::BlockAddress *C, Context &Ctx)
: Constant(ClassID::BlockAddress, C, Ctx) {}
@@ -4219,7 +4304,8 @@ class Context {
size_t getNumValues() const { return LLVMValueToValueMap.size(); }
};
-class Function : public GlobalObject {
+class Function : public GlobalWithNodeAPI<Function, llvm::Function,
+ GlobalObject, llvm::GlobalObject> {
/// Helper for mapped_iterator.
struct LLVMBBToBB {
Context &Ctx;
@@ -4230,7 +4316,7 @@ class Function : public GlobalObject {
};
/// Use Context::createFunction() instead.
Function(llvm::Function *F, sandboxir::Context &Ctx)
- : GlobalObject(ClassID::Function, F, Ctx) {}
+ : GlobalWithNodeAPI(ClassID::Function, F, Ctx) {}
friend class Context; // For constructor.
public:
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 2f20fd3ff1dcc9..03d3e9e607f01a 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2519,6 +2519,39 @@ void GlobalObject::setSection(StringRef S) {
cast<llvm::GlobalObject>(Val)->setSection(S);
}
+template <typename GlobalT, typename LLVMGlobalT, typename ParentT,
+ typename LLVMParentT>
+GlobalT &GlobalWithNodeAPI<GlobalT, LLVMGlobalT, ParentT, LLVMParentT>::
+ LLVMGVToGV::operator()(LLVMGlobalT &LLVMGV) const {
+ return cast<GlobalT>(*Ctx.getValue(&LLVMGV));
+}
+
+namespace llvm::sandboxir {
+// Explicit instantiations.
+template class GlobalWithNodeAPI<GlobalIFunc, llvm::GlobalIFunc, GlobalObject,
+ llvm::GlobalObject>;
+template class GlobalWithNodeAPI<Function, llvm::Function, GlobalObject,
+ llvm::GlobalObject>;
+} // namespace llvm::sandboxir
+
+void GlobalIFunc::setResolver(Constant *Resolver) {
+ Ctx.getTracker()
+ .emplaceIfTracking<
+ GenericSetter<&GlobalIFunc::getResolver, &GlobalIFunc::setResolver>>(
+ this);
+ cast<llvm::GlobalIFunc>(Val)->setResolver(
+ cast<llvm::Constant>(Resolver->Val));
+}
+
+Constant *GlobalIFunc::getResolver() const {
+ return Ctx.getOrCreateConstant(cast<llvm::GlobalIFunc>(Val)->getResolver());
+}
+
+Function *GlobalIFunc::getResolverFunction() {
+ return cast<Function>(Ctx.getOrCreateConstant(
+ cast<llvm::GlobalIFunc>(Val)->getResolverFunction()));
+}
+
void GlobalValue::setUnnamedAddr(UnnamedAddr V) {
Ctx.getTracker()
.emplaceIfTracking<GenericSetter<&GlobalValue::getUnnamedAddr,
@@ -2727,6 +2760,10 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<Function>(
new Function(cast<llvm::Function>(C), *this));
break;
+ case llvm::Value::GlobalIFuncVal:
+ It->second = std::unique_ptr<GlobalIFunc>(
+ new GlobalIFunc(cast<llvm::GlobalIFunc>(C), *this));
+ break;
default:
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
break;
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index b1f3a6c0cf550a..3b80dbd8fb66e8 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -859,6 +859,84 @@ define void @foo() {
EXPECT_EQ(GO->canIncreaseAlignment(), LLVMGO->canIncreaseAlignment());
}
+TEST_F(SandboxIRTest, GlobalIFunc) {
+ parseIR(C, R"IR(
+declare external void @bar()
+ at ifunc0 = ifunc void(), ptr @foo
+ at ifunc1 = ifunc void(), ptr @foo
+define void @foo() {
+ call void @ifunc0()
+ call void @ifunc1()
+ call void @bar()
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ auto *LLVMBB = &*LLVMF.begin();
+ auto LLVMIt = LLVMBB->begin();
+ auto *LLVMCall0 = cast<llvm::CallInst>(&*LLVMIt++);
+ auto *LLVMIFunc0 = cast<llvm::GlobalIFunc>(LLVMCall0->getCalledOperand());
+
+ sandboxir::Context Ctx(C);
+
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto *BB = &*F.begin();
+ auto It = BB->begin();
+ auto *Call0 = cast<sandboxir::CallInst>(&*It++);
+ auto *Call1 = cast<sandboxir::CallInst>(&*It++);
+ auto *CallBar = cast<sandboxir::CallInst>(&*It++);
+ // Check classof(), creation.
+ auto *IFunc0 = cast<sandboxir::GlobalIFunc>(Call0->getCalledOperand());
+ auto *IFunc1 = cast<sandboxir::GlobalIFunc>(Call1->getCalledOperand());
+ auto *Bar = cast<sandboxir::Function>(CallBar->getCalledOperand());
+
+ // Check getIterator().
+ {
+ auto It0 = IFunc0->getIterator();
+ auto It1 = IFunc1->getIterator();
+ EXPECT_EQ(&*It0, IFunc0);
+ EXPECT_EQ(&*It1, IFunc1);
+ EXPECT_EQ(std::next(It0), It1);
+ EXPECT_EQ(std::prev(It1), It0);
+ EXPECT_EQ(&*std::next(It0), IFunc1);
+ EXPECT_EQ(&*std::prev(It1), IFunc0);
+ }
+ // Check getReverseIterator().
+ {
+ auto RevIt0 = IFunc0->getReverseIterator();
+ auto RevIt1 = IFunc1->getReverseIterator();
+ EXPECT_EQ(&*RevIt0, IFunc0);
+ EXPECT_EQ(&*RevIt1, IFunc1);
+ EXPECT_EQ(std::prev(RevIt0), RevIt1);
+ EXPECT_EQ(std::next(RevIt1), RevIt0);
+ EXPECT_EQ(&*std::prev(RevIt0), IFunc1);
+ EXPECT_EQ(&*std::next(RevIt1), IFunc0);
+ }
+
+ // Check setResolver(), getResolver().
+ EXPECT_EQ(IFunc0->getResolver(), Ctx.getValue(LLVMIFunc0->getResolver()));
+ auto *OrigResolver = IFunc0->getResolver();
+ auto *NewResolver = Bar;
+ EXPECT_NE(NewResolver, OrigResolver);
+ IFunc0->setResolver(NewResolver);
+ EXPECT_EQ(IFunc0->getResolver(), NewResolver);
+ IFunc0->setResolver(OrigResolver);
+ EXPECT_EQ(IFunc0->getResolver(), OrigResolver);
+ // Check getResolverFunction().
+ EXPECT_EQ(IFunc0->getResolverFunction(),
+ Ctx.getValue(LLVMIFunc0->getResolverFunction()));
+ // Check isValidLinkage().
+ for (auto L :
+ {GlobalValue::ExternalLinkage, GlobalValue::AvailableExternallyLinkage,
+ GlobalValue::LinkOnceAnyLinkage, GlobalValue::LinkOnceODRLinkage,
+ GlobalValue::WeakAnyLinkage, GlobalValue::WeakODRLinkage,
+ GlobalValue::AppendingLinkage, GlobalValue::InternalLinkage,
+ GlobalValue::PrivateLinkage, GlobalValue::ExternalWeakLinkage,
+ GlobalValue::CommonLinkage}) {
+ EXPECT_EQ(IFunc0->isValidLinkage(L), LLVMIFunc0->isValidLinkage(L));
+ }
+}
+
TEST_F(SandboxIRTest, BlockAddress) {
parseIR(C, R"IR(
define void @foo(ptr %ptr) {
@@ -1200,29 +1278,58 @@ define void @foo(i8 %v) {
TEST_F(SandboxIRTest, Function) {
parseIR(C, R"IR(
-define void @foo(i32 %arg0, i32 %arg1) {
+define void @foo0(i32 %arg0, i32 %arg1) {
bb0:
br label %bb1
bb1:
ret void
}
+define void @foo1() {
+ ret void
+}
+
)IR");
- llvm::Function *LLVMF = &*M->getFunction("foo");
- llvm::Argument *LLVMArg0 = LLVMF->getArg(0);
- llvm::Argument *LLVMArg1 = LLVMF->getArg(1);
+ llvm::Function *LLVMF0 = &*M->getFunction("foo0");
+ llvm::Function *LLVMF1 = &*M->getFunction("foo1");
+ llvm::Argument *LLVMArg0 = LLVMF0->getArg(0);
+ llvm::Argument *LLVMArg1 = LLVMF0->getArg(1);
sandboxir::Context Ctx(C);
- sandboxir::Function *F = Ctx.createFunction(LLVMF);
+ sandboxir::Function *F0 = Ctx.createFunction(LLVMF0);
+ sandboxir::Function *F1 = Ctx.createFunction(LLVMF1);
+
+ // Check getIterator().
+ {
+ auto It0 = F0->getIterator();
+ auto It1 = F1->getIterator();
+ EXPECT_EQ(&*It0, F0);
+ EXPECT_EQ(&*It1, F1);
+ EXPECT_EQ(std::next(It0), It1);
+ EXPECT_EQ(std::prev(It1), It0);
+ EXPECT_EQ(&*std::next(It0), F1);
+ EXPECT_EQ(&*std::prev(It1), F0);
+ }
+ // Check getReverseIterator().
+ {
+ auto RevIt0 = F0->getReverseIterator();
+ auto RevIt1 = F1->getReverseIterator();
+ EXPECT_EQ(&*RevIt0, F0);
+ EXPECT_EQ(&*RevIt1, F1);
+ EXPECT_EQ(std::prev(RevIt0), RevIt1);
+ EXPECT_EQ(std::next(RevIt1), RevIt0);
+ EXPECT_EQ(&*std::prev(RevIt0), F1);
+ EXPECT_EQ(&*std::next(RevIt1), F0);
+ }
// Check F arguments
- EXPECT_EQ(F->arg_size(), 2u);
- EXPECT_FALSE(F->arg_empty());
- EXPECT_EQ(F->getArg(0), Ctx.getValue(LLVMArg0));
- EXPECT_EQ(F->getArg(1), Ctx.getValue(LLVMArg1));
+ EXPECT_EQ(F0->arg_size(), 2u);
+ EXPECT_FALSE(F0->arg_empty());
+ EXPECT_EQ(F0->getArg(0), Ctx.getValue(LLVMArg0));
+ EXPECT_EQ(F0->getArg(1), Ctx.getValue(LLVMArg1));
// Check F.begin(), F.end(), Function::iterator
- llvm::BasicBlock *LLVMBB = &*LLVMF->begin();
- for (sandboxir::BasicBlock &BB : *F) {
+ llvm::BasicBlock *LLVMBB = &*LLVMF0->begin();
+ for (sandboxir::BasicBlock &BB : *F0) {
EXPECT_EQ(&BB, Ctx.getValue(LLVMBB));
LLVMBB = LLVMBB->getNextNode();
}
@@ -1232,17 +1339,17 @@ define void @foo(i32 %arg0, i32 %arg1) {
// Check F.dumpNameAndArgs()
std::string Buff;
raw_string_ostream BS(Buff);
- F->dumpNameAndArgs(BS);
- EXPECT_EQ(Buff, "void @foo(i32 %arg0, i32 %arg1)");
+ F0->dumpNameAndArgs(BS);
+ EXPECT_EQ(Buff, "void @foo0(i32 %arg0, i32 %arg1)");
}
{
// Check F.dump()
std::string Buff;
raw_string_ostream BS(Buff);
BS << "\n";
- F->dumpOS(BS);
+ F0->dumpOS(BS);
EXPECT_EQ(Buff, R"IR(
-void @foo(i32 %arg0, i32 %arg1) {
+void @foo0(i32 %arg0, i32 %arg1) {
bb0:
br label %bb1 ; SB4. (Br)
diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index 6454c54336e6aa..d4ff4fd6464e5c 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -1558,6 +1558,38 @@ define void @foo() {
EXPECT_EQ(GV->getVisibility(), OrigVisibility);
}
+TEST_F(TrackerTest, GlobalIFuncSetters) {
+ parseIR(C, R"IR(
+declare external void @bar()
+ at ifunc = ifunc void(), ptr @foo
+define void @foo() {
+ call void @ifunc()
+ call void @bar()
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto *BB = &*F.begin();
+ auto It = BB->begin();
+ auto *Call0 = cast<sandboxir::CallInst>(&*It++);
+ auto *Call1 = cast<sandboxir::CallInst>(&*It++);
+ // Check classof(), creation.
+ auto *IFunc = cast<sandboxir::GlobalIFunc>(Call0->getCalledOperand());
+ auto *Bar = cast<sandboxir::Function>(Call1->getCalledOperand());
+ // Check setResolver().
+ auto *OrigResolver = IFunc->getResolver();
+ auto *NewResolver = Bar;
+ EXPECT_NE(NewResolver, OrigResolver);
+ Ctx.save();
+ IFunc->setResolver(NewResolver);
+ EXPECT_EQ(IFunc->getResolver(), NewResolver);
+ Ctx.revert();
+ EXPECT_EQ(IFunc->getResolver(), OrigResolver);
+}
+
TEST_F(TrackerTest, SetVolatile) {
parseIR(C, R"IR(
define void @foo(ptr %arg0, i8 %val) {
More information about the llvm-commits
mailing list