[llvm] aa78e94 - [Libomptarget] Support mapping indirect host calls to device functions

Joseph Huber via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 25 16:52:04 PDT 2023


Author: Joseph Huber
Date: 2023-08-25T18:51:56-05:00
New Revision: aa78e94b0bc66375de7f2383b4e39c07cd482104

URL: https://github.com/llvm/llvm-project/commit/aa78e94b0bc66375de7f2383b4e39c07cd482104
DIFF: https://github.com/llvm/llvm-project/commit/aa78e94b0bc66375de7f2383b4e39c07cd482104.diff

LOG: [Libomptarget] Support mapping indirect host calls to device functions

The changes in D157738 allowed for us to emit stub globals on the device
in the offloading entry section. These globals contain the addresses of
device functions and allow us to map host functions to their
corresponding device equivalent. This patch provides the initial support
required to build a table on the device to lookup the associated value.
This is done by finding these entries and creating a global table on the
device that can be searched with a simple binary search.

This requires an allocation, which supposedly should be automatically
freed at plugin shutdown. This includes a basic test which looks up device
pointers via a host pointer using the added function. This will need to be built
upon to provide full support for these calls in the runtime.

To support reverse offloading it would also be useful to provide a reverse table
that allows us to get host functions from device stubs.

Depends on D157738

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D157918

Added: 
    openmp/libomptarget/test/api/omp_indirect_call.c

Modified: 
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    openmp/libomptarget/DeviceRTL/include/Configuration.h
    openmp/libomptarget/DeviceRTL/src/Configuration.cpp
    openmp/libomptarget/DeviceRTL/src/Misc.cpp
    openmp/libomptarget/include/Environment.h
    openmp/libomptarget/include/omptarget.h
    openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
    openmp/libomptarget/src/rtl.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3641ce0dbe652c..ea1035f1907e49 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -327,7 +327,7 @@ class OffloadEntriesInfoManager {
     /// Mark the entry as having no declare target entry kind.
     OMPTargetGlobalVarEntryNone = 0x3,
     /// Mark the entry as a declare target indirect global.
-    OMPTargetGlobalVarEntryIndirect = 0x4,
+    OMPTargetGlobalVarEntryIndirect = 0x8,
   };
 
   /// Kind of device clause for declare target variables

diff  --git a/openmp/libomptarget/DeviceRTL/include/Configuration.h b/openmp/libomptarget/DeviceRTL/include/Configuration.h
index 068c0166845a74..508e2a55bd8e21 100644
--- a/openmp/libomptarget/DeviceRTL/include/Configuration.h
+++ b/openmp/libomptarget/DeviceRTL/include/Configuration.h
@@ -40,6 +40,12 @@ uint64_t getDynamicMemorySize();
 /// Returns the cycles per second of the device's fixed frequency clock.
 uint64_t getClockFrequency();
 
+/// Returns the pointer to the beginning of the indirect call table.
+void *getIndirectCallTablePtr();
+
+/// Returns the size of the indirect call table.
+uint64_t getIndirectCallTableSize();
+
 /// Return if debugging is enabled for the given debug kind.
 bool isDebugMode(DebugKind Level);
 

diff  --git a/openmp/libomptarget/DeviceRTL/src/Configuration.cpp b/openmp/libomptarget/DeviceRTL/src/Configuration.cpp
index 82c1e0ba096f0f..da1e252fc07693 100644
--- a/openmp/libomptarget/DeviceRTL/src/Configuration.cpp
+++ b/openmp/libomptarget/DeviceRTL/src/Configuration.cpp
@@ -50,6 +50,15 @@ uint64_t config::getClockFrequency() {
   return __omp_rtl_device_environment.ClockFrequency;
 }
 
+void *config::getIndirectCallTablePtr() {
+  return reinterpret_cast<void *>(
+      __omp_rtl_device_environment.IndirectCallTable);
+}
+
+uint64_t config::getIndirectCallTableSize() {
+  return __omp_rtl_device_environment.IndirectCallTableSize;
+}
+
 bool config::isDebugMode(config::DebugKind Kind) {
   return config::getDebugKind() & Kind;
 }

diff  --git a/openmp/libomptarget/DeviceRTL/src/Misc.cpp b/openmp/libomptarget/DeviceRTL/src/Misc.cpp
index 0c361fe61061de..87d568779b401e 100644
--- a/openmp/libomptarget/DeviceRTL/src/Misc.cpp
+++ b/openmp/libomptarget/DeviceRTL/src/Misc.cpp
@@ -69,6 +69,47 @@ double getWTime() {
 
 #pragma omp end declare variant
 
+/// Lookup a device-side function using a host pointer /p HstPtr using the table
+/// provided by the device plugin. The table is an ordered pair of host and
+/// device pointers sorted on the value of the host pointer.
+void *indirectCallLookup(void *HstPtr) {
+  if (!HstPtr)
+    return nullptr;
+
+  struct IndirectCallTable {
+    void *HstPtr;
+    void *DevPtr;
+  };
+  IndirectCallTable *Table =
+      reinterpret_cast<IndirectCallTable *>(config::getIndirectCallTablePtr());
+  uint64_t TableSize = config::getIndirectCallTableSize();
+
+  // If the table is empty we assume this is device pointer.
+  if (!Table || !TableSize)
+    return HstPtr;
+
+  uint32_t Left = 0;
+  uint32_t Right = TableSize;
+
+  // If the pointer is definitely not contained in the table we exit early.
+  if (HstPtr < Table[Left].HstPtr || HstPtr > Table[Right - 1].HstPtr)
+    return HstPtr;
+
+  while (Left != Right) {
+    uint32_t Current = Left + (Right - Left) / 2;
+    if (Table[Current].HstPtr == HstPtr)
+      return Table[Current].DevPtr;
+
+    if (HstPtr < Table[Current].HstPtr)
+      Right = Current;
+    else
+      Left = Current;
+  }
+
+  // If we searched the whole table and found nothing this is a device pointer.
+  return HstPtr;
+}
+
 } // namespace impl
 } // namespace ompx
 
@@ -84,6 +125,10 @@ int32_t __kmpc_cancel(IdentTy *, int32_t, int32_t) { return 0; }
 double omp_get_wtick(void) { return ompx::impl::getWTick(); }
 
 double omp_get_wtime(void) { return ompx::impl::getWTime(); }
+
+void *__llvm_omp_indirect_call_lookup(void *HstPtr) {
+  return ompx::impl::indirectCallLookup(HstPtr);
+}
 }
 
 ///}

diff  --git a/openmp/libomptarget/include/Environment.h b/openmp/libomptarget/include/Environment.h
index 094ad107461c93..2d291c4505a1fe 100644
--- a/openmp/libomptarget/include/Environment.h
+++ b/openmp/libomptarget/include/Environment.h
@@ -31,6 +31,8 @@ struct DeviceEnvironmentTy {
   uint32_t DeviceNum;
   uint32_t DynamicMemSize;
   uint64_t ClockFrequency;
+  uintptr_t IndirectCallTable;
+  uint64_t IndirectCallTableSize;
 };
 
 // NOTE: Please don't change the order of those members as their indices are

diff  --git a/openmp/libomptarget/include/omptarget.h b/openmp/libomptarget/include/omptarget.h
index f05c4015da5f2c..f87557a69eff27 100644
--- a/openmp/libomptarget/include/omptarget.h
+++ b/openmp/libomptarget/include/omptarget.h
@@ -83,13 +83,16 @@ enum tgt_map_type {
   OMP_TGT_MAPTYPE_MEMBER_OF       = 0xffff000000000000
 };
 
+/// Flags for offload entries.
 enum OpenMPOffloadingDeclareTargetFlags {
-  /// Mark the entry as having a 'link' attribute.
+  /// Mark the entry global as having a 'link' attribute.
   OMP_DECLARE_TARGET_LINK = 0x01,
-  /// Mark the entry as being a global constructor.
+  /// Mark the entry kernel as being a global constructor.
   OMP_DECLARE_TARGET_CTOR = 0x02,
-  /// Mark the entry as being a global destructor.
-  OMP_DECLARE_TARGET_DTOR = 0x04
+  /// Mark the entry kernel as being a global destructor.
+  OMP_DECLARE_TARGET_DTOR = 0x04,
+  /// Mark the entry global as being an indirectly callable function.
+  OMP_DECLARE_TARGET_INDIRECT = 0x08
 };
 
 enum OpenMPOffloadingRequiresDirFlags {

diff  --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
index 0f2bb07818039f..0e61a49433a6d2 100644
--- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
@@ -267,6 +267,53 @@ struct RecordReplayTy {
 
 } RecordReplay;
 
+// Extract the mapping of host function pointers to device function pointers
+// from the entry table. Functions marked as 'indirect' in OpenMP will have
+// offloading entries generated for them which map the host's function pointer
+// to a global containing the corresponding function pointer on the device.
+static Expected<std::pair<void *, uint64_t>>
+setupIndirectCallTable(GenericPluginTy &Plugin, GenericDeviceTy &Device,
+                       DeviceImageTy &Image) {
+  GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
+
+  llvm::ArrayRef<__tgt_offload_entry> Entries(Image.getTgtImage()->EntriesBegin,
+                                              Image.getTgtImage()->EntriesEnd);
+  llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
+  for (const auto &Entry : Entries) {
+    if (Entry.size == 0 || !(Entry.flags & OMP_DECLARE_TARGET_INDIRECT))
+      continue;
+
+    assert(Entry.size == sizeof(void *) && "Global not a function pointer?");
+    auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
+
+    GlobalTy DeviceGlobal(Entry.name, Entry.size);
+    if (auto Err =
+            Handler.getGlobalMetadataFromDevice(Device, Image, DeviceGlobal))
+      return std::move(Err);
+
+    HstPtr = Entry.addr;
+    if (auto Err = Device.dataRetrieve(&DevPtr, DeviceGlobal.getPtr(),
+                                       Entry.size, nullptr))
+      return std::move(Err);
+  }
+
+  // If we do not have any indirect globals we exit early.
+  if (IndirectCallTable.empty())
+    return std::pair{nullptr, 0};
+
+  // Sort the array to allow for more efficient lookup of device pointers.
+  llvm::sort(IndirectCallTable,
+             [](const auto &x, const auto &y) { return x.first < y.first; });
+
+  uint64_t TableSize =
+      IndirectCallTable.size() * sizeof(std::pair<void *, void *>);
+  void *DevicePtr = Device.allocate(TableSize, nullptr, TARGET_ALLOC_DEVICE);
+  if (auto Err = Device.dataSubmit(DevicePtr, IndirectCallTable.data(),
+                                   TableSize, nullptr))
+    return std::move(Err);
+  return std::pair<void *, uint64_t>(DevicePtr, IndirectCallTable.size());
+}
+
 AsyncInfoWrapperTy::AsyncInfoWrapperTy(GenericDeviceTy &Device,
                                        __tgt_async_info *AsyncInfoPtr)
     : Device(Device),
@@ -626,6 +673,11 @@ Error GenericDeviceTy::setupDeviceEnvironment(GenericPluginTy &Plugin,
   if (!shouldSetupDeviceEnvironment())
     return Plugin::success();
 
+  // Obtain a table mapping host function pointers to device function pointers.
+  auto CallTablePairOrErr = setupIndirectCallTable(Plugin, *this, Image);
+  if (!CallTablePairOrErr)
+    return CallTablePairOrErr.takeError();
+
   DeviceEnvironmentTy DeviceEnvironment;
   DeviceEnvironment.DebugKind = OMPX_DebugKind;
   DeviceEnvironment.NumDevices = Plugin.getNumDevices();
@@ -633,6 +685,9 @@ Error GenericDeviceTy::setupDeviceEnvironment(GenericPluginTy &Plugin,
   DeviceEnvironment.DeviceNum = DeviceId;
   DeviceEnvironment.DynamicMemSize = OMPX_SharedMemorySize;
   DeviceEnvironment.ClockFrequency = getClockFrequency();
+  DeviceEnvironment.IndirectCallTable =
+      reinterpret_cast<uintptr_t>(CallTablePairOrErr->first);
+  DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second;
 
   // Create the metainfo of the device environment global.
   GlobalTy DevEnvGlobal("__omp_rtl_device_environment",

diff  --git a/openmp/libomptarget/src/rtl.cpp b/openmp/libomptarget/src/rtl.cpp
index ed3e86075f8582..6623057f394b08 100644
--- a/openmp/libomptarget/src/rtl.cpp
+++ b/openmp/libomptarget/src/rtl.cpp
@@ -303,6 +303,10 @@ static void registerGlobalCtorsDtorsForImage(__tgt_bin_desc *Desc,
     Device.HasPendingGlobals = true;
     for (__tgt_offload_entry *Entry = Img->EntriesBegin;
          Entry != Img->EntriesEnd; ++Entry) {
+      // Globals are not callable and use a 
diff erent set of flags.
+      if (Entry->size != 0)
+        continue;
+
       if (Entry->flags & OMP_DECLARE_TARGET_CTOR) {
         DP("Adding ctor " DPxMOD " to the pending list.\n",
            DPxPTR(Entry->addr));

diff  --git a/openmp/libomptarget/test/api/omp_indirect_call.c b/openmp/libomptarget/test/api/omp_indirect_call.c
new file mode 100644
index 00000000000000..ac0febf7854dad
--- /dev/null
+++ b/openmp/libomptarget/test/api/omp_indirect_call.c
@@ -0,0 +1,47 @@
+// RUN: %libomptarget-compile-run-and-check-generic
+
+#include <assert.h>
+#include <stdio.h>
+
+#pragma omp begin declare variant match(device = {kind(gpu)})
+// Provided by the runtime.
+void *__llvm_omp_indirect_call_lookup(void *host_ptr);
+#pragma omp declare target to(__llvm_omp_indirect_call_lookup)                 \
+    device_type(nohost)
+#pragma omp end declare variant
+
+#pragma omp begin declare variant match(device = {kind(cpu)})
+// We assume unified addressing on the CPU target.
+void *__llvm_omp_indirect_call_lookup(void *host_ptr) { return host_ptr; }
+#pragma omp end declare variant
+
+#pragma omp begin declare target indirect
+void foo(int *x) { *x = *x + 1; }
+void bar(int *x) { *x = *x + 1; }
+void baz(int *x) { *x = *x + 1; }
+#pragma omp end declare target
+
+int main() {
+  void *foo_ptr = foo;
+  void *bar_ptr = bar;
+  void *baz_ptr = baz;
+
+  int count = 0;
+  void *foo_res;
+  void *bar_res;
+  void *baz_res;
+#pragma omp target map(to : foo_ptr, bar_ptr, baz_ptr) map(tofrom : count)
+  {
+    foo_res = __llvm_omp_indirect_call_lookup(foo_ptr);
+    ((void (*)(int *))foo_res)(&count);
+    bar_res = __llvm_omp_indirect_call_lookup(bar_ptr);
+    ((void (*)(int *))bar_res)(&count);
+    baz_res = __llvm_omp_indirect_call_lookup(baz_ptr);
+    ((void (*)(int *))baz_res)(&count);
+  }
+
+  assert(count == 3 && "Calling failed");
+
+  // CHECK: PASS
+  printf("PASS\n");
+}


        


More information about the llvm-commits mailing list