[llvm] ae3e825 - [SandboxIR] Implement GlobalIFunc (#108622)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 13 13:03:13 PDT 2024


Author: vporpo
Date: 2024-09-13T13:03:10-07:00
New Revision: ae3e82585e61eca1ee6100e81cde68b608faf0a8

URL: https://github.com/llvm/llvm-project/commit/ae3e82585e61eca1ee6100e81cde68b608faf0a8
DIFF: https://github.com/llvm/llvm-project/commit/ae3e82585e61eca1ee6100e81cde68b608faf0a8.diff

LOG: [SandboxIR] Implement GlobalIFunc (#108622)

This patch implements sandboxir::GlobalIFunc mirroring
llvm::GlobalIFunc.

Added: 
    

Modified: 
    llvm/include/llvm/SandboxIR/SandboxIR.h
    llvm/lib/SandboxIR/SandboxIR.cpp
    llvm/unittests/SandboxIR/SandboxIRTest.cpp
    llvm/unittests/SandboxIR/TrackerTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 24c34466b4415e..624309def4df9f 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -128,6 +128,7 @@ class DSOLocalEquivalent;
 class ConstantTokenNone;
 class GlobalValue;
 class GlobalObject;
+class GlobalIFunc;
 class Context;
 class Function;
 class Instruction;
@@ -332,6 +333,7 @@ class Value {
   friend class GlobalValue;           // For `Val`.
   friend class DSOLocalEquivalent;    // For `Val`.
   friend class GlobalObject;          // For `Val`.
+  friend class GlobalIFunc;           // For `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -1128,6 +1130,7 @@ class GlobalValue : public Constant {
   friend class Context; // For constructor.
 
 public:
+  using LinkageTypes = llvm::GlobalValue::LinkageTypes;
   /// For isa/dyn_cast.
   static bool classof(const sandboxir::Value *From) {
     switch (From->getSubclassID()) {
@@ -1285,6 +1288,88 @@ class GlobalObject : public GlobalValue {
   }
 };
 
+/// Provides API functions, like getIterator() and getReverseIterator() to
+/// GlobalIFunc, Function, GlobalVariable and GlobalAlias. In LLVM IR these are
+/// provided by ilist_node.
+template <typename GlobalT, typename LLVMGlobalT, typename ParentT,
+          typename LLVMParentT>
+class GlobalWithNodeAPI : public ParentT {
+  /// Helper for mapped_iterator.
+  struct LLVMGVToGV {
+    Context &Ctx;
+    LLVMGVToGV(Context &Ctx) : Ctx(Ctx) {}
+    GlobalT &operator()(LLVMGlobalT &LLVMGV) const;
+  };
+
+public:
+  GlobalWithNodeAPI(Value::ClassID ID, LLVMParentT *C, Context &Ctx)
+      : ParentT(ID, C, Ctx) {}
+
+  // TODO: Missing getParent(). Should be added once Module is available.
+
+  using iterator = mapped_iterator<
+      decltype(static_cast<LLVMGlobalT *>(nullptr)->getIterator()), LLVMGVToGV>;
+  using reverse_iterator = mapped_iterator<
+      decltype(static_cast<LLVMGlobalT *>(nullptr)->getReverseIterator()),
+      LLVMGVToGV>;
+  iterator getIterator() const {
+    auto *LLVMGV = cast<LLVMGlobalT>(this->Val);
+    LLVMGVToGV ToGV(this->Ctx);
+    return map_iterator(LLVMGV->getIterator(), ToGV);
+  }
+  reverse_iterator getReverseIterator() const {
+    auto *LLVMGV = cast<LLVMGlobalT>(this->Val);
+    LLVMGVToGV ToGV(this->Ctx);
+    return map_iterator(LLVMGV->getReverseIterator(), ToGV);
+  }
+};
+
+class GlobalIFunc final
+    : public GlobalWithNodeAPI<GlobalIFunc, llvm::GlobalIFunc, GlobalObject,
+                               llvm::GlobalObject> {
+  GlobalIFunc(llvm::GlobalObject *C, Context &Ctx)
+      : GlobalWithNodeAPI(ClassID::GlobalIFunc, C, Ctx) {}
+  friend class Context; // For constructor.
+
+public:
+  /// For isa/dyn_cast.
+  static bool classof(const sandboxir::Value *From) {
+    return From->getSubclassID() == ClassID::GlobalIFunc;
+  }
+
+  // TODO: Missing create() because we don't have a sandboxir::Module yet.
+
+  // TODO: Missing functions: copyAttributesFrom(), removeFromParent(),
+  // eraseFromParent()
+
+  void setResolver(Constant *Resolver);
+
+  Constant *getResolver() const;
+
+  // Return the resolver function after peeling off potential ConstantExpr
+  // indirection.
+  Function *getResolverFunction();
+  const Function *getResolverFunction() const {
+    return const_cast<GlobalIFunc *>(this)->getResolverFunction();
+  }
+
+  static bool isValidLinkage(LinkageTypes L) {
+    return llvm::GlobalIFunc::isValidLinkage(L);
+  }
+
+  // TODO: Missing applyAlongResolverPath().
+
+#ifndef NDEBUG
+  void verify() const override {
+    assert(isa<llvm::GlobalIFunc>(Val) && "Expected a GlobalIFunc!");
+  }
+  void dumpOS(raw_ostream &OS) const override {
+    dumpCommonPrefix(OS);
+    dumpCommonSuffix(OS);
+  }
+#endif
+};
+
 class BlockAddress final : public Constant {
   BlockAddress(llvm::BlockAddress *C, Context &Ctx)
       : Constant(ClassID::BlockAddress, C, Ctx) {}
@@ -4219,7 +4304,8 @@ class Context {
   size_t getNumValues() const { return LLVMValueToValueMap.size(); }
 };
 
-class Function : public GlobalObject {
+class Function : public GlobalWithNodeAPI<Function, llvm::Function,
+                                          GlobalObject, llvm::GlobalObject> {
   /// Helper for mapped_iterator.
   struct LLVMBBToBB {
     Context &Ctx;
@@ -4230,7 +4316,7 @@ class Function : public GlobalObject {
   };
   /// Use Context::createFunction() instead.
   Function(llvm::Function *F, sandboxir::Context &Ctx)
-      : GlobalObject(ClassID::Function, F, Ctx) {}
+      : GlobalWithNodeAPI(ClassID::Function, F, Ctx) {}
   friend class Context; // For constructor.
 
 public:

diff  --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 2f20fd3ff1dcc9..03d3e9e607f01a 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2519,6 +2519,39 @@ void GlobalObject::setSection(StringRef S) {
   cast<llvm::GlobalObject>(Val)->setSection(S);
 }
 
+template <typename GlobalT, typename LLVMGlobalT, typename ParentT,
+          typename LLVMParentT>
+GlobalT &GlobalWithNodeAPI<GlobalT, LLVMGlobalT, ParentT, LLVMParentT>::
+    LLVMGVToGV::operator()(LLVMGlobalT &LLVMGV) const {
+  return cast<GlobalT>(*Ctx.getValue(&LLVMGV));
+}
+
+namespace llvm::sandboxir {
+// Explicit instantiations.
+template class GlobalWithNodeAPI<GlobalIFunc, llvm::GlobalIFunc, GlobalObject,
+                                 llvm::GlobalObject>;
+template class GlobalWithNodeAPI<Function, llvm::Function, GlobalObject,
+                                 llvm::GlobalObject>;
+} // namespace llvm::sandboxir
+
+void GlobalIFunc::setResolver(Constant *Resolver) {
+  Ctx.getTracker()
+      .emplaceIfTracking<
+          GenericSetter<&GlobalIFunc::getResolver, &GlobalIFunc::setResolver>>(
+          this);
+  cast<llvm::GlobalIFunc>(Val)->setResolver(
+      cast<llvm::Constant>(Resolver->Val));
+}
+
+Constant *GlobalIFunc::getResolver() const {
+  return Ctx.getOrCreateConstant(cast<llvm::GlobalIFunc>(Val)->getResolver());
+}
+
+Function *GlobalIFunc::getResolverFunction() {
+  return cast<Function>(Ctx.getOrCreateConstant(
+      cast<llvm::GlobalIFunc>(Val)->getResolverFunction()));
+}
+
 void GlobalValue::setUnnamedAddr(UnnamedAddr V) {
   Ctx.getTracker()
       .emplaceIfTracking<GenericSetter<&GlobalValue::getUnnamedAddr,
@@ -2727,6 +2760,10 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
       It->second = std::unique_ptr<Function>(
           new Function(cast<llvm::Function>(C), *this));
       break;
+    case llvm::Value::GlobalIFuncVal:
+      It->second = std::unique_ptr<GlobalIFunc>(
+          new GlobalIFunc(cast<llvm::GlobalIFunc>(C), *this));
+      break;
     default:
       It->second = std::unique_ptr<Constant>(new Constant(C, *this));
       break;

diff  --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index b1f3a6c0cf550a..3b80dbd8fb66e8 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -859,6 +859,84 @@ define void @foo() {
   EXPECT_EQ(GO->canIncreaseAlignment(), LLVMGO->canIncreaseAlignment());
 }
 
+TEST_F(SandboxIRTest, GlobalIFunc) {
+  parseIR(C, R"IR(
+declare external void @bar()
+ at ifunc0 = ifunc void(), ptr @foo
+ at ifunc1 = ifunc void(), ptr @foo
+define void @foo() {
+  call void @ifunc0()
+  call void @ifunc1()
+  call void @bar()
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  auto *LLVMBB = &*LLVMF.begin();
+  auto LLVMIt = LLVMBB->begin();
+  auto *LLVMCall0 = cast<llvm::CallInst>(&*LLVMIt++);
+  auto *LLVMIFunc0 = cast<llvm::GlobalIFunc>(LLVMCall0->getCalledOperand());
+
+  sandboxir::Context Ctx(C);
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *Call0 = cast<sandboxir::CallInst>(&*It++);
+  auto *Call1 = cast<sandboxir::CallInst>(&*It++);
+  auto *CallBar = cast<sandboxir::CallInst>(&*It++);
+  // Check classof(), creation.
+  auto *IFunc0 = cast<sandboxir::GlobalIFunc>(Call0->getCalledOperand());
+  auto *IFunc1 = cast<sandboxir::GlobalIFunc>(Call1->getCalledOperand());
+  auto *Bar = cast<sandboxir::Function>(CallBar->getCalledOperand());
+
+  // Check getIterator().
+  {
+    auto It0 = IFunc0->getIterator();
+    auto It1 = IFunc1->getIterator();
+    EXPECT_EQ(&*It0, IFunc0);
+    EXPECT_EQ(&*It1, IFunc1);
+    EXPECT_EQ(std::next(It0), It1);
+    EXPECT_EQ(std::prev(It1), It0);
+    EXPECT_EQ(&*std::next(It0), IFunc1);
+    EXPECT_EQ(&*std::prev(It1), IFunc0);
+  }
+  // Check getReverseIterator().
+  {
+    auto RevIt0 = IFunc0->getReverseIterator();
+    auto RevIt1 = IFunc1->getReverseIterator();
+    EXPECT_EQ(&*RevIt0, IFunc0);
+    EXPECT_EQ(&*RevIt1, IFunc1);
+    EXPECT_EQ(std::prev(RevIt0), RevIt1);
+    EXPECT_EQ(std::next(RevIt1), RevIt0);
+    EXPECT_EQ(&*std::prev(RevIt0), IFunc1);
+    EXPECT_EQ(&*std::next(RevIt1), IFunc0);
+  }
+
+  // Check setResolver(), getResolver().
+  EXPECT_EQ(IFunc0->getResolver(), Ctx.getValue(LLVMIFunc0->getResolver()));
+  auto *OrigResolver = IFunc0->getResolver();
+  auto *NewResolver = Bar;
+  EXPECT_NE(NewResolver, OrigResolver);
+  IFunc0->setResolver(NewResolver);
+  EXPECT_EQ(IFunc0->getResolver(), NewResolver);
+  IFunc0->setResolver(OrigResolver);
+  EXPECT_EQ(IFunc0->getResolver(), OrigResolver);
+  // Check getResolverFunction().
+  EXPECT_EQ(IFunc0->getResolverFunction(),
+            Ctx.getValue(LLVMIFunc0->getResolverFunction()));
+  // Check isValidLinkage().
+  for (auto L :
+       {GlobalValue::ExternalLinkage, GlobalValue::AvailableExternallyLinkage,
+        GlobalValue::LinkOnceAnyLinkage, GlobalValue::LinkOnceODRLinkage,
+        GlobalValue::WeakAnyLinkage, GlobalValue::WeakODRLinkage,
+        GlobalValue::AppendingLinkage, GlobalValue::InternalLinkage,
+        GlobalValue::PrivateLinkage, GlobalValue::ExternalWeakLinkage,
+        GlobalValue::CommonLinkage}) {
+    EXPECT_EQ(IFunc0->isValidLinkage(L), LLVMIFunc0->isValidLinkage(L));
+  }
+}
+
 TEST_F(SandboxIRTest, BlockAddress) {
   parseIR(C, R"IR(
 define void @foo(ptr %ptr) {
@@ -1200,29 +1278,58 @@ define void @foo(i8 %v) {
 
 TEST_F(SandboxIRTest, Function) {
   parseIR(C, R"IR(
-define void @foo(i32 %arg0, i32 %arg1) {
+define void @foo0(i32 %arg0, i32 %arg1) {
 bb0:
   br label %bb1
 bb1:
   ret void
 }
+define void @foo1() {
+  ret void
+}
+
 )IR");
-  llvm::Function *LLVMF = &*M->getFunction("foo");
-  llvm::Argument *LLVMArg0 = LLVMF->getArg(0);
-  llvm::Argument *LLVMArg1 = LLVMF->getArg(1);
+  llvm::Function *LLVMF0 = &*M->getFunction("foo0");
+  llvm::Function *LLVMF1 = &*M->getFunction("foo1");
+  llvm::Argument *LLVMArg0 = LLVMF0->getArg(0);
+  llvm::Argument *LLVMArg1 = LLVMF0->getArg(1);
 
   sandboxir::Context Ctx(C);
-  sandboxir::Function *F = Ctx.createFunction(LLVMF);
+  sandboxir::Function *F0 = Ctx.createFunction(LLVMF0);
+  sandboxir::Function *F1 = Ctx.createFunction(LLVMF1);
+
+  // Check getIterator().
+  {
+    auto It0 = F0->getIterator();
+    auto It1 = F1->getIterator();
+    EXPECT_EQ(&*It0, F0);
+    EXPECT_EQ(&*It1, F1);
+    EXPECT_EQ(std::next(It0), It1);
+    EXPECT_EQ(std::prev(It1), It0);
+    EXPECT_EQ(&*std::next(It0), F1);
+    EXPECT_EQ(&*std::prev(It1), F0);
+  }
+  // Check getReverseIterator().
+  {
+    auto RevIt0 = F0->getReverseIterator();
+    auto RevIt1 = F1->getReverseIterator();
+    EXPECT_EQ(&*RevIt0, F0);
+    EXPECT_EQ(&*RevIt1, F1);
+    EXPECT_EQ(std::prev(RevIt0), RevIt1);
+    EXPECT_EQ(std::next(RevIt1), RevIt0);
+    EXPECT_EQ(&*std::prev(RevIt0), F1);
+    EXPECT_EQ(&*std::next(RevIt1), F0);
+  }
 
   // Check F arguments
-  EXPECT_EQ(F->arg_size(), 2u);
-  EXPECT_FALSE(F->arg_empty());
-  EXPECT_EQ(F->getArg(0), Ctx.getValue(LLVMArg0));
-  EXPECT_EQ(F->getArg(1), Ctx.getValue(LLVMArg1));
+  EXPECT_EQ(F0->arg_size(), 2u);
+  EXPECT_FALSE(F0->arg_empty());
+  EXPECT_EQ(F0->getArg(0), Ctx.getValue(LLVMArg0));
+  EXPECT_EQ(F0->getArg(1), Ctx.getValue(LLVMArg1));
 
   // Check F.begin(), F.end(), Function::iterator
-  llvm::BasicBlock *LLVMBB = &*LLVMF->begin();
-  for (sandboxir::BasicBlock &BB : *F) {
+  llvm::BasicBlock *LLVMBB = &*LLVMF0->begin();
+  for (sandboxir::BasicBlock &BB : *F0) {
     EXPECT_EQ(&BB, Ctx.getValue(LLVMBB));
     LLVMBB = LLVMBB->getNextNode();
   }
@@ -1232,17 +1339,17 @@ define void @foo(i32 %arg0, i32 %arg1) {
     // Check F.dumpNameAndArgs()
     std::string Buff;
     raw_string_ostream BS(Buff);
-    F->dumpNameAndArgs(BS);
-    EXPECT_EQ(Buff, "void @foo(i32 %arg0, i32 %arg1)");
+    F0->dumpNameAndArgs(BS);
+    EXPECT_EQ(Buff, "void @foo0(i32 %arg0, i32 %arg1)");
   }
   {
     // Check F.dump()
     std::string Buff;
     raw_string_ostream BS(Buff);
     BS << "\n";
-    F->dumpOS(BS);
+    F0->dumpOS(BS);
     EXPECT_EQ(Buff, R"IR(
-void @foo(i32 %arg0, i32 %arg1) {
+void @foo0(i32 %arg0, i32 %arg1) {
 bb0:
   br label %bb1 ; SB4. (Br)
 

diff  --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp
index 6454c54336e6aa..d4ff4fd6464e5c 100644
--- a/llvm/unittests/SandboxIR/TrackerTest.cpp
+++ b/llvm/unittests/SandboxIR/TrackerTest.cpp
@@ -1558,6 +1558,38 @@ define void @foo() {
   EXPECT_EQ(GV->getVisibility(), OrigVisibility);
 }
 
+TEST_F(TrackerTest, GlobalIFuncSetters) {
+  parseIR(C, R"IR(
+declare external void @bar()
+ at ifunc = ifunc void(), ptr @foo
+define void @foo() {
+  call void @ifunc()
+  call void @bar()
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto *BB = &*F.begin();
+  auto It = BB->begin();
+  auto *Call0 = cast<sandboxir::CallInst>(&*It++);
+  auto *Call1 = cast<sandboxir::CallInst>(&*It++);
+  // Check classof(), creation.
+  auto *IFunc = cast<sandboxir::GlobalIFunc>(Call0->getCalledOperand());
+  auto *Bar = cast<sandboxir::Function>(Call1->getCalledOperand());
+  // Check setResolver().
+  auto *OrigResolver = IFunc->getResolver();
+  auto *NewResolver = Bar;
+  EXPECT_NE(NewResolver, OrigResolver);
+  Ctx.save();
+  IFunc->setResolver(NewResolver);
+  EXPECT_EQ(IFunc->getResolver(), NewResolver);
+  Ctx.revert();
+  EXPECT_EQ(IFunc->getResolver(), OrigResolver);
+}
+
 TEST_F(TrackerTest, SetVolatile) {
   parseIR(C, R"IR(
 define void @foo(ptr %arg0, i8 %val) {


        


More information about the llvm-commits mailing list