[llvm] [Offload] Add MPI Proxy Plugin (PR #114574)
Jhonatan Cléto via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 1 10:00:42 PDT 2024
https://github.com/cl3to created https://github.com/llvm/llvm-project/pull/114574
This patch introduces a new Offload plugin built on the existing interface to enable the offloading of computational tasks to remote accelerator devices via an MPI Proxy Layer. It improves the efficiency of kernel launches and data transfers by utilizing an event-driven architecture with non-blocking MPI communications and C++20 coroutines, facilitating asynchronous operations.
With this new MPI Plugin, users can offload OpenMP target regions to remote devices seamlessly, as if they were local. Any remote device compatible with an Offload Plugin can be used with the MPI Plugin. Currently, we have tested this plugin with X86_64 and CUDA devices, but it is expected to work with AMD GPUs as well.
Currently, the plugin lacks support for the following features:
- Unified/shared memory allocation/free operations
- Device operations that depend on host function calls outside target regions, such as:
- RPC calls for user-defined functions
- OMPT callbacks
Programs using the MPI Plugin are compiled like standard OpenMP target programs with clang, as shown in this example:
```sh
clang -fopenmp -fopenmp-targets=nvptx64 -o app app.c
```
The MPI Plugin uses a binary, `llvm-offload-mpi-proxy-device`, to execute target operations on the remote device. Thus, to offload tasks to an MPI device, the program must be executed with the Single Program Multiple Data (SPMD) model of an MPI launcher, as shown here:
```sh
mpirun -np N llvm-offload-mpi-proxy-device : -np 1 ./app
```
**Note**: Only one instance of the OpenMP program (`-np 1 ./app`) should be created. If multiple instances are launched, the plugin will not function correctly. Additionally, due to a design constraint, the host process (`app`) must have the rank `WorldSize - 1` for MPI communication to work correctly. Consequently, it's essential to execute the `mpirun` command in the order shown in the previous example.
At runtime, the number of devices returned by the `omp_get_num_devices()` call will be the sum of local devices and all devices available in each `llvm-offload-mpi-proxy-device` instance.
To compile the plugin and run the test suite, an environment with an installed MPI implementation (such as OpenMPI or MPICH) is required.
We currently lack resources to add a dedicated Buildbot for this plugin, so we request that existing Buildbots be updated to support it.
>From f354f6293493ec5b703e609617578f5ea587d14d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jhonatan=20Cl=C3=A9to?= <j256444 at dac.unicamp.br>
Date: Fri, 24 May 2024 11:16:07 -0300
Subject: [PATCH] [Offload] Add MPI Proxy Plugin
Co-authored-by: Guilherme Valarini <guilherme.a.valarini at gmail.com>
---
offload/CMakeLists.txt | 6 +-
.../common/include/PluginInterface.h | 115 +-
offload/plugins-nextgen/host/src/rtl.cpp | 2 +-
offload/plugins-nextgen/mpi/CMakeLists.txt | 134 ++
.../mpi/event_system/CMakeLists.txt | 29 +
.../mpi/event_system/EventSystem.cpp | 848 +++++++++++
.../mpi/event_system/EventSystem.h | 556 +++++++
.../plugins-nextgen/mpi/src/ProxyDevice.cpp | 1071 ++++++++++++++
.../mpi/src/RemotePluginManager.cpp | 104 ++
.../mpi/src/RemotePluginManager.h | 123 ++
.../mpi/src/RemoteTargets.def.in | 20 +
offload/plugins-nextgen/mpi/src/rtl.cpp | 1309 +++++++++++++++++
offload/src/PluginManager.cpp | 7 +
offload/test/api/omp_device_managed_memory.c | 2 +
.../api/omp_device_managed_memory_alloc.c | 2 +
offload/test/libc/host_call.c | 2 +
offload/test/lit.cfg | 12 +-
.../target_derefence_array_pointrs.cpp | 1 +
offload/test/mapping/target_has_device_addr.c | 1 +
offload/test/mapping/target_uses_allocator.c | 1 +
offload/test/offloading/bug49334.cpp | 2 +-
offload/test/offloading/bug64959.c | 1 +
.../struct_mapping_with_pointers.cpp | 1 +
.../offloading/target_critical_region.cpp | 1 +
offload/test/offloading/thread_limit.c | 1 +
offload/test/sanitizer/kernel_crash.c | 1 +
offload/test/sanitizer/kernel_crash_async.c | 1 +
offload/test/sanitizer/kernel_crash_many.c | 1 +
offload/test/sanitizer/kernel_crash_single.c | 1 +
offload/test/sanitizer/kernel_trap.c | 1 +
offload/test/sanitizer/kernel_trap.cpp | 14 +-
offload/test/sanitizer/kernel_trap_async.c | 1 +
offload/test/sanitizer/kernel_trap_many.c | 1 +
33 files changed, 4306 insertions(+), 66 deletions(-)
create mode 100644 offload/plugins-nextgen/mpi/CMakeLists.txt
create mode 100644 offload/plugins-nextgen/mpi/event_system/CMakeLists.txt
create mode 100644 offload/plugins-nextgen/mpi/event_system/EventSystem.cpp
create mode 100644 offload/plugins-nextgen/mpi/event_system/EventSystem.h
create mode 100644 offload/plugins-nextgen/mpi/src/ProxyDevice.cpp
create mode 100644 offload/plugins-nextgen/mpi/src/RemotePluginManager.cpp
create mode 100644 offload/plugins-nextgen/mpi/src/RemotePluginManager.h
create mode 100644 offload/plugins-nextgen/mpi/src/RemoteTargets.def.in
create mode 100644 offload/plugins-nextgen/mpi/src/rtl.cpp
diff --git a/offload/CMakeLists.txt b/offload/CMakeLists.txt
index 9b771d1116ee38..e01070cca652df 100644
--- a/offload/CMakeLists.txt
+++ b/offload/CMakeLists.txt
@@ -139,7 +139,7 @@ if(DEFINED LIBOMPTARGET_BUILD_CUDA_PLUGIN OR
message(WARNING "Option removed, use 'LIBOMPTARGET_PLUGINS_TO_BUILD' instead")
endif()
-set(LIBOMPTARGET_ALL_PLUGIN_TARGETS amdgpu cuda host)
+set(LIBOMPTARGET_ALL_PLUGIN_TARGETS mpi amdgpu cuda host)
set(LIBOMPTARGET_PLUGINS_TO_BUILD "all" CACHE STRING
"Semicolon-separated list of plugins to use: cuda, amdgpu, host or \"all\".")
@@ -194,8 +194,10 @@ set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} powerpc64-ibm-linux-g
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} powerpc64-ibm-linux-gnu-LTO")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} x86_64-unknown-linux-gnu")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} x86_64-unknown-linux-gnu-LTO")
+set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} x86_64-unknown-linux-gnu-mpi")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} nvptx64-nvidia-cuda")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} nvptx64-nvidia-cuda-LTO")
+set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} nvptx64-nvidia-cuda-mpi")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} nvptx64-nvidia-cuda-JIT-LTO")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} s390x-ibm-linux-gnu")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} s390x-ibm-linux-gnu-LTO")
@@ -341,6 +343,8 @@ set(LIBOMPTARGET_LLVM_LIBRARY_DIR "${LLVM_LIBRARY_DIR}" CACHE STRING
set(LIBOMPTARGET_LLVM_LIBRARY_INTDIR "${LIBOMPTARGET_INTDIR}" CACHE STRING
"Path to folder where intermediate libraries will be output")
+set(LIBOMPTARGET_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src)
+
# Build offloading plugins and device RTLs if they are available.
add_subdirectory(plugins-nextgen)
add_subdirectory(DeviceRTL)
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index 41cc0f286a581f..75fec516de9b88 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -1208,130 +1208,141 @@ struct GenericPluginTy {
/// Returns non-zero if the \p Image is compatible with the plugin. This
/// function does not require the plugin to be initialized before use.
- int32_t is_plugin_compatible(__tgt_device_image *Image);
+ virtual int32_t is_plugin_compatible(__tgt_device_image *Image);
/// Returns non-zero if the \p Image is compatible with the device.
- int32_t is_device_compatible(int32_t DeviceId, __tgt_device_image *Image);
+ virtual int32_t is_device_compatible(int32_t DeviceId,
+ __tgt_device_image *Image);
/// Returns non-zero if the plugin device has been initialized.
- int32_t is_device_initialized(int32_t DeviceId) const;
+ virtual int32_t is_device_initialized(int32_t DeviceId) const;
/// Initialize the device inside of the plugin.
- int32_t init_device(int32_t DeviceId);
+ virtual int32_t init_device(int32_t DeviceId);
/// Return the number of devices this plugin can support.
- int32_t number_of_devices();
+ virtual int32_t number_of_devices();
/// Returns non-zero if the data can be exchanged between the two devices.
- int32_t is_data_exchangable(int32_t SrcDeviceId, int32_t DstDeviceId);
+ virtual int32_t is_data_exchangable(int32_t SrcDeviceId, int32_t DstDeviceId);
/// Initializes the record and replay mechanism inside the plugin.
- int32_t initialize_record_replay(int32_t DeviceId, int64_t MemorySize,
- void *VAddr, bool isRecord, bool SaveOutput,
- uint64_t &ReqPtrArgOffset);
+ virtual int32_t initialize_record_replay(int32_t DeviceId, int64_t MemorySize,
+ void *VAddr, bool isRecord,
+ bool SaveOutput,
+ uint64_t &ReqPtrArgOffset);
/// Loads the associated binary into the plugin and returns a handle to it.
- int32_t load_binary(int32_t DeviceId, __tgt_device_image *TgtImage,
- __tgt_device_binary *Binary);
+ virtual int32_t load_binary(int32_t DeviceId, __tgt_device_image *TgtImage,
+ __tgt_device_binary *Binary);
/// Allocates memory that is accessively to the given device.
- void *data_alloc(int32_t DeviceId, int64_t Size, void *HostPtr, int32_t Kind);
+ virtual void *data_alloc(int32_t DeviceId, int64_t Size, void *HostPtr,
+ int32_t Kind);
/// Deallocates memory on the given device.
- int32_t data_delete(int32_t DeviceId, void *TgtPtr, int32_t Kind);
+ virtual int32_t data_delete(int32_t DeviceId, void *TgtPtr, int32_t Kind);
/// Locks / pins host memory using the plugin runtime.
- int32_t data_lock(int32_t DeviceId, void *Ptr, int64_t Size,
- void **LockedPtr);
+ virtual int32_t data_lock(int32_t DeviceId, void *Ptr, int64_t Size,
+ void **LockedPtr);
/// Unlocks / unpins host memory using the plugin runtime.
- int32_t data_unlock(int32_t DeviceId, void *Ptr);
+ virtual int32_t data_unlock(int32_t DeviceId, void *Ptr);
/// Notify the runtime about a new mapping that has been created outside.
- int32_t data_notify_mapped(int32_t DeviceId, void *HstPtr, int64_t Size);
+ virtual int32_t data_notify_mapped(int32_t DeviceId, void *HstPtr,
+ int64_t Size);
/// Notify t he runtime about a mapping that has been deleted.
- int32_t data_notify_unmapped(int32_t DeviceId, void *HstPtr);
+ virtual int32_t data_notify_unmapped(int32_t DeviceId, void *HstPtr);
/// Copy data to the given device.
- int32_t data_submit(int32_t DeviceId, void *TgtPtr, void *HstPtr,
- int64_t Size);
+ virtual int32_t data_submit(int32_t DeviceId, void *TgtPtr, void *HstPtr,
+ int64_t Size);
/// Copy data to the given device asynchronously.
- int32_t data_submit_async(int32_t DeviceId, void *TgtPtr, void *HstPtr,
- int64_t Size, __tgt_async_info *AsyncInfoPtr);
+ virtual int32_t data_submit_async(int32_t DeviceId, void *TgtPtr,
+ void *HstPtr, int64_t Size,
+ __tgt_async_info *AsyncInfoPtr);
/// Copy data from the given device.
- int32_t data_retrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr,
- int64_t Size);
+ virtual int32_t data_retrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr,
+ int64_t Size);
/// Copy data from the given device asynchornously.
- int32_t data_retrieve_async(int32_t DeviceId, void *HstPtr, void *TgtPtr,
- int64_t Size, __tgt_async_info *AsyncInfoPtr);
+ virtual int32_t data_retrieve_async(int32_t DeviceId, void *HstPtr,
+ void *TgtPtr, int64_t Size,
+ __tgt_async_info *AsyncInfoPtr);
/// Exchange memory addresses between two devices.
- int32_t data_exchange(int32_t SrcDeviceId, void *SrcPtr, int32_t DstDeviceId,
- void *DstPtr, int64_t Size);
+ virtual int32_t data_exchange(int32_t SrcDeviceId, void *SrcPtr,
+ int32_t DstDeviceId, void *DstPtr,
+ int64_t Size);
/// Exchange memory addresses between two devices asynchronously.
- int32_t data_exchange_async(int32_t SrcDeviceId, void *SrcPtr,
- int DstDeviceId, void *DstPtr, int64_t Size,
- __tgt_async_info *AsyncInfo);
+ virtual int32_t data_exchange_async(int32_t SrcDeviceId, void *SrcPtr,
+ int DstDeviceId, void *DstPtr,
+ int64_t Size,
+ __tgt_async_info *AsyncInfo);
/// Begin executing a kernel on the given device.
- int32_t launch_kernel(int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs,
- ptrdiff_t *TgtOffsets, KernelArgsTy *KernelArgs,
- __tgt_async_info *AsyncInfoPtr);
+ virtual int32_t launch_kernel(int32_t DeviceId, void *TgtEntryPtr,
+ void **TgtArgs, ptrdiff_t *TgtOffsets,
+ KernelArgsTy *KernelArgs,
+ __tgt_async_info *AsyncInfoPtr);
/// Synchronize an asyncrhonous queue with the plugin runtime.
- int32_t synchronize(int32_t DeviceId, __tgt_async_info *AsyncInfoPtr);
+ virtual int32_t synchronize(int32_t DeviceId, __tgt_async_info *AsyncInfoPtr);
/// Query the current state of an asynchronous queue.
- int32_t query_async(int32_t DeviceId, __tgt_async_info *AsyncInfoPtr);
+ virtual int32_t query_async(int32_t DeviceId, __tgt_async_info *AsyncInfoPtr);
/// Prints information about the given devices supported by the plugin.
- void print_device_info(int32_t DeviceId);
+ virtual void print_device_info(int32_t DeviceId);
/// Creates an event in the given plugin if supported.
- int32_t create_event(int32_t DeviceId, void **EventPtr);
+ virtual int32_t create_event(int32_t DeviceId, void **EventPtr);
/// Records an event that has occurred.
- int32_t record_event(int32_t DeviceId, void *EventPtr,
- __tgt_async_info *AsyncInfoPtr);
+ virtual int32_t record_event(int32_t DeviceId, void *EventPtr,
+ __tgt_async_info *AsyncInfoPtr);
/// Wait until an event has occurred.
- int32_t wait_event(int32_t DeviceId, void *EventPtr,
- __tgt_async_info *AsyncInfoPtr);
+ virtual int32_t wait_event(int32_t DeviceId, void *EventPtr,
+ __tgt_async_info *AsyncInfoPtr);
/// Syncrhonize execution until an event is done.
- int32_t sync_event(int32_t DeviceId, void *EventPtr);
+ virtual int32_t sync_event(int32_t DeviceId, void *EventPtr);
/// Remove the event from the plugin.
- int32_t destroy_event(int32_t DeviceId, void *EventPtr);
+ virtual int32_t destroy_event(int32_t DeviceId, void *EventPtr);
/// Remove the event from the plugin.
void set_info_flag(uint32_t NewInfoLevel);
/// Creates an asynchronous queue for the given plugin.
- int32_t init_async_info(int32_t DeviceId, __tgt_async_info **AsyncInfoPtr);
+ virtual int32_t init_async_info(int32_t DeviceId,
+ __tgt_async_info **AsyncInfoPtr);
/// Creates device information to be used for diagnostics.
- int32_t init_device_info(int32_t DeviceId, __tgt_device_info *DeviceInfo,
- const char **ErrStr);
+ virtual int32_t init_device_info(int32_t DeviceId,
+ __tgt_device_info *DeviceInfo,
+ const char **ErrStr);
/// Sets the offset into the devices for use by OMPT.
int32_t set_device_identifier(int32_t UserId, int32_t DeviceId);
/// Returns if the plugin can support auotmatic copy.
- int32_t use_auto_zero_copy(int32_t DeviceId);
+ virtual int32_t use_auto_zero_copy(int32_t DeviceId);
/// Look up a global symbol in the given binary.
- int32_t get_global(__tgt_device_binary Binary, uint64_t Size,
- const char *Name, void **DevicePtr);
+ virtual int32_t get_global(__tgt_device_binary Binary, uint64_t Size,
+ const char *Name, void **DevicePtr);
/// Look up a kernel function in the given binary.
- int32_t get_function(__tgt_device_binary Binary, const char *Name,
- void **KernelPtr);
+ virtual int32_t get_function(__tgt_device_binary Binary, const char *Name,
+ void **KernelPtr);
private:
/// Indicates if the platform runtime has been fully initialized.
diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp
index fe296b77c7d557..c72a0770af23cf 100644
--- a/offload/plugins-nextgen/host/src/rtl.cpp
+++ b/offload/plugins-nextgen/host/src/rtl.cpp
@@ -43,7 +43,7 @@
#endif
// The number of devices in this plugin.
-#define NUM_DEVICES 4
+#define NUM_DEVICES 1
namespace llvm {
namespace omp {
diff --git a/offload/plugins-nextgen/mpi/CMakeLists.txt b/offload/plugins-nextgen/mpi/CMakeLists.txt
new file mode 100644
index 00000000000000..b64b2218048aa8
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/CMakeLists.txt
@@ -0,0 +1,134 @@
+# Looking for MPI...
+find_package(MPI QUIET)
+
+if(NOT(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(ppc64le)$" AND CMAKE_SYSTEM_NAME MATCHES "Linux"))
+ message(STATUS "Not building MPI offloading plugin: only support MPI in Linux x86_64 or ppc64le hosts.")
+ return()
+elseif(NOT MPI_CXX_FOUND)
+ message(STATUS "Not building MPI offloading plugin: MPI not found in system.")
+ return()
+endif()
+
+message(STATUS "Building MPI Proxy offloading plugin.")
+
+# Event System
+add_subdirectory(event_system)
+
+# MPI Plugin
+
+# Create the library and add the default arguments.
+add_target_library(omptarget.rtl.mpi MPI)
+
+target_sources(omptarget.rtl.mpi PRIVATE
+ src/rtl.cpp
+)
+
+target_link_libraries(omptarget.rtl.mpi PRIVATE
+ EventSystem
+)
+
+# Add include directories
+target_include_directories(omptarget.rtl.mpi PRIVATE
+ ${LIBOMPTARGET_INCLUDE_DIR})
+
+# Set C++20 as the target standard for this plugin.
+set_target_properties(omptarget.rtl.mpi
+ PROPERTIES
+ CXX_STANDARD 20
+ CXX_STANDARD_REQUIRED ON)
+
+
+# Configure testing for the MPI plugin.
+list(APPEND LIBOMPTARGET_TESTED_PLUGINS "omptarget.rtl.mpi")
+# Report to the parent scope that we are building a plugin for MPI.
+set(LIBOMPTARGET_TESTED_PLUGINS "${LIBOMPTARGET_TESTED_PLUGINS}" PARENT_SCOPE)
+
+# Define the target specific triples and ELF machine values.
+set(LIBOMPTARGET_SYSTEM_TARGETS
+ "${LIBOMPTARGET_SYSTEM_TARGETS} x86_64-pc-linux-gnu-mpi nvptx64-nvidia-cuda-mpi" PARENT_SCOPE)
+
+# Remote Plugin Manager
+message(STATUS "Building the llvm-offload-mpi-proxy-device")
+
+set(LIBOMPTARGET_ALL_REMOTE_PLUGIN_TARGETS amdgpu cuda host)
+set(LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD "all" CACHE STRING
+ "Semicolon-separated list of plugins to use: cuda, amdgpu, host or \"all\".")
+
+if(LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD STREQUAL "all")
+ set(LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD ${LIBOMPTARGET_ALL_REMOTE_PLUGIN_TARGETS})
+endif()
+
+if(NOT CMAKE_SYSTEM_NAME MATCHES "Linux" AND
+ "host" IN_LIST LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD)
+ message(STATUS "Not building remote host plugin: only Linux systems are supported")
+ list(REMOVE_ITEM LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD "host")
+endif()
+if(NOT (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(ppc64le)|(aarch64)$"
+ AND CMAKE_SYSTEM_NAME MATCHES "Linux"))
+ if("amdgpu" IN_LIST LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD)
+ message(STATUS "Not building remote AMDGPU plugin: only support AMDGPU in "
+ "Linux x86_64, ppc64le, or aarch64 hosts")
+ list(REMOVE_ITEM LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD "amdgpu")
+ endif()
+ if("cuda" IN_LIST LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD)
+ message(STATUS "Not building remote CUDA plugin: only support CUDA in "
+ "Linux x86_64, ppc64le, or aarch64 hosts")
+ list(REMOVE_ITEM LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD "cuda")
+ endif()
+endif()
+if("mpi" IN_LIST LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD)
+ message(STATUS "It is not possible to build the mpi plugin inside "
+ "the remote proxy device")
+ list(REMOVE_ITEM LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD "mpi")
+endif()
+
+message(STATUS "Building the MPI Plugin with support for remote offloading to "
+ "the \"${LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD}\" plugins")
+
+set(REMOTE_MPI_ENUM_PLUGIN_TARGETS "")
+foreach(plugin IN LISTS LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD)
+ set(REMOTE_MPI_ENUM_PLUGIN_TARGETS
+ "${REMOTE_MPI_ENUM_PLUGIN_TARGETS}PLUGIN_TARGET(${plugin})\n")
+endforeach()
+string(STRIP ${REMOTE_MPI_ENUM_PLUGIN_TARGETS} REMOTE_MPI_ENUM_PLUGIN_TARGETS)
+configure_file(
+ ${CMAKE_CURRENT_SOURCE_DIR}/src/RemoteTargets.def.in
+ ${LIBOMPTARGET_BINARY_INCLUDE_DIR}/Shared/RemoteTargets.def
+)
+
+llvm_add_tool(OPENMP llvm-offload-mpi-proxy-device
+ src/ProxyDevice.cpp
+ src/RemotePluginManager.cpp
+ ${LIBOMPTARGET_SRC_DIR}/OpenMP/OMPT/Callback.cpp
+)
+
+llvm_update_compile_flags(llvm-offload-mpi-proxy-device)
+
+target_link_libraries(llvm-offload-mpi-proxy-device PRIVATE
+ EventSystem
+ LLVMSupport
+ omp
+)
+
+add_dependencies(llvm-offload-mpi-proxy-device omp)
+
+target_include_directories(llvm-offload-mpi-proxy-device PRIVATE
+ ${LIBOMPTARGET_INCLUDE_DIR}
+ ${LIBOMPTARGET_LLVM_INCLUDE_DIRS}
+ ${LIBOMPTARGET_BINARY_INCLUDE_DIR}
+)
+
+foreach(plugin IN LISTS LIBOMPTARGET_REMOTE_PLUGINS_TO_BUILD)
+ target_link_libraries(llvm-offload-mpi-proxy-device PRIVATE omptarget.rtl.${plugin})
+ add_dependencies(llvm-offload-mpi-proxy-device omptarget.rtl.${plugin})
+endforeach()
+
+# Set C++20 as the target standard for this plugin.
+set_target_properties(llvm-offload-mpi-proxy-device
+ PROPERTIES
+ CXX_STANDARD 20
+ CXX_STANDARD_REQUIRED ON)
+
+target_compile_definitions(llvm-offload-mpi-proxy-device PRIVATE
+ TARGET_NAME=llvm-offload-mpi-proxy-device
+ DEBUG_PREFIX="MPIProxyDevice")
diff --git a/offload/plugins-nextgen/mpi/event_system/CMakeLists.txt b/offload/plugins-nextgen/mpi/event_system/CMakeLists.txt
new file mode 100644
index 00000000000000..32a9f9b79423e1
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/event_system/CMakeLists.txt
@@ -0,0 +1,29 @@
+# Build EventSystem
+add_library(EventSystem OBJECT
+ EventSystem.cpp
+)
+
+target_include_directories(EventSystem PUBLIC
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ ${LIBOMPTARGET_BINARY_INCLUDE_DIR}
+ ${LIBOMPTARGET_INCLUDE_DIR}
+)
+
+target_link_libraries(EventSystem PRIVATE
+ MPI::MPI_CXX
+ LLVMSupport
+)
+
+target_compile_options(EventSystem PUBLIC ${offload_compile_flags})
+target_link_options(EventSystem PUBLIC ${offload_link_flags})
+
+set_target_properties(EventSystem PROPERTIES POSITION_INDEPENDENT_CODE ON)
+
+# Set C++20 as the target standard for this plugin.
+set_target_properties(EventSystem
+ PROPERTIES
+ CXX_STANDARD 20
+ CXX_STANDARD_REQUIRED ON)
+
+target_compile_definitions(EventSystem PRIVATE
+ DEBUG_PREFIX="EventSystem")
\ No newline at end of file
diff --git a/offload/plugins-nextgen/mpi/event_system/EventSystem.cpp b/offload/plugins-nextgen/mpi/event_system/EventSystem.cpp
new file mode 100644
index 00000000000000..ab59e3da837fa5
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/event_system/EventSystem.cpp
@@ -0,0 +1,848 @@
+//===------ event_system.cpp - Concurrent MPI communication -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the implementation of the MPI Event System used by the MPI
+// target runtime for concurrent communication.
+//
+//===----------------------------------------------------------------------===//
+
+#include "EventSystem.h"
+
+#include <algorithm>
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <functional>
+#include <memory>
+
+#include <mpi.h>
+#include <unistd.h>
+
+#include "Shared/APITypes.h"
+#include "Shared/Debug.h"
+#include "Shared/EnvironmentVar.h"
+#include "Shared/Utils.h"
+#include "omptarget.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Error.h"
+
+#include "llvm/Support/DynamicLibrary.h"
+
+#define CHECK(expr, msg, ...) \
+ if (!(expr)) { \
+ REPORT(msg, ##__VA_ARGS__); \
+ return false; \
+ }
+
+/// Resumes the most recent incomplete coroutine in the list.
+void EventTy::resume() {
+ // Acquire first handle not done.
+ const CoHandleTy &RootHandle = getHandle().promise().RootHandle;
+ auto &ResumableHandle = RootHandle.promise().PrevHandle;
+ while (ResumableHandle.done()) {
+ ResumableHandle = ResumableHandle.promise().PrevHandle;
+
+ if (ResumableHandle == RootHandle)
+ break;
+ }
+
+ if (!ResumableHandle.done())
+ ResumableHandle.resume();
+}
+
+/// Wait until event completes.
+void EventTy::wait() {
+ // Advance the event progress until it is completed.
+ while (!done()) {
+ resume();
+
+ std::this_thread::sleep_for(
+ std::chrono::microseconds(EventPollingRate.get()));
+ }
+}
+
+/// Check if the event has completed.
+bool EventTy::done() const { return getHandle().done(); }
+
+/// Check if it is an empty event.
+bool EventTy::empty() const { return !getHandle(); }
+
+/// Get the coroutine error from the Handle.
+llvm::Error EventTy::getError() const {
+ auto &Error = getHandle().promise().CoroutineError;
+ if (Error)
+ return std::move(*Error);
+
+ return llvm::Error::success();
+}
+
+/// MPI Request Manager Destructor. The Manager cannot be destroyed until all
+/// the requests it manages have been completed.
+MPIRequestManagerTy::~MPIRequestManagerTy() {
+ assert(Requests.empty() && "Requests must be fulfilled and emptied before "
+ "destruction. Did you co_await on it?");
+}
+
+/// Send a message to \p OtherRank asynchronously.
+void MPIRequestManagerTy::send(const void *Buffer, int Size,
+ MPI_Datatype Datatype) {
+ MPI_Isend(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+ &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+/// Divide the \p Buffer into fragments of size \p MPIFragmentSize and send them
+/// to \p OtherRank asynchronously.
+void MPIRequestManagerTy::sendInBatchs(void *Buffer, int64_t Size) {
+ // Operates over many fragments of the original buffer of at most
+ // MPI_FRAGMENT_SIZE bytes.
+ char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+ int64_t RemainingBytes = Size;
+ while (RemainingBytes > 0) {
+ send(&BufferByteArray[Size - RemainingBytes],
+ static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+ MPI_BYTE);
+ RemainingBytes -= MPIFragmentSize.get();
+ }
+}
+
+/// Receive a message from \p OtherRank asynchronously.
+void MPIRequestManagerTy::receive(void *Buffer, int Size,
+ MPI_Datatype Datatype) {
+ MPI_Irecv(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+ &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+/// Asynchronously receive message fragments from \p OtherRank and reconstruct
+/// them into \p Buffer.
+void MPIRequestManagerTy::receiveInBatchs(void *Buffer, int64_t Size) {
+ // Operates over many fragments of the original buffer of at most
+ // MPI_FRAGMENT_SIZE bytes.
+ char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+ int64_t RemainingBytes = Size;
+ while (RemainingBytes > 0) {
+ receive(&BufferByteArray[Size - RemainingBytes],
+ static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+ MPI_BYTE);
+ RemainingBytes -= MPIFragmentSize.get();
+ }
+}
+
+/// Coroutine that waits until all pending requests finish.
+EventTy MPIRequestManagerTy::wait() {
+ int RequestsCompleted = false;
+
+ while (!RequestsCompleted) {
+ int MPIError = MPI_Testall(Requests.size(), Requests.data(),
+ &RequestsCompleted, MPI_STATUSES_IGNORE);
+
+ if (MPIError != MPI_SUCCESS)
+ co_return createError("Waiting of MPI requests failed with code %d",
+ MPIError);
+
+ co_await std::suspend_always{};
+ }
+
+ Requests.clear();
+
+ co_return llvm::Error::success();
+}
+
+EventTy operator co_await(MPIRequestManagerTy &RequestManager) {
+ return RequestManager.wait();
+}
+
+void *memAllocHost(int64_t Size) {
+ void *HstPtr = nullptr;
+ int MPIError = MPI_Alloc_mem(Size, MPI_INFO_NULL, &HstPtr);
+ if (MPIError != MPI_SUCCESS)
+ return nullptr;
+ return HstPtr;
+}
+
+int memFreeHost(void *HstPtr) {
+ int MPIError = MPI_Free_mem(HstPtr);
+ if (MPIError != MPI_SUCCESS)
+ return OFFLOAD_FAIL;
+ return OFFLOAD_SUCCESS;
+}
+
+/// Event implementations on Host side.
+namespace OriginEvents {
+
+EventTy retrieveNumDevices(MPIRequestManagerTy RequestManager,
+ int32_t *NumDevices) {
+ RequestManager.receive(NumDevices, 1, MPI_INT32_T);
+ co_return (co_await RequestManager);
+}
+
+EventTy isPluginCompatible(MPIRequestManagerTy RequestManager,
+ __tgt_device_image *Image, bool *QueryResult) {
+ uint64_t Size = utils::getPtrDiff(Image->ImageEnd, Image->ImageStart);
+
+ void *Buffer = memAllocHost(Size);
+ if (Buffer != nullptr)
+ memcpy(Buffer, Image->ImageStart, Size);
+ else
+ Buffer = Image->ImageStart;
+
+ RequestManager.send(&Size, 1, MPI_UINT64_T);
+ RequestManager.send(Buffer, Size, MPI_BYTE);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ if (Buffer != Image->ImageStart)
+ memFreeHost(Buffer);
+
+ RequestManager.receive(QueryResult, sizeof(bool), MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy isDeviceCompatible(MPIRequestManagerTy RequestManager,
+ __tgt_device_image *Image, bool *QueryResult) {
+ uint64_t Size = utils::getPtrDiff(Image->ImageEnd, Image->ImageStart);
+
+ void *Buffer = memAllocHost(Size);
+ if (Buffer != nullptr)
+ memcpy(Buffer, Image->ImageStart, Size);
+ else
+ Buffer = Image->ImageStart;
+
+ RequestManager.send(&Size, 1, MPI_UINT64_T);
+ RequestManager.send(Buffer, Size, MPI_BYTE);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ if (Buffer != Image->ImageStart)
+ memFreeHost(Buffer);
+
+ RequestManager.receive(QueryResult, sizeof(bool), MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy initDevice(MPIRequestManagerTy RequestManager, void **DevicePtr) {
+ // Wait the complete notification
+ RequestManager.receive(DevicePtr, sizeof(void *), MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy initRecordReplay(MPIRequestManagerTy RequestManager, int64_t MemorySize,
+ void *VAddr, bool IsRecord, bool SaveOutput,
+ uint64_t *ReqPtrArgOffset) {
+ RequestManager.send(&MemorySize, 1, MPI_INT64_T);
+ RequestManager.send(&VAddr, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&IsRecord, sizeof(bool), MPI_BYTE);
+ RequestManager.send(&SaveOutput, sizeof(bool), MPI_BYTE);
+ RequestManager.receive(&ReqPtrArgOffset, 1, MPI_UINT64_T);
+ co_return (co_await RequestManager);
+}
+
+EventTy isDataExchangable(MPIRequestManagerTy RequestManager,
+ int32_t DstDeviceId, bool *QueryResult) {
+ RequestManager.send(&DstDeviceId, 1, MPI_INT32_T);
+ RequestManager.receive(QueryResult, sizeof(bool), MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy allocateBuffer(MPIRequestManagerTy RequestManager, int64_t Size,
+ int32_t Kind, void **Buffer) {
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.send(&Kind, 1, MPI_INT32_T);
+
+ // Event completion notification
+ RequestManager.receive(Buffer, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy deleteBuffer(MPIRequestManagerTy RequestManager, void *Buffer,
+ int32_t Kind) {
+ RequestManager.send(&Buffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Kind, 1, MPI_INT32_T);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy submit(MPIRequestManagerTy RequestManager, void *TgtPtr,
+ EventDataHandleTy HstPtr, int64_t Size,
+ __tgt_async_info *AsyncInfoPtr) {
+ RequestManager.send(&AsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ RequestManager.send(&TgtPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+
+ RequestManager.sendInBatchs(HstPtr.get(), Size);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy retrieve(MPIRequestManagerTy RequestManager, int64_t Size, void *HstPtr,
+ void *TgtPtr, __tgt_async_info *AsyncInfoPtr) {
+ bool DeviceOpStatus = true;
+
+ RequestManager.send(&AsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ RequestManager.send(&TgtPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+
+ RequestManager.receive(&DeviceOpStatus, sizeof(bool), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ if (!DeviceOpStatus)
+ co_return createError("Failed to retrieve %p TgtPtr.", TgtPtr);
+
+ RequestManager.receiveInBatchs(HstPtr, Size);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy localExchange(MPIRequestManagerTy RequestManager, void *SrcPtr,
+ int DstDeviceId, void *DstPtr, int64_t Size,
+ __tgt_async_info *AsyncInfoPtr) {
+ RequestManager.send(&SrcPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&DstDeviceId, 1, MPI_INT);
+ RequestManager.send(&DstPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.send(&AsyncInfoPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy exchange(MPIRequestManagerTy RequestManager, int SrcRank,
+ const void *OrgBuffer, int DstRank, void *DstBuffer,
+ int64_t Size, __tgt_async_info *AsyncInfoPtr) {
+ // Send data to SrcRank
+ RequestManager.send(&OrgBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.send(&DstRank, 1, MPI_INT);
+ RequestManager.send(&AsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ // Send data to DstRank
+ RequestManager.OtherRank = DstRank;
+ RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.send(&SrcRank, 1, MPI_INT);
+ RequestManager.send(&AsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ RequestManager.OtherRank = SrcRank;
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy launchKernel(MPIRequestManagerTy RequestManager, void *TgtEntryPtr,
+ EventDataHandleTy TgtArgs, EventDataHandleTy TgtOffsets,
+ EventDataHandleTy KernelArgsHandle,
+ __tgt_async_info *AsyncInfoPtr) {
+ KernelArgsTy *KernelArgs =
+ static_cast<KernelArgsTy *>(KernelArgsHandle.get());
+
+ RequestManager.send(&KernelArgs->NumArgs, 1, MPI_UINT32_T);
+ RequestManager.send(&AsyncInfoPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&TgtEntryPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.send(TgtArgs.get(), KernelArgs->NumArgs * sizeof(void *),
+ MPI_BYTE);
+ RequestManager.send(TgtOffsets.get(), KernelArgs->NumArgs * sizeof(ptrdiff_t),
+ MPI_BYTE);
+
+ RequestManager.send(KernelArgs, sizeof(KernelArgsTy), MPI_BYTE);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy getGlobal(MPIRequestManagerTy RequestManager,
+ __tgt_device_binary Binary, uint64_t Size, const char *Name,
+ void **DevicePtr) {
+ uint32_t NameSize = strlen(Name) + 1;
+ RequestManager.send(&Binary.handle, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_UINT64_T);
+ RequestManager.send(&NameSize, 1, MPI_UINT32_T);
+ RequestManager.send(Name, NameSize, MPI_CHAR);
+
+ RequestManager.receive(DevicePtr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy getFunction(MPIRequestManagerTy RequestManager,
+ __tgt_device_binary Binary, const char *Name,
+ void **KernelPtr) {
+ RequestManager.send(&Binary.handle, sizeof(void *), MPI_BYTE);
+ uint32_t Size = strlen(Name) + 1;
+ RequestManager.send(&Size, 1, MPI_UINT32_T);
+ RequestManager.send(Name, Size, MPI_CHAR);
+
+ RequestManager.receive(KernelPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy synchronize(MPIRequestManagerTy RequestManager,
+ __tgt_async_info *AsyncInfoPtr) {
+ bool DeviceOpStatus = false;
+ RequestManager.send(&AsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ RequestManager.receive(&DeviceOpStatus, sizeof(bool), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ if (!DeviceOpStatus)
+ co_return createError("Failed to synchronize device.");
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy sync(EventTy Event) {
+ while (!Event.done())
+ co_await std::suspend_always{};
+
+ co_return llvm::Error::success();
+}
+
+EventTy loadBinary(MPIRequestManagerTy RequestManager,
+ const __tgt_device_image *Image,
+ __tgt_device_binary *Binary) {
+ auto &[ImageStart, ImageEnd, EntriesBegin, EntriesEnd] = *Image;
+
+ // Send the target table sizes.
+ size_t ImageSize = (size_t)ImageEnd - (size_t)ImageStart;
+ size_t EntryCount = EntriesEnd - EntriesBegin;
+ llvm::SmallVector<size_t> EntryNameSizes(EntryCount);
+
+ for (size_t I = 0; I < EntryCount; I++) {
+ // Note: +1 for the terminator.
+ EntryNameSizes[I] = std::strlen(EntriesBegin[I].name) + 1;
+ }
+
+ RequestManager.send(&ImageSize, 1, MPI_UINT64_T);
+ RequestManager.send(&EntryCount, 1, MPI_UINT64_T);
+ RequestManager.send(EntryNameSizes.begin(), EntryCount, MPI_UINT64_T);
+
+ void *Buffer = memAllocHost(ImageSize);
+ if (Buffer != nullptr)
+ memcpy(Buffer, ImageStart, ImageSize);
+ else
+ Buffer = ImageStart;
+
+ // Send the image bytes and the table entries.
+ RequestManager.send(Buffer, ImageSize, MPI_BYTE);
+
+ for (size_t I = 0; I < EntryCount; I++) {
+ RequestManager.send(&EntriesBegin[I].addr, 1, MPI_UINT64_T);
+ RequestManager.send(EntriesBegin[I].name, EntryNameSizes[I], MPI_CHAR);
+ RequestManager.send(&EntriesBegin[I].size, 1, MPI_UINT64_T);
+ RequestManager.send(&EntriesBegin[I].flags, 1, MPI_INT32_T);
+ RequestManager.send(&EntriesBegin[I].data, 1, MPI_INT32_T);
+ }
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ if (Buffer != ImageStart)
+ memFreeHost(Buffer);
+
+ RequestManager.receive(&Binary->handle, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy queryAsync(MPIRequestManagerTy RequestManager,
+ __tgt_async_info *AsyncInfoPtr) {
+ RequestManager.send(&AsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy printDeviceInfo(MPIRequestManagerTy RequestManager) {
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy initAsyncInfo(MPIRequestManagerTy RequestManager,
+ __tgt_async_info **AsyncInfoPtr) {
+ RequestManager.receive(AsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy initDeviceInfo(MPIRequestManagerTy RequestManager,
+ __tgt_device_info *DeviceInfo) {
+ RequestManager.send(DeviceInfo, sizeof(__tgt_device_info), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ RequestManager.receive(DeviceInfo, sizeof(__tgt_device_info), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy dataLock(MPIRequestManagerTy RequestManager, void *Ptr, int64_t Size,
+ void **LockedPtr) {
+ RequestManager.send(&Ptr, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.receive(LockedPtr, sizeof(void *), MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy dataUnlock(MPIRequestManagerTy RequestManager, void *Ptr) {
+ RequestManager.send(&Ptr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy dataNotifyMapped(MPIRequestManagerTy RequestManager, void *HstPtr,
+ int64_t Size) {
+ RequestManager.send(&HstPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy dataNotifyUnmapped(MPIRequestManagerTy RequestManager, void *HstPtr) {
+ RequestManager.send(&HstPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy exit(MPIRequestManagerTy RequestManager) {
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+} // namespace OriginEvents
+
+/// Event Queue implementation
+EventQueue::EventQueue() : Queue(), QueueMtx(), CanPopCv() {}
+
+size_t EventQueue::size() {
+ std::lock_guard<std::mutex> Lock(QueueMtx);
+ return Queue.size();
+}
+
+void EventQueue::push(EventTy &&Event) {
+ {
+ std::unique_lock<std::mutex> Lock(QueueMtx);
+ Queue.emplace(Event);
+ }
+
+ // Notifies a thread possibly blocked by an empty queue.
+ CanPopCv.notify_one();
+}
+
+EventTy EventQueue::pop(std::stop_token &Stop) {
+ std::unique_lock<std::mutex> Lock(QueueMtx);
+
+ // Waits for at least one item to be pushed.
+ const bool HasNewEvent =
+ CanPopCv.wait(Lock, Stop, [&] { return !Queue.empty(); });
+
+ if (!HasNewEvent) {
+ assert(Stop.stop_requested() && "Queue was empty while running.");
+ return EventTy();
+ }
+
+ EventTy TargetEvent = std::move(Queue.front());
+ Queue.pop();
+ return TargetEvent;
+}
+
+/// Event System implementation
+EventSystemTy::EventSystemTy()
+ : EventSystemState(EventSystemStateTy::CREATED),
+ NumMPIComms("OMPTARGET_NUM_MPI_COMMS", 10) {}
+
+EventSystemTy::~EventSystemTy() {
+ if (EventSystemState == EventSystemStateTy::FINALIZED)
+ return;
+
+ REPORT("Destructing internal event system before deinitializing it.\n");
+ deinitialize();
+}
+
+bool EventSystemTy::initialize() {
+ if (EventSystemState >= EventSystemStateTy::INITIALIZED) {
+ REPORT("Trying to initialize event system twice.\n");
+ return false;
+ }
+
+ if (!createLocalMPIContext())
+ return false;
+
+ EventSystemState = EventSystemStateTy::INITIALIZED;
+
+ return true;
+}
+
+bool EventSystemTy::is_initialized() {
+ return EventSystemState == EventSystemStateTy::INITIALIZED;
+}
+
+bool EventSystemTy::deinitialize() {
+ if (EventSystemState == EventSystemStateTy::FINALIZED) {
+ REPORT("Trying to deinitialize event system twice.\n");
+ return false;
+ }
+
+ if (EventSystemState == EventSystemStateTy::RUNNING) {
+ REPORT("Trying to deinitialize event system while it is running.\n");
+ return false;
+ }
+
+ // Only send exit events from the host side
+ if (isHost() && WorldSize > 1) {
+ const int NumWorkers = WorldSize - 1;
+ llvm::SmallVector<EventTy> ExitEvents(NumWorkers);
+ for (int WorkerRank = 0; WorkerRank < NumWorkers; WorkerRank++) {
+ ExitEvents[WorkerRank] =
+ createEvent(OriginEvents::exit, EventTypeTy::EXIT, WorkerRank);
+ ExitEvents[WorkerRank].resume();
+ }
+
+ bool SuccessfullyExited = true;
+ for (int WorkerRank = 0; WorkerRank < NumWorkers; WorkerRank++) {
+ ExitEvents[WorkerRank].wait();
+ SuccessfullyExited &= ExitEvents[WorkerRank].done();
+ auto Error = ExitEvents[WorkerRank].getError();
+ if (Error)
+ REPORT("Exit event failed with msg: %s\n",
+ toString(std::move(Error)).data());
+ }
+
+ if (!SuccessfullyExited) {
+ REPORT("Failed to stop worker processes.\n");
+ return false;
+ }
+ }
+
+ if (!destroyLocalMPIContext())
+ return false;
+
+ EventSystemState = EventSystemStateTy::FINALIZED;
+
+ return true;
+}
+
+EventTy EventSystemTy::createExchangeEvent(int SrcDevice, const void *SrcBuffer,
+ int DstDevice, void *DstBuffer,
+ int64_t Size,
+ __tgt_async_info *AsyncInfo) {
+ const int EventTag = createNewEventTag();
+ auto &EventComm = getNewEventComm(EventTag);
+
+ int32_t SrcRank, SrcDeviceId, DstRank, DstDeviceId;
+
+ std::tie(SrcRank, SrcDeviceId) = mapDeviceId(SrcDevice);
+ std::tie(DstRank, DstDeviceId) = mapDeviceId(DstDevice);
+
+ int SrcEventNotificationInfo[] = {static_cast<int>(EventTypeTy::EXCHANGE_SRC),
+ EventTag, SrcDeviceId};
+
+ int DstEventNotificationInfo[] = {static_cast<int>(EventTypeTy::EXCHANGE_DST),
+ EventTag, DstDeviceId};
+
+ MPI_Request SrcRequest = MPI_REQUEST_NULL;
+ MPI_Request DstRequest = MPI_REQUEST_NULL;
+
+ int MPIError = MPI_Isend(SrcEventNotificationInfo, 3, MPI_INT, SrcRank,
+ static_cast<int>(ControlTagsTy::EVENT_REQUEST),
+ GateThreadComm, &SrcRequest);
+
+ MPIError &= MPI_Isend(DstEventNotificationInfo, 3, MPI_INT, DstRank,
+ static_cast<int>(ControlTagsTy::EVENT_REQUEST),
+ GateThreadComm, &DstRequest);
+
+ if (MPIError != MPI_SUCCESS)
+ co_return createError(
+ "MPI failed during exchange event notification with error %d",
+ MPIError);
+
+ MPIRequestManagerTy RequestManager(EventComm, EventTag, SrcRank, SrcDeviceId,
+ {SrcRequest, DstRequest});
+
+ co_return (co_await OriginEvents::exchange(std::move(RequestManager), SrcRank,
+ SrcBuffer, DstRank, DstBuffer,
+ Size, AsyncInfo));
+}
+
+/// Creates a new event tag of at least 'FIRST_EVENT' value.
+/// Tag values smaller than 'FIRST_EVENT' are reserved for control
+/// messages between the event systems of different MPI processes.
+int EventSystemTy::createNewEventTag() {
+ int tag = 0;
+
+ do {
+ tag = EventCounter.fetch_add(1) % MPITagMaxValue;
+ } while (tag < static_cast<int>(ControlTagsTy::FIRST_EVENT));
+
+ return tag;
+}
+
+MPI_Comm &EventSystemTy::getNewEventComm(int MPITag) {
+ // Retrieve a comm using a round-robin strategy around the event's mpi tag.
+ return EventCommPool[MPITag % EventCommPool.size()];
+}
+
+static const char *threadLevelToString(int ThreadLevel) {
+ switch (ThreadLevel) {
+ case MPI_THREAD_SINGLE:
+ return "MPI_THREAD_SINGLE";
+ case MPI_THREAD_SERIALIZED:
+ return "MPI_THREAD_SERIALIZED";
+ case MPI_THREAD_FUNNELED:
+ return "MPI_THREAD_FUNNELED";
+ case MPI_THREAD_MULTIPLE:
+ return "MPI_THREAD_MULTIPLE";
+ default:
+ return "unkown";
+ }
+}
+
+/// Initialize the MPI context.
+bool EventSystemTy::createLocalMPIContext() {
+ int MPIError = MPI_SUCCESS;
+ int IsMPIInitialized = 0;
+ int ThreadLevel = MPI_THREAD_SINGLE;
+
+ MPI_Initialized(&IsMPIInitialized);
+
+ if (IsMPIInitialized)
+ MPI_Query_thread(&ThreadLevel);
+ else
+ MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &ThreadLevel);
+
+ CHECK(ThreadLevel == MPI_THREAD_MULTIPLE,
+ "MPI plugin requires a MPI implementation with %s thread level. "
+ "Implementation only supports up to %s.\n",
+ threadLevelToString(MPI_THREAD_MULTIPLE),
+ threadLevelToString(ThreadLevel));
+
+ if (IsMPIInitialized && ThreadLevel == MPI_THREAD_MULTIPLE) {
+ MPI_Comm_rank(MPI_COMM_WORLD, &LocalRank);
+ MPI_Comm_size(MPI_COMM_WORLD, &WorldSize);
+ return true;
+ }
+
+ // Create gate thread comm.
+ MPIError = MPI_Comm_dup(MPI_COMM_WORLD, &GateThreadComm);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to create gate thread MPI comm with error %d\n", MPIError);
+
+ // Create event comm pool.
+ EventCommPool.resize(NumMPIComms.get(), MPI_COMM_NULL);
+ for (auto &Comm : EventCommPool) {
+ MPI_Comm_dup(MPI_COMM_WORLD, &Comm);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to create MPI comm pool with error %d\n", MPIError);
+ }
+
+ // Get local MPI process description.
+ MPIError = MPI_Comm_rank(GateThreadComm, &LocalRank);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to acquire the local MPI rank with error %d\n", MPIError);
+
+ MPIError = MPI_Comm_size(GateThreadComm, &WorldSize);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to acquire the world size with error %d\n", MPIError);
+
+ // Get max value for MPI tags.
+ MPI_Aint *Value = nullptr;
+ int Flag = 0;
+ MPIError = MPI_Comm_get_attr(GateThreadComm, MPI_TAG_UB, &Value, &Flag);
+ CHECK(Flag == 1 && MPIError == MPI_SUCCESS,
+ "Failed to acquire the MPI max tag value with error %d\n", MPIError);
+ MPITagMaxValue = *Value;
+
+ return true;
+}
+
+/// Destroy the MPI context.
+bool EventSystemTy::destroyLocalMPIContext() {
+ int MPIError = MPI_SUCCESS;
+
+ if (GateThreadComm == MPI_COMM_NULL) {
+ return true;
+ }
+
+ // Note: We don't need to assert here since application part of the program
+ // was finished.
+ // Free gate thread comm.
+ MPIError = MPI_Comm_free(&GateThreadComm);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to destroy the gate thread MPI comm with error %d\n", MPIError);
+
+ // Free event comm pool.
+ for (auto &comm : EventCommPool) {
+ MPI_Comm_free(&comm);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to destroy the event MPI comm with error %d\n", MPIError);
+ }
+ EventCommPool.resize(0);
+
+ // Finalize the global MPI session.
+ int IsFinalized = false;
+ MPIError = MPI_Finalized(&IsFinalized);
+
+ if (IsFinalized) {
+ DP("MPI was already finalized externally.\n");
+ } else {
+ MPIError = MPI_Finalize();
+ CHECK(MPIError == MPI_SUCCESS, "Failed to finalize MPI with error: %d\n",
+ MPIError);
+ }
+
+ return true;
+}
+
+int EventSystemTy::getNumWorkers() const {
+ if (isHost())
+ return WorldSize - 1;
+ return 0;
+}
+
+int EventSystemTy::isHost() const { return LocalRank == WorldSize - 1; };
+
+/// Map DeviceId to the pair <RemoteRank, RemoteDeviceId>
+RemoteDeviceId EventSystemTy::mapDeviceId(int32_t DeviceId) {
+ int32_t NumRanks = DevicesPerRemote.size();
+ for (int32_t RemoteRank = 0; RemoteRank < NumRanks; RemoteRank++) {
+ if (DeviceId < DevicesPerRemote[RemoteRank])
+ return {RemoteRank, DeviceId};
+ DeviceId -= DevicesPerRemote[RemoteRank];
+ }
+ return {-1, -1};
+}
diff --git a/offload/plugins-nextgen/mpi/event_system/EventSystem.h b/offload/plugins-nextgen/mpi/event_system/EventSystem.h
new file mode 100644
index 00000000000000..b4e3b56dc8b0e8
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/event_system/EventSystem.h
@@ -0,0 +1,556 @@
+//===------- event_system.h - Concurrent MPI communication ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the declarations of the MPI Event System used by the MPI
+// target.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _MPI_PROXY_EVENT_SYSTEM_H_
+#define _MPI_PROXY_EVENT_SYSTEM_H_
+
+#include <atomic>
+#include <cassert>
+#include <concepts>
+#include <condition_variable>
+#include <coroutine>
+#include <cstddef>
+#include <cstdint>
+#include <exception>
+#include <memory>
+#include <mutex>
+#include <optional>
+#include <queue>
+#include <thread>
+#include <type_traits>
+#include <utility>
+
+#define MPICH_SKIP_MPICXX
+#include <mpi.h>
+
+#include "llvm/ADT/SmallVector.h"
+
+#include "Shared/APITypes.h"
+#include "Shared/EnvironmentVar.h"
+#include "Shared/Utils.h"
+
+/// External forward declarations.
+struct __tgt_device_image;
+struct ProxyDevice;
+
+/// Template helper for generating llvm::Error instances from events.
+template <typename... ArgsTy>
+static llvm::Error createError(const char *ErrFmt, ArgsTy... Args) {
+ return llvm::createStringError(llvm::inconvertibleErrorCode(), ErrFmt,
+ Args...);
+}
+
+/// The event type (type of action it will performed).
+///
+/// Enumerates the available events. Each enum item should be accompanied by an
+/// event class derived from BaseEvent. All the events are executed at a remote
+/// MPI process target by the event.
+enum class EventTypeTy : unsigned int {
+ // Remote device management
+ RETRIEVE_NUM_DEVICES, // Receives the number of devices from a remote process.
+ INIT_DEVICE, // Init Remote device
+ INIT_RECORD_REPLAY, // Initializes the record and replay mechanism.
+ IS_PLUGIN_COMPATIBLE, // Check if the Image can be executed by the remote
+ // plugin.
+ IS_DEVICE_COMPATIBLE, // Check if the Image can be executed by a device in the
+ // remote plugin.
+ IS_DATA_EXCHANGABLE, // Check if the plugin supports exchanging data.
+ LOAD_BINARY, // Transmits the binary descriptor to all workers
+ GET_GLOBAL, // Look up a global symbol in the given binary
+ GET_FUNCTION, // Look up a kernel function in the given binary.
+ SYNCHRONIZE, // Sync all events in the device.
+ INIT_ASYNC_INFO,
+ INIT_DEVICE_INFO,
+ QUERY_ASYNC,
+ PRINT_DEVICE_INFO,
+ DATA_LOCK,
+ DATA_UNLOCK,
+ DATA_NOTIFY_MAPPED,
+ DATA_NOTIFY_UNMAPPED,
+
+ // Memory management.
+ ALLOC, // Allocates a buffer at the remote process.
+ DELETE, // Deletes a buffer at the remote process.
+
+ // Data movement.
+ SUBMIT, // Sends a buffer data to a remote process.
+ RETRIEVE, // Receives a buffer data from a remote process.
+ LOCAL_EXCHANGE, // Data exchange between two devices in one remote process.
+ EXCHANGE_SRC, // SRC side of the exchange event between two remote processes.
+ EXCHANGE_DST, // DST side of the exchange event between two remote processes.
+
+ // Target region execution.
+ LAUNCH_KERNEL, // Executes a target region at the remote process.
+
+ // Local event used to wait on other events.
+ SYNC,
+
+ // Internal event system commands.
+ EXIT // Stops the event system execution at the remote process.
+};
+
+/// Coroutine events
+///
+/// Return object for the event system coroutines. This class works as an
+/// external handle for the coroutine execution, allowing anyone to: query for
+/// the coroutine completion, resume the coroutine and check its state.
+/// Moreover, this class allows for coroutines to be chainable, meaning a
+/// coroutine function may wait on the completion of another one by using the
+/// co_await operator, all through a single external handle.
+struct EventTy {
+ /// Internal event handle to access C++ coroutine states.
+ struct promise_type;
+ using CoHandleTy = std::coroutine_handle<promise_type>;
+ std::shared_ptr<void> HandlePtr;
+
+ /// Polling rate period (us) used by event handlers.
+ IntEnvar EventPollingRate;
+
+ /// Internal (and required) promise type. Allows for customization of the
+ /// coroutines behavior and to store custom data inside the coroutine itself.
+ struct promise_type {
+ /// Coroutines are chained as a reverse linked-list. The most-recent
+ /// coroutine in a chain points to the previous one and so on, until the
+ /// root (and first) coroutine, which then points to the most-recent one.
+ /// The root always refers to the coroutine stored in the external handle,
+ /// the only handle an external user have access to.
+ CoHandleTy PrevHandle;
+ CoHandleTy RootHandle;
+
+ /// Indicates if the coroutine was completed successfully. Contains the
+ /// appropriate error otherwise.
+ std::optional<llvm::Error> CoroutineError;
+
+ promise_type() : CoroutineError(std::nullopt) {
+ PrevHandle = RootHandle = CoHandleTy::from_promise(*this);
+ }
+
+ /// Event coroutines should always suspend upon creation and finalization.
+ std::suspend_always initial_suspend() { return {}; }
+ std::suspend_always final_suspend() noexcept { return {}; }
+
+ /// Coroutines should return llvm::Error::success() or an appropriate error
+ /// message.
+ void return_value(llvm::Error &&GivenError) noexcept {
+ CoroutineError = std::move(GivenError);
+ }
+
+ /// Any unhandled exception should create an externally visible error.
+ void unhandled_exception() {
+ assert(std::uncaught_exceptions() > 0 &&
+ "Function should only be called if an uncaught exception is "
+ "generated inside the coroutine");
+ CoroutineError = createError("Event generated an unhandled exception");
+ }
+
+ /// Returns the external coroutine handle from the promise object.
+ EventTy get_return_object() {
+ void *HandlePtr = CoHandleTy::from_promise(*this).address();
+ return {
+ std::shared_ptr<void>(HandlePtr,
+ [](void *HandlePtr) {
+ assert(HandlePtr);
+ CoHandleTy::from_address(HandlePtr).destroy();
+ }),
+ IntEnvar("OMPTARGET_EVENT_POLLING_RATE", 1)};
+ }
+ };
+
+ /// Returns the external coroutine handle from the event.
+ CoHandleTy getHandle() const {
+ return CoHandleTy::from_address(HandlePtr.get());
+ }
+
+ /// Execution handling.
+ /// Resume the coroutine execution up until the next suspension point.
+ void resume();
+
+ /// Blocks the caller thread until the coroutine is completed.
+ void wait();
+
+ /// Checks if the coroutine is completed or not.
+ bool done() const;
+
+ /// Coroutine state handling.
+ /// Checks if the coroutine is valid.
+ bool empty() const;
+
+ /// Get the returned error from the coroutine.
+ llvm::Error getError() const;
+
+ /// EventTy instances are also awaitables. This means one can link multiple
+ /// EventTy together by calling the co_await operator on one another. For this
+ /// to work, EventTy must implement the following three functions.
+ ///
+ /// Called on the new coroutine before suspending the current one on co_await.
+ /// If returns true, the new coroutine is already completed, thus it should
+ /// not be linked against the current one and the current coroutine can
+ /// continue without suspending.
+ bool await_ready() { return getHandle().done(); }
+
+ /// Called on the new coroutine when the current one is suspended. It is
+ /// responsible for chaining coroutines together.
+ void await_suspend(CoHandleTy SuspendedHandle) {
+ auto Handle = getHandle();
+ auto &CurrPromise = Handle.promise();
+ auto &SuspendedPromise = SuspendedHandle.promise();
+ auto &RootPromise = SuspendedPromise.RootHandle.promise();
+
+ CurrPromise.PrevHandle = SuspendedHandle;
+ CurrPromise.RootHandle = SuspendedPromise.RootHandle;
+
+ RootPromise.PrevHandle = Handle;
+ }
+
+ /// Called on the new coroutine when the current one is resumed. Used to
+ /// return errors when co_awaiting on other EventTy.
+ llvm::Error await_resume() {
+ auto &Error = getHandle().promise().CoroutineError;
+
+ if (Error) {
+ return std::move(*Error);
+ }
+
+ return llvm::Error::success();
+ }
+};
+
+/// Coroutine like manager for many non-blocking MPI calls. Allows for coroutine
+/// to co_await on the registered MPI requests.
+class MPIRequestManagerTy {
+ /// Target specification for the MPI messages.
+ const MPI_Comm Comm;
+ const int Tag;
+ /// Pending MPI requests.
+ llvm::SmallVector<MPI_Request> Requests;
+ /// Maximum buffer Size to use during data transfer.
+ Int64Envar MPIFragmentSize;
+
+public:
+ /// Target peer to send and receive messages
+ int OtherRank;
+
+ /// Target device in OtherRank
+ int DeviceId;
+
+ MPIRequestManagerTy(MPI_Comm Comm, int Tag, int OtherRank, int DeviceId,
+ llvm::SmallVector<MPI_Request> InitialRequests =
+ {}) // TODO: Change to initializer_list
+ : Comm(Comm), Tag(Tag), Requests(InitialRequests),
+ MPIFragmentSize("OMPTARGET_MPI_FRAGMENT_SIZE", 100e6),
+ OtherRank(OtherRank), DeviceId(DeviceId) {}
+
+ /// This class should not be copied.
+ MPIRequestManagerTy(const MPIRequestManagerTy &) = delete;
+ MPIRequestManagerTy &operator=(const MPIRequestManagerTy &) = delete;
+
+ MPIRequestManagerTy(MPIRequestManagerTy &&Other) noexcept
+ : Comm(Other.Comm), Tag(Other.Tag), Requests(Other.Requests),
+ MPIFragmentSize(Other.MPIFragmentSize), OtherRank(Other.OtherRank),
+ DeviceId(Other.DeviceId) {
+ Other.Requests = {};
+ }
+
+ MPIRequestManagerTy &operator=(MPIRequestManagerTy &&Other) = delete;
+
+ ~MPIRequestManagerTy();
+
+ /// Sends a buffer of given datatype items with determined size to target.
+ void send(const void *Buffer, int Size, MPI_Datatype Datatype);
+
+ /// Send a buffer with determined size to target in batchs.
+ void sendInBatchs(void *Buffer, int64_t Size);
+
+ /// Receives a buffer of given datatype items with determined size from
+ /// target.
+ void receive(void *Buffer, int Size, MPI_Datatype Datatype);
+
+ /// Receives a buffer with determined size from target in batchs.
+ void receiveInBatchs(void *Buffer, int64_t Size);
+
+ /// Coroutine that waits on all internal pending requests.
+ EventTy wait();
+};
+
+EventTy operator co_await(MPIRequestManagerTy &RequestManager);
+
+/// Data handle for host buffers in event. It keeps the host data even if the
+/// original buffer is deallocated before the event happens.
+using EventDataHandleTy = std::shared_ptr<void>;
+
+/// Index Pair used to identify the remote device
+using RemoteDeviceId = std::pair<int32_t, int32_t>;
+
+/// Routines to alloc/dealloc pinned host memory.
+///
+/// Allocate \p Size of host memory and returns its ptr.
+void *memAllocHost(int64_t Size);
+
+/// Deallocate the host memory pointered by \p HstPrt.
+int memFreeHost(void *HstPtr);
+
+/// Coroutine events created at the origin rank of the event.
+namespace OriginEvents {
+
+EventTy retrieveNumDevices(MPIRequestManagerTy RequestManager,
+ int32_t *NumDevices);
+EventTy isPluginCompatible(MPIRequestManagerTy RequestManager,
+ __tgt_device_image *Image, bool *QueryResult);
+EventTy isDeviceCompatible(MPIRequestManagerTy RequestManager,
+ __tgt_device_image *Image, bool *QueryResult);
+EventTy initDevice(MPIRequestManagerTy RequestManager, void **DevicePtr);
+EventTy initRecordReplay(MPIRequestManagerTy RequestManager, int64_t MemorySize,
+ void *VAddr, bool IsRecord, bool SaveOutput,
+ uint64_t *ReqPtrArgOffset);
+EventTy isDataExchangable(MPIRequestManagerTy RequestManager,
+ int32_t DstDeviceId, bool *QueryResult);
+EventTy allocateBuffer(MPIRequestManagerTy RequestManager, int64_t Size,
+ int32_t Kind, void **Buffer);
+EventTy deleteBuffer(MPIRequestManagerTy RequestManager, void *Buffer,
+ int32_t Kind);
+EventTy submit(MPIRequestManagerTy RequestManager, void *TgtPtr,
+ EventDataHandleTy HstPtr, int64_t Size,
+ __tgt_async_info *AsyncInfoPtr);
+EventTy retrieve(MPIRequestManagerTy RequestManager, int64_t Size, void *HstPtr,
+ void *TgtPtr, __tgt_async_info *AsyncInfoPtr);
+EventTy localExchange(MPIRequestManagerTy RequestManager, void *SrcPtr,
+ int DstDeviceId, void *DstPtr, int64_t Size,
+ __tgt_async_info *AsyncInfoPtr);
+EventTy exchange(MPIRequestManagerTy RequestManager, int SrcRank,
+ const void *OrgBuffer, int DstRank, void *DstBuffer,
+ int64_t Size, __tgt_async_info *AsyncInfoPtr);
+EventTy synchronize(MPIRequestManagerTy RequestManager,
+ __tgt_async_info *AsyncInfoPtr);
+EventTy sync(EventTy Event);
+EventTy loadBinary(MPIRequestManagerTy RequestManager,
+ const __tgt_device_image *Image,
+ __tgt_device_binary *Binary);
+EventTy getGlobal(MPIRequestManagerTy RequestManager,
+ __tgt_device_binary Binary, uint64_t Size, const char *Name,
+ void **DevicePtr);
+EventTy getFunction(MPIRequestManagerTy RequestManager,
+ __tgt_device_binary Binary, const char *Name,
+ void **KernelPtr);
+EventTy launchKernel(MPIRequestManagerTy RequestManager, void *TgtEntryPtr,
+ EventDataHandleTy TgtArgs, EventDataHandleTy TgtOffsets,
+ EventDataHandleTy KernelArgsHandle,
+ __tgt_async_info *AsyncInfoPtr);
+EventTy initAsyncInfo(MPIRequestManagerTy RequestManager,
+ __tgt_async_info **AsyncInfoPtr);
+EventTy initDeviceInfo(MPIRequestManagerTy RequestManager,
+ __tgt_device_info *DeviceInfo);
+EventTy queryAsync(MPIRequestManagerTy RequestManager,
+ __tgt_async_info *AsyncInfoPtr);
+EventTy printDeviceInfo(MPIRequestManagerTy RequestManager);
+EventTy dataLock(MPIRequestManagerTy RequestManager, void *Ptr, int64_t Size,
+ void **LockedPtr);
+EventTy dataUnlock(MPIRequestManagerTy RequestManager, void *Ptr);
+EventTy dataNotifyMapped(MPIRequestManagerTy RequestManager, void *HstPtr,
+ int64_t Size);
+EventTy dataNotifyUnmapped(MPIRequestManagerTy RequestManager, void *HstPtr);
+EventTy exit(MPIRequestManagerTy RequestManager);
+
+} // namespace OriginEvents
+
+/// Event Queue
+///
+/// Event queue for received events.
+class EventQueue {
+private:
+ /// Base internal queue.
+ std::queue<EventTy> Queue;
+ /// Base queue sync mutex.
+ std::mutex QueueMtx;
+
+ /// Conditional variables to block popping on an empty queue.
+ std::condition_variable_any CanPopCv;
+
+public:
+ /// Event Queue default constructor.
+ EventQueue();
+
+ /// Gets current queue size.
+ size_t size();
+
+ /// Push an event to the queue, resizing it when needed.
+ void push(EventTy &&Event);
+
+ /// Pops an event from the queue, waiting if the queue is empty. When stopped,
+ /// returns a nullptr event.
+ EventTy pop(std::stop_token &Stop);
+};
+
+/// Event System
+///
+/// MPI tags used in control messages.
+///
+/// Special tags values used to send control messages between event systems of
+/// different processes. When adding new tags, please summarize the tag usage
+/// with a side comment as done below.
+enum class ControlTagsTy : int {
+ EVENT_REQUEST = 0, // Used by event handlers to receive new event requests.
+ FIRST_EVENT // Tag used by the first event. Must always be placed last.
+};
+
+/// Event system execution state.
+///
+/// Describes the event system state through the program.
+enum class EventSystemStateTy {
+ CREATED, // ES was created but it is not ready to send or receive new
+ // events.
+ INITIALIZED, // ES was initialized alongside internal MPI states. It is ready
+ // to send new events, but not receive them.
+ RUNNING, // ES is running and ready to receive new events.
+ EXITED, // ES was stopped.
+ FINALIZED // ES was finalized and cannot run anything else.
+};
+
+/// The distributed event system.
+class EventSystemTy {
+ /// MPI definitions.
+ /// The largest MPI tag allowed by its implementation.
+ int32_t MPITagMaxValue = 0;
+
+ /// Communicator used by the gate thread and base communicator for the event
+ /// system.
+ MPI_Comm GateThreadComm = MPI_COMM_NULL;
+
+ /// Communicator pool distributed over the events. Many MPI implementations
+ /// allow for better network hardware parallelism when unrelated MPI messages
+ /// are exchanged over distinct communicators. Thus this pool will be given in
+ /// a round-robin fashion to each newly created event to better utilize the
+ /// hardware capabilities.
+ llvm::SmallVector<MPI_Comm> EventCommPool{};
+
+ /// Number of process used by the event system.
+ int WorldSize = -1;
+
+ /// The local rank of the current instance.
+ int LocalRank = -1;
+
+ /// Number of events created by the current instance so far. This is used to
+ /// generate unique MPI tags for each event.
+ std::atomic<int> EventCounter{0};
+
+ /// Event queue between the local gate thread and the event handlers. The exec
+ /// queue is responsible for only running the execution events, while the data
+ /// queue executes all the other ones. This allows for long running execution
+ /// events to not block any data transfers (which are all done in a
+ /// non-blocking fashion).
+ EventQueue ExecEventQueue{};
+ EventQueue DataEventQueue{};
+
+ /// Event System execution state.
+ std::atomic<EventSystemStateTy> EventSystemState{};
+
+ /// Number of communicators to be spawned and distributed for the events.
+ /// Allows for parallel use of network resources.
+ Int64Envar NumMPIComms;
+
+private:
+ /// Creates a new unique event tag for a new event.
+ int createNewEventTag();
+
+ /// Gets a comm for a new event from the comm pool.
+ MPI_Comm &getNewEventComm(int MPITag);
+
+ /// Creates a local MPI context containing a exclusive comm for the gate
+ /// thread, and a comm pool to be used internally by the events. It also
+ /// acquires the local MPI process description.
+ bool createLocalMPIContext();
+
+ /// Destroy the local MPI context and all of its comms.
+ bool destroyLocalMPIContext();
+
+public:
+ EventSystemTy();
+ ~EventSystemTy();
+
+ bool initialize();
+ bool is_initialized();
+ bool deinitialize();
+
+ /// Creates a new event.
+ ///
+ /// Creates a new event of 'EventClass' type targeting the 'DestRank'. The
+ /// 'args' parameters are additional arguments that may be passed to the
+ /// EventClass origin constructor.
+ ///
+ /// /note: since this is a template function, it must be defined in
+ /// this header.
+ template <class EventFuncTy, typename... ArgsTy>
+ requires std::invocable<EventFuncTy, MPIRequestManagerTy, ArgsTy...>
+ EventTy createEvent(EventFuncTy EventFunc, EventTypeTy EventType,
+ int DstDeviceID, ArgsTy... Args);
+
+ /// Create a new Exchange event.
+ ///
+ /// This function notifies \p SrcDevice and \p TargetDevice about the
+ /// transfer and creates a host event that waits until the transfer is
+ /// completed.
+ EventTy createExchangeEvent(int SrcDevice, const void *SrcBuffer,
+ int DstDevice, void *DstBuffer, int64_t Size,
+ __tgt_async_info *AsyncInfo);
+
+ /// Get the number of workers available.
+ ///
+ /// \return the number of MPI available workers.
+ int getNumWorkers() const;
+
+ /// Check if we are at the host MPI process.
+ ///
+ /// \return true if the current MPI process is the host (rank WorldSize-1),
+ /// false otherwise.
+ int isHost() const;
+
+ RemoteDeviceId mapDeviceId(int32_t DeviceId);
+
+ llvm::SmallVector<int> DevicesPerRemote{};
+
+ friend struct ProxyDevice;
+};
+
+template <class EventFuncTy, typename... ArgsTy>
+ requires std::invocable<EventFuncTy, MPIRequestManagerTy, ArgsTy...>
+EventTy EventSystemTy::createEvent(EventFuncTy EventFunc, EventTypeTy EventType,
+ int DstDeviceID, ArgsTy... Args) {
+ // Create event MPI request manager.
+ const int EventTag = createNewEventTag();
+ auto &EventComm = getNewEventComm(EventTag);
+
+ int32_t RemoteRank = DstDeviceID, RemoteDeviceId = -1;
+
+ if (EventType != EventTypeTy::IS_PLUGIN_COMPATIBLE &&
+ EventType != EventTypeTy::RETRIEVE_NUM_DEVICES &&
+ EventType != EventTypeTy::EXIT)
+ std::tie(RemoteRank, RemoteDeviceId) = mapDeviceId(DstDeviceID);
+
+ // Send new event notification.
+ int EventNotificationInfo[] = {static_cast<int>(EventType), EventTag,
+ RemoteDeviceId};
+ MPI_Request NotificationRequest = MPI_REQUEST_NULL;
+ int MPIError = MPI_Isend(EventNotificationInfo, 3, MPI_INT, RemoteRank,
+ static_cast<int>(ControlTagsTy::EVENT_REQUEST),
+ GateThreadComm, &NotificationRequest);
+
+ if (MPIError != MPI_SUCCESS)
+ co_return createError("MPI failed during event notification with error %d",
+ MPIError);
+
+ MPIRequestManagerTy RequestManager(EventComm, EventTag, RemoteRank,
+ RemoteDeviceId, {NotificationRequest});
+
+ co_return (co_await EventFunc(std::move(RequestManager), Args...));
+}
+
+#endif // _MPI_PROXY_EVENT_SYSTEM_H_
diff --git a/offload/plugins-nextgen/mpi/src/ProxyDevice.cpp b/offload/plugins-nextgen/mpi/src/ProxyDevice.cpp
new file mode 100644
index 00000000000000..219970856a0ea1
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/src/ProxyDevice.cpp
@@ -0,0 +1,1071 @@
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <tuple>
+
+#include "EventSystem.h"
+#include "RemotePluginManager.h"
+#include "Shared/APITypes.h"
+#include "mpi.h"
+#include "omptarget.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+
+#ifdef OMPT_SUPPORT
+#include "OpenMP/OMPT/Callback.h"
+#include "omp-tools.h"
+extern void llvm::omp::target::ompt::connectLibrary();
+#endif
+
+/// Class that holds a stage pointer for data transfer between host and remote
+/// device (RAII)
+struct PluginDataHandle {
+ void *HstPtr;
+ uint32_t Plugin;
+ uint32_t Device;
+ RemotePluginManager *PM;
+ PluginDataHandle(RemotePluginManager *PluginManager, uint32_t PluginId,
+ uint32_t DeviceId, uint64_t Size) {
+ Device = DeviceId;
+ Plugin = PluginId;
+ PM = PluginManager;
+ HstPtr = PM->Plugins[Plugin]->data_alloc(Device, Size, nullptr,
+ TARGET_ALLOC_HOST);
+ }
+ ~PluginDataHandle() {
+ PM->Plugins[Plugin]->data_delete(Device, HstPtr, TARGET_ALLOC_HOST);
+ }
+};
+
+/// Event Implementations on Device side.
+struct ProxyDevice {
+ ProxyDevice()
+ : NumExecEventHandlers("OMPTARGET_NUM_EXEC_EVENT_HANDLERS", 1),
+ NumDataEventHandlers("OMPTARGET_NUM_DATA_EVENT_HANDLERS", 1),
+ EventPollingRate("OMPTARGET_EVENT_POLLING_RATE", 1) {
+#ifdef OMPT_SUPPORT
+ // Initialize OMPT first
+ llvm::omp::target::ompt::connectLibrary();
+#endif
+
+ EventSystem.initialize();
+ PluginManager.init();
+ }
+
+ ~ProxyDevice() {
+ EventSystem.deinitialize();
+ PluginManager.deinit();
+ }
+
+ void mapDevicesPerRemote() {
+ EventSystem.DevicesPerRemote = {};
+ for (int PluginId = 0; PluginId < PluginManager.getNumUsedPlugins();
+ PluginId++) {
+ EventSystem.DevicesPerRemote.emplace_back(
+ PluginManager.getNumDevices(PluginId));
+ }
+ }
+
+ __tgt_async_info *MapAsyncInfo(void *HostAsyncInfoPtr) {
+ const std::lock_guard<std::mutex> Lock(TableMutex);
+ __tgt_async_info *TgtAsyncInfoPtr = nullptr;
+ if (AsyncInfoTable[HostAsyncInfoPtr])
+ TgtAsyncInfoPtr =
+ static_cast<__tgt_async_info *>(AsyncInfoTable[HostAsyncInfoPtr]);
+ else {
+ std::unique_ptr<__tgt_async_info> newEntry =
+ std::make_unique<__tgt_async_info>();
+ TgtAsyncInfoPtr = AsyncInfoList.emplace_back(std::move(newEntry)).get();
+ AsyncInfoTable[HostAsyncInfoPtr] = static_cast<void *>(TgtAsyncInfoPtr);
+ }
+
+ return TgtAsyncInfoPtr;
+ }
+
+ EventTy waitAsyncOpEnd(int32_t PluginId, int32_t DeviceId,
+ void *AsyncInfoPtr) {
+ auto *TgtAsyncInfo = MapAsyncInfo(AsyncInfoPtr);
+ auto *RPCServer =
+ PluginManager.Plugins[PluginId]->getDevice(DeviceId).getRPCServer();
+
+ while (TgtAsyncInfo->Queue != nullptr) {
+ if (PluginManager.Plugins[PluginId]->query_async(
+ DeviceId, TgtAsyncInfo) == OFFLOAD_FAIL)
+ co_return createError("Failure to wait AsyncOp\n");
+
+ if (RPCServer)
+ if (auto Err = RPCServer->runServer(
+ PluginManager.Plugins[PluginId]->getDevice(DeviceId)))
+ co_return Err;
+ co_await std::suspend_always{};
+ }
+
+ co_return llvm::Error::success();
+ }
+
+ EventTy retrieveNumDevices(MPIRequestManagerTy RequestManager) {
+ int32_t NumDevices = PluginManager.getNumDevices();
+ RequestManager.send(&NumDevices, 1, MPI_INT);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy isPluginCompatible(MPIRequestManagerTy RequestManager) {
+ __tgt_device_image Image;
+ bool QueryResult = false;
+
+ uint64_t Size = 0;
+
+ RequestManager.receive(&Size, 1, MPI_UINT64_T);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ Image.ImageStart = memAllocHost(Size);
+ RequestManager.receive(Image.ImageStart, Size, MPI_BYTE);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ Image.ImageEnd = (void *)((ptrdiff_t)(Image.ImageStart) + Size);
+
+ llvm::SmallVector<std::unique_ptr<GenericPluginTy>> UsedPlugins;
+
+ for (auto &Plugin : PluginManager.Plugins) {
+ QueryResult = Plugin->is_plugin_compatible(&Image);
+ if (QueryResult) {
+ UsedPlugins.emplace_back(std::move(Plugin));
+ break;
+ }
+ }
+
+ for (auto &Plugin : PluginManager.Plugins) {
+ if (Plugin)
+ UsedPlugins.emplace_back(std::move(Plugin));
+ }
+
+ PluginManager.Plugins = std::move(UsedPlugins);
+ mapDevicesPerRemote();
+
+ memFreeHost(Image.ImageStart);
+ RequestManager.send(&QueryResult, sizeof(bool), MPI_BYTE);
+ co_return (co_await RequestManager);
+ }
+
+ EventTy isDeviceCompatible(MPIRequestManagerTy RequestManager) {
+ __tgt_device_image Image;
+ bool QueryResult = false;
+
+ uint64_t Size = 0;
+
+ RequestManager.receive(&Size, 1, MPI_UINT64_T);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ Image.ImageStart = memAllocHost(Size);
+ RequestManager.receive(Image.ImageStart, Size, MPI_BYTE);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ Image.ImageEnd = (void *)((ptrdiff_t)(Image.ImageStart) + Size);
+
+ int32_t DeviceId, PluginId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ QueryResult =
+ PluginManager.Plugins[PluginId]->is_device_compatible(DeviceId, &Image);
+
+ memFreeHost(Image.ImageStart);
+ RequestManager.send(&QueryResult, sizeof(bool), MPI_BYTE);
+ co_return (co_await RequestManager);
+ }
+
+ EventTy initDevice(MPIRequestManagerTy RequestManager) {
+ int32_t DeviceId, PluginId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->init_device(DeviceId);
+
+ auto *DevicePtr = &PluginManager.Plugins[PluginId]->getDevice(DeviceId);
+
+ // Event completion notification
+ RequestManager.send(&DevicePtr, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy initRecordReplay(MPIRequestManagerTy RequestManager) {
+ int64_t MemorySize = 0;
+ void *VAddr = nullptr;
+ bool IsRecord = false, SaveOutput = false;
+ uint64_t ReqPtrArgOffset = 0;
+
+ RequestManager.receive(&MemorySize, 1, MPI_INT64_T);
+ RequestManager.receive(&VAddr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&IsRecord, sizeof(bool), MPI_BYTE);
+ RequestManager.receive(&SaveOutput, sizeof(bool), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t DeviceId, PluginId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->initialize_record_replay(
+ DeviceId, MemorySize, VAddr, IsRecord, SaveOutput, ReqPtrArgOffset);
+
+ RequestManager.send(&ReqPtrArgOffset, 1, MPI_UINT64_T);
+ co_return (co_await RequestManager);
+ }
+
+ EventTy isDataExchangable(MPIRequestManagerTy RequestManager) {
+ int32_t DstDeviceId = 0;
+ bool QueryResult = false;
+ RequestManager.receive(&DstDeviceId, 1, MPI_INT32_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t SrcDeviceId, PluginId;
+
+ std::tie(PluginId, SrcDeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ QueryResult = PluginManager.Plugins[PluginId]->isDataExchangable(
+ SrcDeviceId, DstDeviceId);
+
+ RequestManager.send(&QueryResult, sizeof(bool), MPI_BYTE);
+ co_return (co_await RequestManager);
+ }
+
+ EventTy allocateBuffer(MPIRequestManagerTy RequestManager) {
+ int64_t Size = 0;
+ int32_t Kind = 0;
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+ RequestManager.receive(&Kind, 1, MPI_INT32_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ void *Buffer = PluginManager.Plugins[PluginId]->data_alloc(DeviceId, Size,
+ nullptr, Kind);
+
+ RequestManager.send(&Buffer, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy deleteBuffer(MPIRequestManagerTy RequestManager) {
+ void *Buffer = nullptr;
+ int32_t Kind = 0;
+
+ RequestManager.receive(&Buffer, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Kind, 1, MPI_INT32_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->data_delete(DeviceId, Buffer, Kind);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy submit(MPIRequestManagerTy RequestManager) {
+ void *TgtPtr = nullptr, *HstAsyncInfoPtr = nullptr;
+ int64_t Size = 0;
+
+ RequestManager.receive(&HstAsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ auto *TgtAsyncInfo = MapAsyncInfo(HstAsyncInfoPtr);
+
+ RequestManager.receive(&TgtPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginDataHandle DataHandler(&PluginManager, PluginId, DeviceId, Size);
+ RequestManager.receiveInBatchs(DataHandler.HstPtr, Size);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ PluginManager.Plugins[PluginId]->data_submit_async(
+ DeviceId, TgtPtr, DataHandler.HstPtr, Size, TgtAsyncInfo);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy retrieve(MPIRequestManagerTy RequestManager) {
+ void *TgtPtr = nullptr, *HstAsyncInfoPtr = nullptr;
+ int64_t Size = 0;
+ bool DeviceOpStatus = true;
+
+ RequestManager.receive(&HstAsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ auto *TgtAsyncInfo = MapAsyncInfo(HstAsyncInfoPtr);
+ RequestManager.receive(&TgtPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginDataHandle DataHandler(&PluginManager, PluginId, DeviceId, Size);
+
+ PluginManager.Plugins[PluginId]->data_retrieve_async(
+ DeviceId, DataHandler.HstPtr, TgtPtr, Size, TgtAsyncInfo);
+
+ if (auto Error =
+ co_await waitAsyncOpEnd(PluginId, DeviceId, HstAsyncInfoPtr);
+ Error)
+ REPORT("Retrieve event failed with msg: %s\n",
+ toString(std::move(Error)).data());
+
+ RequestManager.send(&DeviceOpStatus, sizeof(bool), MPI_BYTE);
+
+ if (!DeviceOpStatus)
+ co_return (co_await RequestManager);
+
+ RequestManager.sendInBatchs(DataHandler.HstPtr, Size);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy exchange(MPIRequestManagerTy RequestManager) {
+ void *SrcPtr = nullptr, *DstPtr = nullptr;
+ int DstDeviceId = 0;
+ int64_t Size = 0;
+ void *HstAsyncInfoPtr = nullptr;
+
+ RequestManager.receive(&SrcPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&DstDeviceId, 1, MPI_INT);
+ RequestManager.receive(&DstPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+ RequestManager.receive(&HstAsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ auto *TgtAsyncInfo = MapAsyncInfo(HstAsyncInfoPtr);
+
+ int32_t PluginId, SrcDeviceId;
+
+ std::tie(PluginId, SrcDeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->data_exchange_async(
+ SrcDeviceId, SrcPtr, DstDeviceId, DstPtr, Size, TgtAsyncInfo);
+
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy exchangeSrc(MPIRequestManagerTy RequestManager) {
+ void *SrcBuffer, *HstAsyncInfoPtr = nullptr;
+ int64_t Size;
+ int DstRank;
+
+ // Save head node rank
+ int HeadNodeRank = RequestManager.OtherRank;
+
+ RequestManager.receive(&SrcBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+ RequestManager.receive(&DstRank, 1, MPI_INT);
+ RequestManager.receive(&HstAsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ auto *TgtAsyncInfo = MapAsyncInfo(HstAsyncInfoPtr);
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginDataHandle DataHandler(&PluginManager, PluginId, DeviceId, Size);
+
+ PluginManager.Plugins[PluginId]->data_retrieve_async(
+ DeviceId, DataHandler.HstPtr, SrcBuffer, Size, TgtAsyncInfo);
+
+ if (auto Error =
+ co_await waitAsyncOpEnd(PluginId, DeviceId, HstAsyncInfoPtr);
+ Error)
+ co_return Error;
+
+ // Set the Destination Rank in RequestManager
+ RequestManager.OtherRank = DstRank;
+
+ // Send buffer to target device
+ RequestManager.sendInBatchs(DataHandler.HstPtr, Size);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Set the HeadNode Rank to send the final notificatin
+ RequestManager.OtherRank = HeadNodeRank;
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy exchangeDst(MPIRequestManagerTy RequestManager) {
+ void *DstBuffer, *HstAsyncInfoPtr = nullptr;
+ int64_t Size;
+ // Save head node rank
+ int SrcRank, HeadNodeRank = RequestManager.OtherRank;
+
+ RequestManager.receive(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+ RequestManager.receive(&SrcRank, 1, MPI_INT);
+ RequestManager.receive(&HstAsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ auto *TgtAsyncInfo = MapAsyncInfo(HstAsyncInfoPtr);
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginDataHandle DataHandler(&PluginManager, PluginId, DeviceId, Size);
+
+ // Set the Source Rank in RequestManager
+ RequestManager.OtherRank = SrcRank;
+
+ // Receive buffer from the Source device
+ RequestManager.receiveInBatchs(DataHandler.HstPtr, Size);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ PluginManager.Plugins[PluginId]->data_submit_async(
+ DeviceId, DstBuffer, DataHandler.HstPtr, Size, TgtAsyncInfo);
+
+ if (auto Error =
+ co_await waitAsyncOpEnd(PluginId, DeviceId, HstAsyncInfoPtr);
+ Error)
+ co_return Error;
+
+ // Set the HeadNode Rank to send the final notificatin
+ RequestManager.OtherRank = HeadNodeRank;
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy launchKernel(MPIRequestManagerTy RequestManager) {
+ void *TgtEntryPtr = nullptr, *HostAsyncInfoPtr = nullptr;
+ KernelArgsTy KernelArgs;
+
+ llvm::SmallVector<void *> TgtArgs;
+ llvm::SmallVector<ptrdiff_t> TgtOffsets;
+
+ uint32_t NumArgs = 0;
+
+ RequestManager.receive(&NumArgs, 1, MPI_UINT32_T);
+ RequestManager.receive(&HostAsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ auto *TgtAsyncInfo = MapAsyncInfo(HostAsyncInfoPtr);
+
+ TgtArgs.resize(NumArgs);
+ TgtOffsets.resize(NumArgs);
+
+ RequestManager.receive(&TgtEntryPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(TgtArgs.data(), NumArgs * sizeof(void *), MPI_BYTE);
+ RequestManager.receive(TgtOffsets.data(), NumArgs * sizeof(ptrdiff_t),
+ MPI_BYTE);
+
+ RequestManager.receive(&KernelArgs, sizeof(KernelArgsTy), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->launch_kernel(
+ DeviceId, TgtEntryPtr, TgtArgs.data(), TgtOffsets.data(), &KernelArgs,
+ TgtAsyncInfo);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy loadBinary(MPIRequestManagerTy RequestManager) {
+ // Receive the target table sizes.
+ size_t ImageSize = 0;
+ size_t EntryCount = 0;
+ RequestManager.receive(&ImageSize, 1, MPI_UINT64_T);
+ RequestManager.receive(&EntryCount, 1, MPI_UINT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ llvm::SmallVector<size_t> EntryNameSizes(EntryCount);
+
+ RequestManager.receive(EntryNameSizes.begin(), EntryCount, MPI_UINT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Create the device name with the appropriate sizes and receive its
+ // content.
+ DeviceImage *Image = &RemoteImages.emplace_back(ImageSize, EntryCount);
+
+ Image->setImageEntries(EntryNameSizes);
+
+ // Received the image bytes and the table entries.
+ RequestManager.receive(Image->ImageStart, ImageSize, MPI_BYTE);
+
+ for (size_t I = 0; I < EntryCount; I++) {
+ RequestManager.receive(&Image->Entries[I].addr, 1, MPI_UINT64_T);
+ RequestManager.receive(Image->Entries[I].name, EntryNameSizes[I],
+ MPI_CHAR);
+ RequestManager.receive(&Image->Entries[I].size, 1, MPI_UINT64_T);
+ RequestManager.receive(&Image->Entries[I].flags, 1, MPI_INT32_T);
+ RequestManager.receive(&Image->Entries[I].data, 1, MPI_INT32_T);
+ }
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ __tgt_device_binary Binary;
+
+ PluginManager.Plugins[PluginId]->load_binary(DeviceId, Image, &Binary);
+
+ RequestManager.send(&Binary.handle, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy getGlobal(MPIRequestManagerTy RequestManager) {
+ __tgt_device_binary Binary;
+ uint64_t Size = 0;
+ llvm::SmallVector<char> Name;
+ void *DevicePtr = nullptr;
+ uint32_t NameSize = 0;
+
+ RequestManager.receive(&Binary.handle, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_UINT64_T);
+ RequestManager.receive(&NameSize, 1, MPI_UINT32_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ Name.resize(NameSize);
+ RequestManager.receive(Name.data(), NameSize, MPI_CHAR);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->get_global(Binary, Size, Name.data(),
+ &DevicePtr);
+
+ RequestManager.send(&DevicePtr, sizeof(void *), MPI_BYTE);
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+ }
+
+ EventTy getFunction(MPIRequestManagerTy RequestManager) {
+ __tgt_device_binary Binary;
+ uint32_t Size = 0;
+ llvm::SmallVector<char> Name;
+ void *KernelPtr = nullptr;
+
+ RequestManager.receive(&Binary.handle, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_UINT32_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ Name.resize(Size);
+ RequestManager.receive(Name.data(), Size, MPI_CHAR);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->get_function(Binary, Name.data(),
+ &KernelPtr);
+
+ RequestManager.send(&KernelPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+ }
+
+ EventTy synchronize(MPIRequestManagerTy RequestManager) {
+ void *HstAsyncInfoPtr = nullptr;
+ bool DeviceOpStatus = true;
+
+ RequestManager.receive(&HstAsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ if (auto Error =
+ co_await waitAsyncOpEnd(PluginId, DeviceId, HstAsyncInfoPtr);
+ Error)
+ REPORT("Synchronize event failed with msg: %s\n",
+ toString(std::move(Error)).data());
+
+ RequestManager.send(&DeviceOpStatus, sizeof(bool), MPI_BYTE);
+
+ if (!DeviceOpStatus)
+ co_return (co_await RequestManager);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy initAsyncInfo(MPIRequestManagerTy RequestManager) {
+ __tgt_async_info *TgtAsyncInfoPtr = nullptr;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->init_async_info(DeviceId,
+ &TgtAsyncInfoPtr);
+
+ RequestManager.send(&TgtAsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy initDeviceInfo(MPIRequestManagerTy RequestManager) {
+ __tgt_device_info DeviceInfo;
+ const char *ErrStr = nullptr;
+
+ RequestManager.receive(&DeviceInfo, sizeof(__tgt_device_info), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->init_device_info(DeviceId, &DeviceInfo,
+ &ErrStr);
+
+ RequestManager.send(&DeviceInfo, sizeof(__tgt_device_info), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy queryAsync(MPIRequestManagerTy RequestManager) {
+ void *HstAsyncInfoPtr = nullptr;
+
+ RequestManager.receive(&HstAsyncInfoPtr, sizeof(void *), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ auto *TgtAsyncInfo = MapAsyncInfo(HstAsyncInfoPtr);
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->query_async(DeviceId, TgtAsyncInfo);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ EventTy printDeviceInfo(MPIRequestManagerTy RequestManager) {
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->print_device_info(DeviceId);
+
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+ }
+
+ EventTy dataLock(MPIRequestManagerTy RequestManager) {
+ void *Ptr = nullptr;
+ int64_t Size = 0;
+ void *LockedPtr = nullptr;
+
+ RequestManager.receive(&Ptr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->data_lock(DeviceId, Ptr, Size, &LockedPtr);
+
+ RequestManager.send(&LockedPtr, sizeof(void *), MPI_BYTE);
+ co_return (co_await RequestManager);
+ }
+
+ EventTy dataUnlock(MPIRequestManagerTy RequestManager) {
+ void *Ptr = nullptr;
+ RequestManager.receive(&Ptr, sizeof(void *), MPI_BYTE);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->data_unlock(DeviceId, Ptr);
+
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+ }
+
+ EventTy dataNotifyMapped(MPIRequestManagerTy RequestManager) {
+ void *HstPtr = nullptr;
+ int64_t Size = 0;
+ RequestManager.receive(&HstPtr, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->data_notify_mapped(DeviceId, HstPtr, Size);
+
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+ }
+
+ EventTy dataNotifyUnmapped(MPIRequestManagerTy RequestManager) {
+ void *HstPtr = nullptr;
+ RequestManager.receive(&HstPtr, sizeof(void *), MPI_BYTE);
+
+ if (auto Err = co_await RequestManager; Err)
+ co_return Err;
+
+ int32_t PluginId, DeviceId;
+
+ std::tie(PluginId, DeviceId) =
+ EventSystem.mapDeviceId(RequestManager.DeviceId);
+
+ PluginManager.Plugins[PluginId]->data_notify_unmapped(DeviceId, HstPtr);
+
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+ }
+
+ EventTy exit(MPIRequestManagerTy RequestManager,
+ std::atomic<EventSystemStateTy> &EventSystemState) {
+ EventSystemStateTy OldState =
+ EventSystemState.exchange(EventSystemStateTy::EXITED);
+ assert(OldState != EventSystemStateTy::EXITED &&
+ "Exit event received multiple times");
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+ }
+
+ /// Function executed by the event handler threads.
+ void runEventHandler(std::stop_token Stop, EventQueue &Queue) {
+ while (EventSystem.EventSystemState == EventSystemStateTy::RUNNING ||
+ Queue.size() > 0) {
+ EventTy Event = Queue.pop(Stop);
+
+ // Re-checks the stop condition when no event was found.
+ if (Event.empty()) {
+ continue;
+ }
+
+ Event.resume();
+
+ if (!Event.done()) {
+ Queue.push(std::move(Event));
+ continue;
+ }
+
+ auto Error = Event.getError();
+ if (Error)
+ REPORT("Internal event failed with msg: %s\n",
+ toString(std::move(Error)).data());
+ }
+ }
+
+ /// Gate thread procedure.
+ ///
+ /// Caller thread will spawn the event handlers, execute the gate logic and
+ /// wait until the event system receive an Exit event.
+ void runGateThread() {
+ // Device image to be used by this gate thread.
+ // DeviceImage Image;
+
+ // Updates the event state and
+ EventSystem.EventSystemState = EventSystemStateTy::RUNNING;
+
+ // Spawns the event handlers.
+ llvm::SmallVector<std::jthread> EventHandlers;
+ EventHandlers.resize(NumExecEventHandlers.get() +
+ NumDataEventHandlers.get());
+ int EventHandlersSize = EventHandlers.size();
+ auto HandlerFunction = std::bind_front(&ProxyDevice::runEventHandler, this);
+ for (int Idx = 0; Idx < EventHandlersSize; Idx++) {
+ EventHandlers[Idx] = std::jthread(
+ HandlerFunction, std::ref(Idx < NumExecEventHandlers.get()
+ ? EventSystem.ExecEventQueue
+ : EventSystem.DataEventQueue));
+ }
+
+ // Executes the gate thread logic
+ while (EventSystem.EventSystemState == EventSystemStateTy::RUNNING) {
+ // Checks for new incoming event requests.
+ MPI_Message EventReqMsg;
+ MPI_Status EventStatus;
+ int HasReceived = false;
+ MPI_Improbe(MPI_ANY_SOURCE,
+ static_cast<int>(ControlTagsTy::EVENT_REQUEST),
+ EventSystem.GateThreadComm, &HasReceived, &EventReqMsg,
+ MPI_STATUS_IGNORE);
+
+ // If none was received, wait for `EVENT_POLLING_RATE`us for the next
+ // check.
+ if (!HasReceived) {
+ std::this_thread::sleep_for(
+ std::chrono::microseconds(EventPollingRate.get()));
+ continue;
+ }
+
+ // Acquires the event information from the received request, which are:
+ // - Event type
+ // - Event tag
+ // - Target comm
+ // - Event source rank
+ int EventInfo[3];
+ MPI_Mrecv(EventInfo, 3, MPI_INT, &EventReqMsg, &EventStatus);
+ const auto NewEventType = static_cast<EventTypeTy>(EventInfo[0]);
+ MPIRequestManagerTy RequestManager(
+ EventSystem.getNewEventComm(EventInfo[1]), EventInfo[1],
+ EventStatus.MPI_SOURCE, EventInfo[2]);
+
+ // Creates a new receive event of 'event_type' type.
+ using enum EventTypeTy;
+ EventTy NewEvent;
+ switch (NewEventType) {
+ case RETRIEVE_NUM_DEVICES:
+ NewEvent = retrieveNumDevices(std::move(RequestManager));
+ break;
+ case IS_PLUGIN_COMPATIBLE:
+ NewEvent = isPluginCompatible(std::move(RequestManager));
+ break;
+ case IS_DEVICE_COMPATIBLE:
+ NewEvent = isDeviceCompatible(std::move(RequestManager));
+ break;
+ case INIT_DEVICE:
+ NewEvent = initDevice(std::move(RequestManager));
+ break;
+ case INIT_RECORD_REPLAY:
+ NewEvent = initRecordReplay(std::move(RequestManager));
+ break;
+ case IS_DATA_EXCHANGABLE:
+ NewEvent = isDataExchangable(std::move(RequestManager));
+ break;
+ case ALLOC:
+ NewEvent = allocateBuffer(std::move(RequestManager));
+ break;
+ case DELETE:
+ NewEvent = deleteBuffer(std::move(RequestManager));
+ break;
+ case SUBMIT:
+ NewEvent = submit(std::move(RequestManager));
+ break;
+ case RETRIEVE:
+ NewEvent = retrieve(std::move(RequestManager));
+ break;
+ case LOCAL_EXCHANGE:
+ NewEvent = exchange(std::move(RequestManager));
+ break;
+ case EXCHANGE_SRC:
+ NewEvent = exchangeSrc(std::move(RequestManager));
+ break;
+ case EXCHANGE_DST:
+ NewEvent = exchangeDst(std::move(RequestManager));
+ break;
+ case EXIT:
+ NewEvent =
+ exit(std::move(RequestManager), EventSystem.EventSystemState);
+ break;
+ case LOAD_BINARY:
+ NewEvent = loadBinary(std::move(RequestManager));
+ break;
+ case GET_GLOBAL:
+ NewEvent = getGlobal(std::move(RequestManager));
+ break;
+ case GET_FUNCTION:
+ NewEvent = getFunction(std::move(RequestManager));
+ break;
+ case LAUNCH_KERNEL:
+ NewEvent = launchKernel(std::move(RequestManager));
+ break;
+ case SYNCHRONIZE:
+ NewEvent = synchronize(std::move(RequestManager));
+ break;
+ case INIT_ASYNC_INFO:
+ NewEvent = initAsyncInfo(std::move(RequestManager));
+ break;
+ case INIT_DEVICE_INFO:
+ NewEvent = initDeviceInfo(std::move(RequestManager));
+ break;
+ case QUERY_ASYNC:
+ NewEvent = queryAsync(std::move(RequestManager));
+ break;
+ case PRINT_DEVICE_INFO:
+ NewEvent = printDeviceInfo(std::move(RequestManager));
+ break;
+ case DATA_LOCK:
+ NewEvent = dataLock(std::move(RequestManager));
+ break;
+ case DATA_UNLOCK:
+ NewEvent = dataUnlock(std::move(RequestManager));
+ break;
+ case DATA_NOTIFY_MAPPED:
+ NewEvent = dataNotifyMapped(std::move(RequestManager));
+ break;
+ case DATA_NOTIFY_UNMAPPED:
+ NewEvent = dataNotifyUnmapped(std::move(RequestManager));
+ break;
+ case SYNC:
+ assert(false && "Trying to create a local event on a remote node");
+ }
+
+ if (NewEventType == LAUNCH_KERNEL) {
+ EventSystem.ExecEventQueue.push(std::move(NewEvent));
+ } else {
+ EventSystem.DataEventQueue.push(std::move(NewEvent));
+ }
+ }
+
+ assert(EventSystem.EventSystemState == EventSystemStateTy::EXITED &&
+ "Event State should be EXITED after receiving an Exit event");
+ }
+
+private:
+ llvm::SmallVector<std::unique_ptr<__tgt_async_info>, 16> AsyncInfoList;
+ llvm::SmallVector<DeviceImage, 1> RemoteImages;
+ llvm::DenseMap<void *, void *> AsyncInfoTable;
+ RemotePluginManager PluginManager;
+ EventSystemTy EventSystem;
+ /// Number of execute event handlers to spawn.
+ IntEnvar NumExecEventHandlers;
+ /// Number of data event handlers to spawn.
+ IntEnvar NumDataEventHandlers;
+ /// Polling rate period (us) used by event handlers.
+ IntEnvar EventPollingRate;
+
+ // Mutex for AsyncInfoTable
+ std::mutex TableMutex;
+};
+
+int main(int argc, char **argv) {
+ ProxyDevice PD;
+ PD.runGateThread();
+ return 0;
+}
\ No newline at end of file
diff --git a/offload/plugins-nextgen/mpi/src/RemotePluginManager.cpp b/offload/plugins-nextgen/mpi/src/RemotePluginManager.cpp
new file mode 100644
index 00000000000000..0c65697b082f73
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/src/RemotePluginManager.cpp
@@ -0,0 +1,104 @@
+//===-- RemotePluginManager.cpp - Plugin loading and communication API ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Functionality for handling plugins.
+//
+//===----------------------------------------------------------------------===//
+
+#include "RemotePluginManager.h"
+#include "Shared/Debug.h"
+#include "Shared/Profile.h"
+
+#include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <cassert>
+#include <cstdint>
+#include <cstdio>
+#include <memory>
+
+using namespace llvm;
+using namespace llvm::sys;
+
+// Every plugin exports this method to create an instance of the plugin type.
+#define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
+#include "Shared/RemoteTargets.def"
+
+void RemotePluginManager::init() {
+ TIMESCOPE();
+ DP("Loading RTLs...\n");
+
+ // Attempt to create an instance of each supported plugin.
+#define PLUGIN_TARGET(Name) \
+ do { \
+ Plugins.emplace_back( \
+ std::unique_ptr<GenericPluginTy>(createPlugin_##Name())); \
+ } while (false);
+#include "Shared/RemoteTargets.def"
+
+ DP("RTLs loaded!\n");
+}
+
+void RemotePluginManager::deinit() {
+ TIMESCOPE();
+ DP("Unloading RTLs...\n");
+
+ for (auto &Plugin : Plugins) {
+ if (auto Err = Plugin->deinit()) {
+ [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
+ DP("Failed to deinit plugin: %s\n", InfoMsg.c_str());
+ }
+ Plugin.release();
+ }
+
+ DP("RTLs unloaded!\n");
+}
+
+void RemotePluginManager::initDevices(GenericPluginTy &RTL) {
+ int32_t NumDevices = RTL.getNumDevices();
+ int32_t Ret;
+ for (int32_t DeviceID = 0; DeviceID < NumDevices; DeviceID++) {
+ Ret = RTL.init_device(DeviceID);
+ if (Ret != OFFLOAD_SUCCESS)
+ DP("Failed to initialize device %d\n", DeviceID);
+ }
+}
+
+void RemotePluginManager::initAllPlugins() {
+ for (auto &R : Plugins)
+ initDevices(*R);
+}
+
+/// Return the number of usable devices.
+int RemotePluginManager::getNumDevices() {
+ int32_t NumDevices = 0;
+ for (auto &Plugin : Plugins) {
+ if (!Plugin->is_initialized()) {
+ if (auto Err = Plugin->init()) {
+ [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
+ DP("Failed to init plugin: %s\n", InfoMsg.c_str());
+ continue;
+ }
+ DP("Registered plugin %s with %d visible device(s)\n", Plugin->getName(),
+ Plugin->number_of_devices());
+ }
+ NumDevices += Plugin->number_of_devices();
+ }
+ return NumDevices;
+}
+
+int RemotePluginManager::getNumDevices(int32_t PluginId) {
+ int32_t NumPlugins = getNumUsedPlugins();
+ assert(PluginId < NumPlugins && "Invalid PluginId");
+ if (!Plugins[PluginId]->is_initialized()) {
+ if (auto Err = Plugins[PluginId]->init()) {
+ [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
+ DP("Failed to init plugin: %s\n", InfoMsg.c_str());
+ }
+ }
+ return Plugins[PluginId]->number_of_devices();
+}
diff --git a/offload/plugins-nextgen/mpi/src/RemotePluginManager.h b/offload/plugins-nextgen/mpi/src/RemotePluginManager.h
new file mode 100644
index 00000000000000..fd6466489f981d
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/src/RemotePluginManager.h
@@ -0,0 +1,123 @@
+//===-- ProxyRemotePluginManager.h - Remote Plugin Manager ------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Declarations for managing remote devices that are handled by MPI Plugin.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef REMOTE_PLUGIN_MANAGER_H
+#define REMOTE_PLUGIN_MANAGER_H
+
+#include "PluginInterface.h"
+#include "Shared/APITypes.h"
+#include "Shared/Utils.h"
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Support/DynamicLibrary.h"
+#include "llvm/Support/Error.h"
+
+#include <cstdint>
+#include <list>
+#include <memory>
+#include <mutex>
+#include <string>
+
+using llvm::sys::DynamicLibrary;
+
+using GenericPluginTy = llvm::omp::target::plugin::GenericPluginTy;
+using GenericDeviceTy = llvm::omp::target::plugin::GenericDeviceTy;
+
+/// Device Image Storage. This class is used to store Device Image data
+/// in the remote device process.
+struct DeviceImage : __tgt_device_image {
+ llvm::SmallVector<unsigned char, 1> ImageBuffer;
+ llvm::SmallVector<__tgt_offload_entry, 16> Entries;
+ llvm::SmallVector<char> FlattenedEntryNames;
+
+ DeviceImage() {
+ ImageStart = nullptr;
+ ImageEnd = nullptr;
+ EntriesBegin = nullptr;
+ EntriesEnd = nullptr;
+ }
+
+ DeviceImage(size_t ImageSize, size_t EntryCount)
+ : ImageBuffer(ImageSize + alignof(void *)), Entries(EntryCount) {
+ // Align the image buffer to alignof(void *).
+ ImageStart = ImageBuffer.begin();
+ std::align(alignof(void *), ImageSize, ImageStart, ImageSize);
+ ImageEnd = (void *)((size_t)ImageStart + ImageSize);
+ }
+
+ void setImageEntries(llvm::SmallVector<size_t> EntryNameSizes) {
+ // Adjust the entry names to use the flattened name buffer.
+ size_t EntryCount = Entries.size();
+ size_t TotalNameSize = 0;
+ for (size_t I = 0; I < EntryCount; I++) {
+ TotalNameSize += EntryNameSizes[I];
+ }
+ FlattenedEntryNames.resize(TotalNameSize);
+
+ for (size_t I = EntryCount; I > 0; I--) {
+ TotalNameSize -= EntryNameSizes[I - 1];
+ Entries[I - 1].name = &FlattenedEntryNames[TotalNameSize];
+ }
+
+ // Set the entries pointers.
+ EntriesBegin = Entries.begin();
+ EntriesEnd = Entries.end();
+ }
+
+ /// Get the image size.
+ size_t getSize() const { return utils::getPtrDiff(ImageEnd, ImageStart); }
+
+ /// Getter and setter for the dynamic library.
+ DynamicLibrary &getDynamicLibrary() { return DynLib; }
+ void setDynamicLibrary(const DynamicLibrary &Lib) { DynLib = Lib; }
+
+private:
+ DynamicLibrary DynLib;
+};
+
+/// Struct for the data required to handle plugins
+struct RemotePluginManager {
+
+ RemotePluginManager() {}
+
+ void init();
+
+ void deinit();
+
+ /// Initialize as many devices as possible for this plugin. Devices that fail
+ /// to initialize are ignored.
+ void initDevices(GenericPluginTy &RTL);
+
+ /// Return the number of usable devices.
+ int getNumDevices();
+
+ int getNumDevices(int32_t PluginId);
+
+ int getNumUsedPlugins() const { return Plugins.size(); }
+
+ // Initialize all plugins.
+ void initAllPlugins();
+
+ /// Iterator range for all plugins (in use or not, but always valid).
+ auto plugins() { return llvm::make_pointee_range(Plugins); }
+
+ auto getPlugin(int32_t PluginId) { return &Plugins[PluginId]; }
+
+ // List of all plugins, in use or not.
+ llvm::SmallVector<std::unique_ptr<GenericPluginTy>> Plugins;
+};
+
+#endif // REMOTE_PLUGIN_MANAGER_H
diff --git a/offload/plugins-nextgen/mpi/src/RemoteTargets.def.in b/offload/plugins-nextgen/mpi/src/RemoteTargets.def.in
new file mode 100644
index 00000000000000..b6774d225854d1
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/src/RemoteTargets.def.in
@@ -0,0 +1,20 @@
+//===-- Shared/Targets.def - Target plugin enumerator -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Enumerates over all of the supported target plugins that are available to
+// the offloading library.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PLUGIN_TARGET
+# error Please define the macro PLUGIN_TARGET(TargetName)
+#endif
+
+ at REMOTE_MPI_ENUM_PLUGIN_TARGETS@
+
+#undef PLUGIN_TARGET
diff --git a/offload/plugins-nextgen/mpi/src/rtl.cpp b/offload/plugins-nextgen/mpi/src/rtl.cpp
new file mode 100644
index 00000000000000..5c744bbaa1f8a4
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/src/rtl.cpp
@@ -0,0 +1,1309 @@
+//===------RTLs/mpi/src/rtl.cpp - Target RTLs Implementation - C++ ------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// RTL NextGen for MPI applications
+//
+//===----------------------------------------------------------------------===//
+
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <list>
+#include <optional>
+#include <string>
+#include <thread>
+#include <tuple>
+
+#include "Shared/APITypes.h"
+#include "Shared/Debug.h"
+#include "Utils/ELF.h"
+
+#include "EventSystem.h"
+#include "GlobalHandler.h"
+#include "OpenMP/OMPT/Callback.h"
+#include "PluginInterface.h"
+#include "omptarget.h"
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/BinaryFormat/ELF.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
+#include "llvm/Support/Error.h"
+
+#if !defined(__BYTE_ORDER__) || !defined(__ORDER_LITTLE_ENDIAN__) || \
+ !defined(__ORDER_BIG_ENDIAN__)
+#error "Missing preprocessor definitions for endianness detection."
+#endif
+
+#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+#define LITTLEENDIAN_CPU
+#elif defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+#define BIGENDIAN_CPU
+#endif
+
+namespace llvm::omp::target::plugin {
+
+/// Forward declarations for all specialized data structures.
+struct MPIPluginTy;
+struct MPIDeviceTy;
+struct MPIDeviceImageTy;
+struct MPIKernelTy;
+class MPIGlobalHandlerTy;
+
+// TODO: Should this be defined inside the EventSystem?
+using MPIEventQueue = std::list<EventTy>;
+using MPIEventQueuePtr = MPIEventQueue *;
+
+/// Class implementing the MPI device images properties.
+struct MPIDeviceImageTy : public DeviceImageTy {
+ /// Create the MPI image with the id and the target image pointer.
+ MPIDeviceImageTy(int32_t ImageId, GenericDeviceTy &Device,
+ const __tgt_device_image *TgtImage)
+ : DeviceImageTy(ImageId, Device, TgtImage), DeviceImageAddrs(getSize()) {}
+
+ llvm::SmallVector<void *> DeviceImageAddrs;
+};
+
+class MPIGlobalHandlerTy final : public GenericGlobalHandlerTy {
+public:
+ Error getGlobalMetadataFromDevice(GenericDeviceTy &GenericDevice,
+ DeviceImageTy &Image,
+ GlobalTy &DeviceGlobal) override {
+ const char *GlobalName = DeviceGlobal.getName().data();
+ MPIDeviceImageTy &MPIImage = static_cast<MPIDeviceImageTy &>(Image);
+
+ if (GlobalName == nullptr) {
+ return Plugin::error("Failed to get name for global %p", &DeviceGlobal);
+ }
+
+ void *EntryAddress = nullptr;
+
+ __tgt_offload_entry *Begin = MPIImage.getTgtImage()->EntriesBegin;
+ __tgt_offload_entry *End = MPIImage.getTgtImage()->EntriesEnd;
+
+ int I = 0;
+ for (auto &Entry = Begin; Entry < End; ++Entry) {
+ if (!strcmp(Entry->name, GlobalName)) {
+ EntryAddress = MPIImage.DeviceImageAddrs[I];
+ break;
+ }
+ I++;
+ }
+
+ if (EntryAddress == nullptr) {
+ return Plugin::error("Failed to find global %s", GlobalName);
+ }
+
+ // Save the pointer to the symbol.
+ DeviceGlobal.setPtr(EntryAddress);
+
+ return Plugin::success();
+ }
+};
+
+struct MPIKernelTy : public GenericKernelTy {
+ /// Construct the kernel with a name and an execution mode.
+ MPIKernelTy(const char *Name) : GenericKernelTy(Name), Func(nullptr) {}
+
+ /// Initialize the kernel.
+ Error initImpl(GenericDeviceTy &Device, DeviceImageTy &Image) override {
+ // Functions have zero size.
+ GlobalTy Global(getName(), 0);
+
+ // Get the metadata (address) of the kernel function.
+ GenericGlobalHandlerTy &GHandler = Device.Plugin.getGlobalHandler();
+ if (auto Err = GHandler.getGlobalMetadataFromDevice(Device, Image, Global))
+ return Err;
+
+ // Check that the function pointer is valid.
+ if (!Global.getPtr())
+ return Plugin::error("Invalid function for kernel %s", getName());
+
+ // Save the function pointer.
+ Func = (void (*)())Global.getPtr();
+
+ // TODO: Check which settings are appropriate for the mpi plugin
+ // for now we are using the Elf64 plugin configuration
+ KernelEnvironment.Configuration.ExecMode = OMP_TGT_EXEC_MODE_GENERIC;
+ KernelEnvironment.Configuration.MayUseNestedParallelism = /* Unknown */ 2;
+ KernelEnvironment.Configuration.UseGenericStateMachine = /* Unknown */ 2;
+
+ // Set the maximum number of threads to a single.
+ MaxNumThreads = 1;
+ return Plugin::success();
+ }
+
+ /// Launch the kernel.
+ Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
+ uint64_t NumBlocks, KernelArgsTy &KernelArgs,
+ KernelLaunchParamsTy LaunchParams,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
+
+private:
+ /// The kernel function to execute.
+ void (*Func)(void);
+};
+
+/// MPI resource reference and queue. These are the objects handled by the
+/// MPIQueue Manager for the MPI plugin.
+template <typename ResourceTy>
+struct MPIResourceRef final : public GenericDeviceResourceRef {
+
+ /// The underlying handler type for the resource.
+ using HandleTy = ResourceTy *;
+
+ /// Create a empty reference to an invalid resource.
+ MPIResourceRef() : Resource(nullptr) {}
+
+ /// Create a reference to an existing resource.
+ MPIResourceRef(HandleTy Queue) : Resource(Queue) {}
+
+ /// Create a new resource and save the reference.
+ Error create(GenericDeviceTy &Device) override {
+ if (Resource)
+ return Plugin::error("Recreating an existing resource");
+
+ Resource = new ResourceTy;
+ if (!Resource)
+ return Plugin::error("Failed to allocated a new resource");
+
+ return Plugin::success();
+ }
+
+ /// Destroy the resource and invalidate the reference.
+ Error destroy(GenericDeviceTy &Device) override {
+ if (!Resource)
+ return Plugin::error("Destroying an invalid resource");
+
+ delete Resource;
+ Resource = nullptr;
+
+ return Plugin::success();
+ }
+
+ operator HandleTy() const { return Resource; }
+
+private:
+ HandleTy Resource;
+};
+
+/// Class implementing the device functionalities for remote x86_64 processes.
+struct MPIDeviceTy : public GenericDeviceTy {
+ /// Create a MPI Device with a device id and the default MPI grid values.
+ MPIDeviceTy(GenericPluginTy &Plugin, int32_t DeviceId, int32_t NumDevices)
+ : GenericDeviceTy(Plugin, DeviceId, NumDevices, MPIGridValues),
+ MPIEventQueueManager(*this), MPIEventManager(*this) {}
+
+ /// Initialize the device, its resources and get its properties.
+ Error initImpl(GenericPluginTy &Plugin) override {
+ if (auto Err = MPIEventQueueManager.init(OMPX_InitialNumStreams))
+ return Err;
+
+ if (auto Err = MPIEventManager.init(OMPX_InitialNumEvents))
+ return Err;
+
+ return Plugin::success();
+ }
+
+ /// Deinitizalize the device and release its resources.
+ Error deinitImpl() override {
+ if (auto Err = MPIEventQueueManager.deinit())
+ return Err;
+
+ if (auto Err = MPIEventManager.deinit())
+ return Err;
+
+ return Plugin::success();
+ }
+
+ Error setContext() override { return Plugin::success(); }
+
+ /// Load the binary image into the device and allocate an image object.
+ Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage,
+ int32_t ImageId) override {
+ // Allocate and initialize the image object.
+ MPIDeviceImageTy *Image = Plugin.allocate<MPIDeviceImageTy>();
+ new (Image) MPIDeviceImageTy(ImageId, *this, TgtImage);
+ return Image;
+ }
+
+ /// Allocate memory on the device or related to the device.
+ void *allocate(size_t Size, void *, TargetAllocTy Kind) override {
+ return nullptr;
+ }
+
+ /// Deallocate memory on the device or related to the device.
+ int free(void *TgtPtr, TargetAllocTy Kind) override {
+ return OFFLOAD_SUCCESS;
+ }
+
+ /// Submit data to the device (host to device transfer).
+ Error dataSubmitImpl(void *TgtPtr, const void *HstPtr, int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::success();
+ }
+
+ /// Retrieve data from the device (device to host transfer).
+ Error dataRetrieveImpl(void *HstPtr, const void *TgtPtr, int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::success();
+ }
+
+ /// Exchange data between two devices directly. In the MPI plugin, this
+ /// function will create an event for the host to tell the devices about the
+ /// exchange. Then, the devices will do the transfer themselves and let the
+ /// host know when it's done.
+ Error dataExchangeImpl(const void *SrcPtr, GenericDeviceTy &DstDev,
+ void *DstPtr, int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::success();
+ }
+
+ /// Allocate and construct a MPI kernel.
+ Expected<GenericKernelTy &> constructKernel(const char *Name) override {
+ // Allocate and construct the kernel.
+ MPIKernelTy *MPIKernel = Plugin.allocate<MPIKernelTy>();
+
+ if (!MPIKernel)
+ return Plugin::error("Failed to allocate memory for MPI kernel");
+
+ new (MPIKernel) MPIKernelTy(Name);
+
+ return *MPIKernel;
+ }
+
+ /// Create an event.
+ Error createEventImpl(void **EventStoragePtr) override {
+ return Plugin::success();
+ }
+
+ /// Destroy a previously created event.
+ Error destroyEventImpl(void *Event) override {
+ return MPIEventManager.returnResource(reinterpret_cast<EventTy *>(Event));
+ }
+
+ /// Record the event.
+ Error recordEventImpl(void *Event,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::success();
+ }
+
+ /// Make the queue wait on the event.
+ Error waitEventImpl(void *Event,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::success();
+ }
+
+ /// Synchronize the current thread with the event
+ Error syncEventImpl(void *Event) override { return Plugin::success(); }
+
+ /// Synchronize current thread with the pending operations on the async info.
+ Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
+ return Plugin::success();
+ }
+
+ /// Query for the completion of the pending operations on the async info.
+ Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override {
+ return Plugin::success();
+ }
+
+ Expected<void *> dataLockImpl(void *HstPtr, int64_t Size) override {
+ return HstPtr;
+ }
+
+ /// Indicate that the buffer is not pinned.
+ Expected<bool> isPinnedPtrImpl(void *HstPtr, void *&BaseHstPtr,
+ void *&BaseDevAccessiblePtr,
+ size_t &BaseSize) const override {
+ return false;
+ }
+
+ Error dataUnlockImpl(void *HstPtr) override { return Plugin::success(); }
+
+ /// This plugin should not setup the device environment or memory pool.
+ virtual bool shouldSetupDeviceEnvironment() const override { return false; };
+ virtual bool shouldSetupDeviceMemoryPool() const override { return false; };
+
+ /// Device memory limits are currently not applicable to the MPI plugin.
+ Error getDeviceStackSize(uint64_t &Value) override {
+ Value = 0;
+ return Plugin::success();
+ }
+
+ Error setDeviceStackSize(uint64_t Value) override {
+ return Plugin::success();
+ }
+
+ Error getDeviceHeapSize(uint64_t &Value) override {
+ Value = 0;
+ return Plugin::success();
+ }
+
+ Error setDeviceHeapSize(uint64_t Value) override { return Plugin::success(); }
+
+ /// Device interoperability. Not supported by MPI right now.
+ Error initAsyncInfoImpl(AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::error("initAsyncInfoImpl not supported");
+ }
+
+ /// This plugin does not support interoperability.
+ Error initDeviceInfoImpl(__tgt_device_info *DeviceInfo) override {
+ return Plugin::error("initDeviceInfoImpl not supported");
+ }
+
+ /// Print information about the device.
+ Error obtainInfoImpl(InfoQueueTy &Info) override {
+ // TODO: Add more information about the device.
+ Info.add("MPI plugin");
+ Info.add("MPI OpenMP Device Number", DeviceId);
+
+ return Plugin::success();
+ }
+
+ Error getQueue(AsyncInfoWrapperTy &AsyncInfoWrapper,
+ MPIEventQueuePtr &Queue) {
+ return Plugin::success();
+ }
+
+private:
+ using MPIEventQueueManagerTy =
+ GenericDeviceResourceManagerTy<MPIResourceRef<MPIEventQueue>>;
+ using MPIEventManagerTy =
+ GenericDeviceResourceManagerTy<MPIResourceRef<EventTy>>;
+
+ MPIEventQueueManagerTy MPIEventQueueManager;
+ MPIEventManagerTy MPIEventManager;
+
+ /// Grid values for the MPI plugin.
+ static constexpr GV MPIGridValues = {
+ 1, // GV_Slot_Size
+ 1, // GV_Warp_Size
+ 1, // GV_Max_Teams
+ 1, // GV_Default_Num_Teams
+ 1, // GV_SimpleBufferSize
+ 1, // GV_Max_WG_Size
+ 1, // GV_Default_WG_Size
+ };
+};
+
+Error MPIKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
+ uint32_t NumThreads, uint64_t NumBlocks,
+ KernelArgsTy &KernelArgs,
+ KernelLaunchParamsTy LaunchParams,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) const {
+ return Plugin::success();
+}
+
+/// Class implementing the MPI plugin.
+struct MPIPluginTy : public GenericPluginTy {
+ MPIPluginTy() : GenericPluginTy(getTripleArch()) {}
+
+ /// This class should not be copied.
+ MPIPluginTy(const MPIPluginTy &) = delete;
+ MPIPluginTy(MPIPluginTy &&) = delete;
+
+ /// Initialize the plugin and return the number of devices.
+ Expected<int32_t> initImpl() override {
+ if (!EventSystem.is_initialized())
+ EventSystem.initialize();
+ int32_t NumRemoteDevices = getNumRemoteDevices();
+ assert(RemoteDevices.size() == 0 && "MPI Plugin already initialized");
+ RemoteDevices.resize(NumRemoteDevices, nullptr);
+ return NumRemoteDevices;
+ }
+
+ /// Deinitialize the plugin.
+ Error deinitImpl() override {
+ EventSystem.deinitialize();
+ return Plugin::success();
+ }
+
+ /// Creates a MPI device.
+ GenericDeviceTy *createDevice(GenericPluginTy &Plugin, int32_t DeviceId,
+ int32_t NumDevices) override {
+ return new MPIDeviceTy(Plugin, DeviceId, NumDevices);
+ }
+
+ /// Creates a MPI global handler.
+ GenericGlobalHandlerTy *createGlobalHandler() override {
+ return new MPIGlobalHandlerTy();
+ }
+
+ /// Get the ELF code to recognize the compatible binary images.
+ uint16_t getMagicElfBits() const override {
+ return utils::elf::getTargetMachine();
+ }
+
+ /// All images (ELF-compatible) should be compatible with this plugin.
+ Expected<bool> isELFCompatible(uint32_t DeviceID,
+ StringRef Image) const override {
+ return true;
+ }
+
+ Triple::ArchType getTripleArch() const override {
+#if defined(__x86_64__)
+ return llvm::Triple::x86_64;
+#elif defined(__s390x__)
+ return llvm::Triple::systemz;
+#elif defined(__aarch64__)
+#ifdef LITTLEENDIAN_CPU
+ return llvm::Triple::aarch64;
+#else
+ return llvm::Triple::aarch64_be;
+#endif
+#elif defined(__powerpc64__)
+#ifdef LITTLEENDIAN_CPU
+ return llvm::Triple::ppc64le;
+#else
+ return llvm::Triple::ppc64;
+#endif
+#else
+ return llvm::Triple::UnknownArch;
+#endif
+ }
+
+ Error getQueue(__tgt_async_info *AsyncInfoPtr, MPIEventQueuePtr &Queue) {
+ const std::lock_guard<std::mutex> Lock(MPIQueueMutex);
+ Queue = static_cast<MPIEventQueuePtr>(AsyncInfoPtr->Queue);
+ if (!Queue) {
+ Queue = new MPIEventQueue;
+ if (Queue == nullptr)
+ return Plugin::error("Failed to get Queue from AsyncInfoPtr %p\n",
+ AsyncInfoPtr);
+ // Modify the AsyncInfoWrapper to hold the new queue.
+ AsyncInfoPtr->Queue = Queue;
+ }
+ return Plugin::success();
+ }
+
+ Error returnQueue(MPIEventQueuePtr &Queue) {
+ const std::lock_guard<std::mutex> Lock(MPIQueueMutex);
+ if (Queue == nullptr)
+ return Plugin::error("Failed to return Queue: invalid Queue ptr");
+
+ delete Queue;
+
+ return Plugin::success();
+ }
+
+ const char *getName() const override { return GETNAME(TARGET_NAME); }
+
+ /// This plugin does not support exchanging data between two devices.
+ bool isDataExchangable(int32_t SrcDeviceId, int32_t DstDeviceId) override {
+ bool QueryResult = false;
+
+ int32_t SrcRank = -1, SrcDevId, DstRank = -1, DstDevId;
+
+ std::tie(SrcRank, SrcDevId) = EventSystem.mapDeviceId(SrcDeviceId);
+ std::tie(DstRank, DstDevId) = EventSystem.mapDeviceId(DstDeviceId);
+
+ // If the exchange is between different mpi processes, it is possible to
+ // perform the operation without consulting the devices
+ if ((SrcRank != -1) && (DstRank != -1) && (SrcRank != DstRank))
+ return true;
+
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::isDataExchangable, EventTypeTy::IS_DATA_EXCHANGABLE,
+ SrcDeviceId, DstDeviceId, &QueryResult);
+
+ if (Event.empty()) {
+ DP("Failed to create isDataExchangeble event in %d SrcDevice\n",
+ SrcDeviceId);
+ return false;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ DP("Failed to query isDataExchangeble from device %d SrcDevice: %s\n",
+ SrcDeviceId, toString(std::move(Error)).c_str());
+ return false;
+ }
+
+ return QueryResult;
+ }
+
+ /// Get the number of devices considering all devices per rank
+ int32_t getNumRemoteDevices() {
+ int32_t NumRemoteDevices = 0;
+ int32_t NumRanks = EventSystem.getNumWorkers();
+
+ for (int32_t RemoteRank = 0; RemoteRank < NumRanks; RemoteRank++) {
+ auto Event = EventSystem.createEvent(
+ OriginEvents::retrieveNumDevices, EventTypeTy::RETRIEVE_NUM_DEVICES,
+ RemoteRank, &EventSystem.DevicesPerRemote.emplace_back(0));
+
+ if (Event.empty()) {
+ DP("Error retrieving Num Devices from rank %d\n", RemoteRank);
+ return 0;
+ }
+
+ Event.wait();
+ if (auto Err = Event.getError())
+ DP("Error retrieving Num Devices from rank %d: %s\n", RemoteRank,
+ toString(std::move(Err)).c_str());
+
+ NumRemoteDevices += EventSystem.DevicesPerRemote[RemoteRank];
+ }
+
+ return NumRemoteDevices;
+ }
+
+ int32_t is_plugin_compatible(__tgt_device_image *Image) override {
+ if (!EventSystem.is_initialized())
+ EventSystem.initialize();
+
+ int NumRanks = EventSystem.getNumWorkers();
+ llvm::SmallVector<bool> QueryResults{};
+ bool QueryResult = true;
+ for (int RemoteRank = 0; RemoteRank < NumRanks; RemoteRank++) {
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::isPluginCompatible, EventTypeTy::IS_PLUGIN_COMPATIBLE,
+ RemoteRank, Image, &QueryResults.emplace_back(false));
+
+ if (Event.empty()) {
+ DP("Failed to create isPluginCompatible on Rank %d\n", RemoteRank);
+ QueryResults[RemoteRank] = false;
+ }
+
+ Event.wait();
+ if (auto Err = Event.getError()) {
+ DP("Error querying the binary compability on Rank %d\n", RemoteRank);
+ QueryResults[RemoteRank] = false;
+ }
+
+ QueryResult &= QueryResults[RemoteRank];
+ }
+
+ return QueryResult;
+ }
+
+ int32_t is_device_compatible(int32_t DeviceId,
+ __tgt_device_image *Image) override {
+ bool QueryResult = true;
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::isDeviceCompatible,
+ EventTypeTy::IS_DEVICE_COMPATIBLE,
+ DeviceId, Image, &QueryResult);
+
+ if (Event.empty()) {
+ DP("Failed to create isDeviceCompatible on Device %d\n", DeviceId);
+ }
+
+ Event.wait();
+ if (auto Err = Event.getError()) {
+ DP("Error querying the binary compability on Device %d\n", DeviceId);
+ }
+
+ return QueryResult;
+ }
+
+ int32_t is_device_initialized(int32_t DeviceId) const override {
+ return isValidDeviceId(DeviceId) && RemoteDevices[DeviceId] != nullptr;
+ }
+
+ int32_t init_device(int32_t DeviceId) override {
+ void *DevicePtr = nullptr;
+
+ EventTy Event =
+ EventSystem.createEvent(OriginEvents::initDevice,
+ EventTypeTy::INIT_DEVICE, DeviceId, &DevicePtr);
+
+ if (Event.empty()) {
+ REPORT("Error to create InitDevice Event for device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to initialize device %d: %s\n", DeviceId,
+ toString(std::move(Error)).data());
+ return 0;
+ }
+
+ RemoteDevices[DeviceId] = DevicePtr;
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t initialize_record_replay(int32_t DeviceId, int64_t MemorySize,
+ void *VAddr, bool isRecord, bool SaveOutput,
+ uint64_t &ReqPtrArgOffset) override {
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::initRecordReplay, EventTypeTy::INIT_RECORD_REPLAY,
+ DeviceId, MemorySize, VAddr, isRecord, SaveOutput, &ReqPtrArgOffset);
+
+ if (Event.empty()) {
+ REPORT("Error to create initRecordReplay Event for device %d\n",
+ DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("WARNING RR did not intialize RR-properly with %lu bytes"
+ "(Error: %s)\n",
+ MemorySize, toString(std::move(Error)).data());
+ if (!isRecord) {
+ return OFFLOAD_FAIL;
+ }
+ }
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t load_binary(int32_t DeviceId, __tgt_device_image *TgtImage,
+ __tgt_device_binary *Binary) override {
+ EventTy Event = EventSystem.createEvent(OriginEvents::loadBinary,
+ EventTypeTy::LOAD_BINARY, DeviceId,
+ TgtImage, Binary);
+
+ if (Event.empty()) {
+ REPORT("Failed to create loadBinary event for image %p", TgtImage);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError(); Error) {
+ REPORT("Event failed during loadBinary. %s\n",
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ DeviceImgPtrToDeviceId[Binary->handle] = DeviceId;
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ void *data_alloc(int32_t DeviceId, int64_t Size, void *HostPtr,
+ int32_t Kind) override {
+ if (Size == 0)
+ return nullptr;
+
+ void *TgtPtr = nullptr;
+ std::optional<Error> Err = std::nullopt;
+ EventTy Event;
+
+ switch (Kind) {
+ case TARGET_ALLOC_DEFAULT:
+ case TARGET_ALLOC_DEVICE:
+ case TARGET_ALLOC_DEVICE_NON_BLOCKING:
+ Event = EventSystem.createEvent(OriginEvents::allocateBuffer,
+ EventTypeTy::ALLOC, DeviceId, Size, Kind,
+ &TgtPtr);
+
+ if (Event.empty()) {
+ Err = Plugin::error("Failed to create alloc event with size %z", Size);
+ break;
+ }
+
+ Event.wait();
+ Err = Event.getError();
+ break;
+ case TARGET_ALLOC_HOST:
+ TgtPtr = memAllocHost(Size);
+ Err = Plugin::check(TgtPtr == nullptr, "Failed to allocate host memory");
+ break;
+ case TARGET_ALLOC_SHARED:
+ Err = Plugin::error("Incompatible memory type %d", Kind);
+ break;
+ }
+
+ if (*Err) {
+ REPORT("Failed to allocate data for HostPtr %p: %s\n", HostPtr,
+ toString(std::move(*Err)).c_str());
+ return nullptr;
+ }
+
+ return TgtPtr;
+ }
+
+ int32_t data_delete(int32_t DeviceId, void *TgtPtr, int32_t Kind) override {
+ if (TgtPtr == nullptr)
+ return OFFLOAD_SUCCESS;
+
+ std::optional<Error> Err = std::nullopt;
+ EventTy Event;
+
+ switch (Kind) {
+ case TARGET_ALLOC_DEFAULT:
+ case TARGET_ALLOC_DEVICE:
+ case TARGET_ALLOC_DEVICE_NON_BLOCKING:
+ Event =
+ EventSystem.createEvent(OriginEvents::deleteBuffer,
+ EventTypeTy::DELETE, DeviceId, TgtPtr, Kind);
+
+ if (Event.empty()) {
+ Err = Plugin::error("Failed to create data delete event for %p TgtPtr",
+ TgtPtr);
+ break;
+ }
+
+ Event.wait();
+ Err = Event.getError();
+ break;
+ case TARGET_ALLOC_HOST:
+ Err = Plugin::check(memFreeHost(TgtPtr), "Failed to free host memory");
+ break;
+ case TARGET_ALLOC_SHARED:
+ Err = createStringError(inconvertibleErrorCode(),
+ "Incompatible memory type %d", Kind);
+ break;
+ }
+
+ if (*Err) {
+ REPORT("Failed delete data at %p TgtPtr: %s\n", TgtPtr,
+ toString(std::move(*Err)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_lock(int32_t DeviceId, void *Ptr, int64_t Size,
+ void **LockedPtr) override {
+ EventTy Event =
+ EventSystem.createEvent(OriginEvents::dataLock, EventTypeTy::DATA_LOCK,
+ DeviceId, Ptr, Size, LockedPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create data lock event on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to lock memory %p: %s\n", Ptr,
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ if (!(*LockedPtr)) {
+ REPORT("Failure to lock memory %p: obtained a null locked pointer\n",
+ Ptr);
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_unlock(int32_t DeviceId, void *Ptr) override {
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::dataUnlock, EventTypeTy::DATA_UNLOCK, DeviceId, Ptr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create data unlock event on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to unlock memory %p: %s\n", Ptr,
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_notify_mapped(int32_t DeviceId, void *HstPtr,
+ int64_t Size) override {
+ EventTy Event = EventSystem.createEvent(OriginEvents::dataNotifyMapped,
+ EventTypeTy::DATA_NOTIFY_MAPPED,
+ DeviceId, HstPtr, Size);
+
+ if (Event.empty()) {
+ REPORT("Failed to create data notify mapped event on device %d\n",
+ DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to notify data mapped %p: %s\n", HstPtr,
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_notify_unmapped(int32_t DeviceId, void *HstPtr) override {
+ EventTy Event = EventSystem.createEvent(OriginEvents::dataNotifyUnmapped,
+ EventTypeTy::DATA_NOTIFY_UNMAPPED,
+ DeviceId, HstPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create data notify unmapped event on device %d\n",
+ DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to notify data unmapped %p: %s\n", HstPtr,
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_submit_async(int32_t DeviceId, void *TgtPtr, void *HstPtr,
+ int64_t Size,
+ __tgt_async_info *AsyncInfoPtr) override {
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfoPtr, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ // Copy HstData to a buffer with event-managed lifetime.
+ memAllocHost(Size);
+ void *SubmitBuffer = memAllocHost(Size);
+ std::memcpy(SubmitBuffer, HstPtr, Size);
+ EventDataHandleTy DataHandle(SubmitBuffer, &memFreeHost);
+
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::submit, EventTypeTy::SUBMIT, DeviceId, TgtPtr, DataHandle,
+ Size, AsyncInfoPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create dataSubmit event from %p HstPtr to %p TgtPtr\n",
+ HstPtr, TgtPtr);
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(Event);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_retrieve_async(int32_t DeviceId, void *HstPtr, void *TgtPtr,
+ int64_t Size,
+ __tgt_async_info *AsyncInfoPtr) override {
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfoPtr, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ EventTy Event =
+ EventSystem.createEvent(OriginEvents::retrieve, EventTypeTy::RETRIEVE,
+ DeviceId, Size, HstPtr, TgtPtr, AsyncInfoPtr);
+
+ if (Event.empty()) {
+ REPORT(
+ "Failed to create dataRetrieve event from %p TgtPtr to %p HstPtr\n",
+ TgtPtr, HstPtr);
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(Event);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_exchange_async(int32_t SrcDeviceId, void *SrcPtr,
+ int DstDeviceId, void *DstPtr, int64_t Size,
+ __tgt_async_info *AsyncInfo) override {
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfo, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ int32_t SrcRank, SrcDevId, DstRank, DstDevId;
+ EventTy Event;
+
+ std::tie(SrcRank, SrcDevId) = EventSystem.mapDeviceId(SrcDeviceId);
+ std::tie(DstRank, DstDevId) = EventSystem.mapDeviceId(DstDeviceId);
+
+ if (SrcRank == DstRank) {
+ Event = EventSystem.createEvent(
+ OriginEvents::localExchange, EventTypeTy::LOCAL_EXCHANGE, SrcDeviceId,
+ SrcPtr, DstDeviceId, DstPtr, Size, AsyncInfo);
+ }
+
+ else {
+ Event = EventSystem.createExchangeEvent(SrcDeviceId, SrcPtr, DstDeviceId,
+ DstPtr, Size, AsyncInfo);
+ }
+
+ if (Event.empty()) {
+ REPORT("Failed to create data exchange event from %d SrcDeviceId to %d "
+ "DstDeviceId\n",
+ SrcDeviceId, DstDeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(Event);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t launch_kernel(int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs,
+ ptrdiff_t *TgtOffsets, KernelArgsTy *KernelArgs,
+ __tgt_async_info *AsyncInfoPtr) override {
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfoPtr, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ uint32_t NumArgs = KernelArgs->NumArgs;
+
+ void *Args = memAllocHost(sizeof(void *) * NumArgs);
+ std::memcpy(Args, TgtArgs, sizeof(void *) * NumArgs);
+ EventDataHandleTy ArgsHandle(Args, &memFreeHost);
+
+ void *Offsets = memAllocHost(sizeof(ptrdiff_t) * NumArgs);
+ std::memcpy(Offsets, TgtOffsets, sizeof(ptrdiff_t) * NumArgs);
+ EventDataHandleTy OffsetsHandle(Offsets, &memFreeHost);
+
+ void *KernelArgsPtr = memAllocHost(sizeof(KernelArgsTy));
+ std::memcpy(KernelArgsPtr, KernelArgs, sizeof(KernelArgsTy));
+ EventDataHandleTy KernelArgsHandle(KernelArgsPtr, &memFreeHost);
+
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::launchKernel, EventTypeTy::LAUNCH_KERNEL, DeviceId,
+ TgtEntryPtr, ArgsHandle, OffsetsHandle, KernelArgsHandle, AsyncInfoPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create launchKernel event on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(Event);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t synchronize(int32_t DeviceId,
+ __tgt_async_info *AsyncInfoPtr) override {
+ MPIEventQueuePtr Queue =
+ reinterpret_cast<MPIEventQueuePtr>(AsyncInfoPtr->Queue);
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::synchronize,
+ EventTypeTy::SYNCHRONIZE, DeviceId,
+ AsyncInfoPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create synchronize event on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(Event);
+
+ for (auto &Event : *Queue) {
+ Event.wait();
+
+ if (auto Error = Event.getError(); Error) {
+ REPORT("Event failed during synchronization. %s\n",
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+ }
+
+ // Once the queue is synchronized, return it to the pool and reset the
+ // AsyncInfo. This is to make sure that the synchronization only works
+ // for its own tasks.
+ AsyncInfoPtr->Queue = nullptr;
+ if (auto Error = returnQueue(Queue)) {
+ REPORT("Failed to return async Queue: %s\n",
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t query_async(int32_t DeviceId,
+ __tgt_async_info *AsyncInfoPtr) override {
+ auto *Queue = reinterpret_cast<MPIEventQueue *>(AsyncInfoPtr->Queue);
+
+ // Returns success when there are pending operations in AsyncInfo, moving
+ // forward through the events on the queue until it is fully completed.
+ while (!Queue->empty()) {
+ auto &Event = Queue->front();
+
+ Event.resume();
+
+ if (!Event.done())
+ return OFFLOAD_SUCCESS;
+
+ if (auto Error = Event.getError(); Error) {
+ REPORT("Event failed during query. %s\n",
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+ Queue->pop_front();
+ }
+
+ // Once the queue is synchronized, return it to the pool and reset the
+ // AsyncInfo. This is to make sure that the synchronization only works
+ // for its own tasks.
+ AsyncInfoPtr->Queue = nullptr;
+ if (auto Error = returnQueue(Queue)) {
+ REPORT("Failed to return async Queue: %s\n",
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ void print_device_info(int32_t DeviceId) override {
+ EventTy Event =
+ EventSystem.createEvent(OriginEvents::printDeviceInfo,
+ EventTypeTy::PRINT_DEVICE_INFO, DeviceId);
+
+ if (Event.empty()) {
+ REPORT("Failed to create printDeviceInfo event on device %d\n", DeviceId);
+ return;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to print device %d info: %s\n", DeviceId,
+ toString(std::move(Error)).data());
+ }
+ }
+
+ int32_t create_event(int32_t DeviceId, void **EventPtr) override {
+ if (!EventPtr) {
+ REPORT("Failure to record event: Received invalid event pointer\n");
+ return OFFLOAD_FAIL;
+ }
+
+ EventTy *NewEvent = new EventTy;
+
+ if (NewEvent == nullptr) {
+ REPORT("Failed to createEvent\n");
+ return OFFLOAD_FAIL;
+ }
+
+ *EventPtr = reinterpret_cast<void *>(NewEvent);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t record_event(int32_t DeviceId, void *EventPtr,
+ __tgt_async_info *AsyncInfoPtr) override {
+ if (!EventPtr) {
+ REPORT("Failure to record event: Received invalid event pointer\n");
+ return OFFLOAD_FAIL;
+ }
+
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfoPtr, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ if (Queue->empty())
+ return OFFLOAD_SUCCESS;
+
+ auto &RecordedEvent = *reinterpret_cast<EventTy *>(EventPtr);
+ RecordedEvent = Queue->back();
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t wait_event(int32_t DeviceId, void *EventPtr,
+ __tgt_async_info *AsyncInfoPtr) override {
+ if (!EventPtr) {
+ REPORT("Failure to wait event: Received invalid event pointer\n");
+ return OFFLOAD_FAIL;
+ }
+
+ auto &RecordedEvent = *reinterpret_cast<EventTy *>(EventPtr);
+ auto SyncEvent = OriginEvents::sync(RecordedEvent);
+
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfoPtr, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(SyncEvent);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t sync_event(int32_t DeviceId, void *EventPtr) override {
+ if (!EventPtr) {
+ REPORT("Failure to wait event: Received invalid event pointer\n");
+ return OFFLOAD_FAIL;
+ }
+
+ auto &RecordedEvent = *reinterpret_cast<EventTy *>(EventPtr);
+ auto SyncEvent = OriginEvents::sync(RecordedEvent);
+
+ SyncEvent.wait();
+
+ if (auto Err = SyncEvent.getError()) {
+ REPORT("Failure to synchronize event %p: %s\n", EventPtr,
+ toString(std::move(Err)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t destroy_event(int32_t DeviceId, void *EventPtr) override {
+
+ if (!EventPtr) {
+ REPORT("Failure to destroy event: Received invalid event pointer\n");
+ return OFFLOAD_FAIL;
+ }
+
+ EventTy *MPIEventPtr = reinterpret_cast<EventTy *>(EventPtr);
+
+ delete MPIEventPtr;
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t init_async_info(int32_t DeviceId,
+ __tgt_async_info **AsyncInfoPtr) override {
+ assert(AsyncInfoPtr && "Invalid async info");
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::initAsyncInfo,
+ EventTypeTy::INIT_ASYNC_INFO,
+ DeviceId, AsyncInfoPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create initAsyncInfo on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Err = Event.getError()) {
+ REPORT("Failure to initialize async info at " DPxMOD
+ " on device %d: %s\n",
+ DPxPTR(*AsyncInfoPtr), DeviceId, toString(std::move(Err)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t init_device_info(int32_t DeviceId, __tgt_device_info *DeviceInfo,
+ const char **ErrStr) override {
+ *ErrStr = "";
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::initDeviceInfo,
+ EventTypeTy::INIT_DEVICE_INFO,
+ DeviceId, DeviceInfo);
+
+ if (Event.empty()) {
+ REPORT("Failed to create initDeviceInfo on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Err = Event.getError()) {
+ REPORT("Failure to initialize device info at " DPxMOD
+ " on device %d: %s\n",
+ DPxPTR(DeviceInfo), DeviceId, toString(std::move(Err)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t use_auto_zero_copy(int32_t DeviceId) override { return false; }
+
+ int32_t get_global(__tgt_device_binary Binary, uint64_t Size,
+ const char *Name, void **DevicePtr) override {
+ int32_t DeviceId = DeviceImgPtrToDeviceId[Binary.handle];
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::getGlobal,
+ EventTypeTy::GET_GLOBAL, DeviceId,
+ Binary, Size, Name, DevicePtr);
+ if (Event.empty()) {
+ REPORT("Failed to create getGlobal event on device %d\n", 0);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failed to get Global on device %d: %s\n", 0,
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t get_function(__tgt_device_binary Binary, const char *Name,
+ void **KernelPtr) override {
+
+ int32_t DeviceId = DeviceImgPtrToDeviceId[Binary.handle];
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::getFunction,
+ EventTypeTy::GET_FUNCTION, DeviceId,
+ Binary, Name, KernelPtr);
+ if (Event.empty()) {
+ REPORT("Failed to create getFunction event on device %d\n", 0);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failed to get function on device %d: %s\n", 0,
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+private:
+ std::mutex MPIQueueMutex;
+ llvm::DenseMap<uintptr_t, int32_t> DeviceImgPtrToDeviceId;
+ llvm::SmallVector<void *> RemoteDevices;
+ EventSystemTy EventSystem;
+};
+
+template <typename... ArgsTy>
+static Error Plugin::check(int32_t ErrorCode, const char *ErrFmt,
+ ArgsTy... Args) {
+ if (ErrorCode == OFFLOAD_SUCCESS)
+ return Error::success();
+
+ return createStringError<ArgsTy..., const char *>(
+ inconvertibleErrorCode(), ErrFmt, Args...,
+ std::to_string(ErrorCode).data());
+}
+
+} // namespace llvm::omp::target::plugin
+
+extern "C" {
+llvm::omp::target::plugin::GenericPluginTy *createPlugin_mpi() {
+ return new llvm::omp::target::plugin::MPIPluginTy();
+}
+}
\ No newline at end of file
diff --git a/offload/src/PluginManager.cpp b/offload/src/PluginManager.cpp
index 315b953f9b31ac..f3a4c153b2580a 100644
--- a/offload/src/PluginManager.cpp
+++ b/offload/src/PluginManager.cpp
@@ -12,6 +12,7 @@
#include "PluginManager.h"
#include "Shared/Debug.h"
+#include "Shared/EnvironmentVar.h"
#include "Shared/Profile.h"
#include "device.h"
@@ -65,6 +66,12 @@ bool PluginManager::initializePlugin(GenericPluginTy &Plugin) {
if (Plugin.is_initialized())
return true;
+ // Disable Host Plugin when it is needed
+ IntEnvar DisableHostPlugin("OMPTARGET_DISABLE_HOST_PLUGIN", 0);
+ if (DisableHostPlugin.get() && !strcmp(Plugin.getName(), "x86_64")) {
+ return false;
+ }
+
if (auto Err = Plugin.init()) {
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
DP("Failed to init plugin: %s\n", InfoMsg.c_str());
diff --git a/offload/test/api/omp_device_managed_memory.c b/offload/test/api/omp_device_managed_memory.c
index 2a9fe09a8334c9..63812c74595e13 100644
--- a/offload/test/api/omp_device_managed_memory.c
+++ b/offload/test/api/omp_device_managed_memory.c
@@ -1,5 +1,7 @@
// RUN: %libomptarget-compile-run-and-check-generic
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
+
#include <omp.h>
#include <stdio.h>
diff --git a/offload/test/api/omp_device_managed_memory_alloc.c b/offload/test/api/omp_device_managed_memory_alloc.c
index c48866922debaf..07f3fbbf26e895 100644
--- a/offload/test/api/omp_device_managed_memory_alloc.c
+++ b/offload/test/api/omp_device_managed_memory_alloc.c
@@ -1,6 +1,8 @@
// RUN: %libomptarget-compile-run-and-check-generic
// RUN: %libomptarget-compileopt-run-and-check-generic
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
+
#include <omp.h>
#include <stdio.h>
diff --git a/offload/test/libc/host_call.c b/offload/test/libc/host_call.c
index 61c4e14d5b3881..17941e96a21d08 100644
--- a/offload/test/libc/host_call.c
+++ b/offload/test/libc/host_call.c
@@ -2,6 +2,8 @@
// REQUIRES: libc
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
+
#include <assert.h>
#include <omp.h>
#include <stdio.h>
diff --git a/offload/test/lit.cfg b/offload/test/lit.cfg
index 2f1ef3e98d8172..0f3f121466bf1b 100644
--- a/offload/test/lit.cfg
+++ b/offload/test/lit.cfg
@@ -140,6 +140,8 @@ elif config.libomptarget_current_target.startswith('amdgcn'):
(config.amdgpu_test_arch.startswith("gfx942") and
evaluate_bool_env(config.environment['IS_APU']))):
supports_apu = True
+if config.libomptarget_current_target.endswith('-mpi'):
+ supports_unified_shared_memory = False
if supports_unified_shared_memory:
config.available_features.add('unified_shared_memory')
if supports_apu:
@@ -156,7 +158,7 @@ elif config.operating_system == 'Darwin':
config.test_flags += " -Wl,-rpath," + config.library_dir
config.test_flags += " -Wl,-rpath," + config.omp_host_rtl_directory
else: # Unices
- if config.libomptarget_current_target != "nvptx64-nvidia-cuda":
+ if config.libomptarget_current_target not in ("nvptx64-nvidia-cuda", "nvptx64-nvidia-cuda-mpi"):
config.test_flags += " -nogpulib"
config.test_flags += " -Wl,-rpath," + config.library_dir
config.test_flags += " -Wl,-rpath," + config.omp_host_rtl_directory
@@ -178,6 +180,8 @@ def remove_suffix_if_present(name):
return name[:-4]
elif name.endswith('-JIT-LTO'):
return name[:-8]
+ elif name.endswith('-mpi'):
+ return name[:-4]
else:
return name
@@ -321,7 +325,7 @@ for libomptarget_target in config.libomptarget_all_targets:
"%clang-" + libomptarget_target + add_libraries(" -O3 %s -o %t")))
config.substitutions.append(("%libomptarget-run-" + \
libomptarget_target, \
- "%t"))
+ "%pre_bin %t"))
config.substitutions.append(("%libomptarget-run-fail-" + \
libomptarget_target, \
"%not --crash %t"))
@@ -414,6 +418,10 @@ else:
config.substitutions.append(("%cuda_flags", ""))
config.substitutions.append(("%flags_clang", config.test_flags_clang))
config.substitutions.append(("%flags_flang", config.test_flags_flang))
+if config.libomptarget_current_target.endswith('-mpi'):
+ config.substitutions.append(("%pre_bin", "mpirun -np 1 llvm-offload-mpi-proxy-device : -np 1"))
+else:
+ config.substitutions.append(("%pre_bin", ""))
config.substitutions.append(("%flags", config.test_flags))
config.substitutions.append(("%not", config.libomptarget_not))
config.substitutions.append(("%offload-device-info",
diff --git a/offload/test/mapping/target_derefence_array_pointrs.cpp b/offload/test/mapping/target_derefence_array_pointrs.cpp
index a6dd4069a8f588..9ac6218816b608 100644
--- a/offload/test/mapping/target_derefence_array_pointrs.cpp
+++ b/offload/test/mapping/target_derefence_array_pointrs.cpp
@@ -6,6 +6,7 @@
// UNSUPPORTED: amdgcn-amd-amdhsa
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
#include <stdio.h>
#include <stdlib.h>
diff --git a/offload/test/mapping/target_has_device_addr.c b/offload/test/mapping/target_has_device_addr.c
index e8bfff868c7ed7..3421b6ddda7604 100644
--- a/offload/test/mapping/target_has_device_addr.c
+++ b/offload/test/mapping/target_has_device_addr.c
@@ -3,6 +3,7 @@
// RUN: | %fcheck-generic
// UNSUPPORTED: amdgcn-amd-amdhsa
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
#include <omp.h>
#include <stdio.h>
diff --git a/offload/test/mapping/target_uses_allocator.c b/offload/test/mapping/target_uses_allocator.c
index eb20e965c30bc9..b37eacc30b65c7 100755
--- a/offload/test/mapping/target_uses_allocator.c
+++ b/offload/test/mapping/target_uses_allocator.c
@@ -4,6 +4,7 @@
// UNSUPPORTED: amdgcn-amd-amdhsa
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
#include <omp.h>
#include <stdio.h>
diff --git a/offload/test/offloading/bug49334.cpp b/offload/test/offloading/bug49334.cpp
index 4ca01d4bbd9717..f3436d559e4710 100644
--- a/offload/test/offloading/bug49334.cpp
+++ b/offload/test/offloading/bug49334.cpp
@@ -87,7 +87,7 @@ int BlockMatMul_TargetNowait(BlockMatrix &A, BlockMatrix &B, BlockMatrix &C) {
for (int k = 0; k < N / BS; ++k) {
float *BlockA = A.GetBlock(i, k);
float *BlockB = B.GetBlock(k, j);
-// clang-format off
+ // clang-format off
#pragma omp target depend(in: BlockA[0], BlockB[0]) depend(inout: BlockC[0]) \
map(to: BlockA[:BS * BS], BlockB[:BS * BS]) \
map(tofrom: BlockC[:BS * BS]) nowait
diff --git a/offload/test/offloading/bug64959.c b/offload/test/offloading/bug64959.c
index eddc55325ffe90..b057a889061c7e 100644
--- a/offload/test/offloading/bug64959.c
+++ b/offload/test/offloading/bug64959.c
@@ -6,6 +6,7 @@
// UNSUPPORTED: amdgcn-amd-amdhsa
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
#include <omp.h>
#include <stdio.h>
diff --git a/offload/test/offloading/struct_mapping_with_pointers.cpp b/offload/test/offloading/struct_mapping_with_pointers.cpp
index f0fde50889dace..d04e5f1f656f62 100644
--- a/offload/test/offloading/struct_mapping_with_pointers.cpp
+++ b/offload/test/offloading/struct_mapping_with_pointers.cpp
@@ -6,6 +6,7 @@
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
#include <stdio.h>
#include <stdlib.h>
diff --git a/offload/test/offloading/target_critical_region.cpp b/offload/test/offloading/target_critical_region.cpp
index 0b97823e266b29..216ace1e73a88b 100644
--- a/offload/test/offloading/target_critical_region.cpp
+++ b/offload/test/offloading/target_critical_region.cpp
@@ -3,6 +3,7 @@
// REQUIRES: gpu
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
// UNSUPPORTED: amdgcn-amd-amdhsa
#include <omp.h>
diff --git a/offload/test/offloading/thread_limit.c b/offload/test/offloading/thread_limit.c
index 72fa0b218a3c5b..ea813cd9e88b85 100644
--- a/offload/test/offloading/thread_limit.c
+++ b/offload/test/offloading/thread_limit.c
@@ -5,6 +5,7 @@
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
// REQUIRES: gpu
int main() {
diff --git a/offload/test/sanitizer/kernel_crash.c b/offload/test/sanitizer/kernel_crash.c
index 1406af47c7ba92..af3bb88984f84f 100644
--- a/offload/test/sanitizer/kernel_crash.c
+++ b/offload/test/sanitizer/kernel_crash.c
@@ -9,6 +9,7 @@
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-unknown-linux-gnu
diff --git a/offload/test/sanitizer/kernel_crash_async.c b/offload/test/sanitizer/kernel_crash_async.c
index ee22ba504018be..e3a27445a59dd1 100644
--- a/offload/test/sanitizer/kernel_crash_async.c
+++ b/offload/test/sanitizer/kernel_crash_async.c
@@ -9,6 +9,7 @@
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-unknown-linux-gnu
diff --git a/offload/test/sanitizer/kernel_crash_many.c b/offload/test/sanitizer/kernel_crash_many.c
index f1d17ca2b76e23..49b464d763ec3c 100644
--- a/offload/test/sanitizer/kernel_crash_many.c
+++ b/offload/test/sanitizer/kernel_crash_many.c
@@ -7,6 +7,7 @@
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-unknown-linux-gnu
diff --git a/offload/test/sanitizer/kernel_crash_single.c b/offload/test/sanitizer/kernel_crash_single.c
index 8baa0a850bf150..ac2ca877c0e53e 100644
--- a/offload/test/sanitizer/kernel_crash_single.c
+++ b/offload/test/sanitizer/kernel_crash_single.c
@@ -9,6 +9,7 @@
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-unknown-linux-gnu
diff --git a/offload/test/sanitizer/kernel_trap.c b/offload/test/sanitizer/kernel_trap.c
index 91c4c7229159bc..e7a27ee5ed1a22 100644
--- a/offload/test/sanitizer/kernel_trap.c
+++ b/offload/test/sanitizer/kernel_trap.c
@@ -10,6 +10,7 @@
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-unknown-linux-gnu
diff --git a/offload/test/sanitizer/kernel_trap.cpp b/offload/test/sanitizer/kernel_trap.cpp
index c67b3857fabba1..155bebea3b2478 100644
--- a/offload/test/sanitizer/kernel_trap.cpp
+++ b/offload/test/sanitizer/kernel_trap.cpp
@@ -10,6 +10,7 @@
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-unknown-linux-gnu
@@ -22,18 +23,13 @@ struct S {};
template <typename T> void cxx_function_name(int I, T *) {
#pragma omp target
- {
- }
+ {}
#pragma omp target
- {
- }
+ {}
#pragma omp target
- {
- __builtin_trap();
- }
+ { __builtin_trap(); }
#pragma omp target
- {
- }
+ {}
}
int main(void) {
diff --git a/offload/test/sanitizer/kernel_trap_async.c b/offload/test/sanitizer/kernel_trap_async.c
index 391ff0c7dcaa4e..7693d8271e817d 100644
--- a/offload/test/sanitizer/kernel_trap_async.c
+++ b/offload/test/sanitizer/kernel_trap_async.c
@@ -10,6 +10,7 @@
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-unknown-linux-gnu
diff --git a/offload/test/sanitizer/kernel_trap_many.c b/offload/test/sanitizer/kernel_trap_many.c
index f2e63794168b2b..c389fe52b4ec17 100644
--- a/offload/test/sanitizer/kernel_trap_many.c
+++ b/offload/test/sanitizer/kernel_trap_many.c
@@ -7,6 +7,7 @@
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: nvptx64-nvidia-cuda-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-unknown-linux-gnu
More information about the llvm-commits
mailing list