[clang] [llvm] [OpenMP][clang] Register Vtables on device for indirect calls (PR #159856)

Joseph Huber via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 19 15:47:12 PDT 2025


================
@@ -112,13 +112,39 @@ setupIndirectCallTable(DeviceTy &Device, __tgt_device_image *Image,
   llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
   for (const auto &Entry : Entries) {
     if (Entry.Kind != llvm::object::OffloadKind::OFK_OpenMP ||
-        Entry.Size == 0 || !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT))
+        Entry.Size == 0 ||
+        !(Entry.Flags &
+          (OMP_DECLARE_TARGET_INDIRECT | OMP_DECLARE_TARGET_INDIRECT_VTABLE)))
       continue;
 
-    assert(Entry.Size == sizeof(void *) && "Global not a function pointer?");
-    auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
-
-    void *Ptr;
+    size_t PtrSize = sizeof(void *);
+    if (Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE) {
+      // This is a VTable entry, the current entry is the first index of the
+      // VTable and Entry.Size is the total size of the VTable. Unlike the
+      // indirect function case below, the Global is not of size Entry.Size and
+      // is instead of size PtrSize (sizeof(void*)).
+      void *Vtable;
+      void *res;
+      if (Device.RTL->get_global(Binary, PtrSize, Entry.SymbolName, &Vtable))
+        return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
+                                         "failed to load %s", Entry.SymbolName);
+
+      // HstPtr = Entry.Address;
+      if (Device.retrieveData(&res, Vtable, PtrSize, AsyncInfo))
+        return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
+                                         "failed to load %s", Entry.SymbolName);
+      // Calculate and emplace entire Vtable from first Vtable byte
+      for (uint64_t i = 0; i < Entry.Size / PtrSize; ++i) {
+        auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
+        HstPtr = (void *)((uintptr_t)Entry.Address + i * PtrSize);
+        DevPtr = (void *)((uintptr_t)res + i * PtrSize);
----------------
jhuber6 wrote:

C++ casts

https://github.com/llvm/llvm-project/pull/159856


More information about the llvm-commits mailing list