[llvm] [Offload] Add MPI Plugin (PR #90890)
Jhonatan Cléto via llvm-commits
llvm-commits at lists.llvm.org
Thu May 9 10:40:44 PDT 2024
https://github.com/cl3to updated https://github.com/llvm/llvm-project/pull/90890
>From ef4f22e919cecac49098776129abdde120ebc1ec Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jhonatan=20Cl=C3=A9to?= <j256444 at dac.unicamp.br>
Date: Thu, 2 May 2024 14:38:22 -0300
Subject: [PATCH 1/3] [Offload] Add MPI Plugin
Co-authored-by: Guilherme Valarini <guilherme.a.valarini at gmail.com>
---
offload/CMakeLists.txt | 3 +-
offload/plugins-nextgen/mpi/CMakeLists.txt | 127 ++
.../plugins-nextgen/mpi/src/EventSystem.cpp | 1071 +++++++++++++++++
offload/plugins-nextgen/mpi/src/EventSystem.h | 493 ++++++++
.../plugins-nextgen/mpi/src/MPIDeviceMain.cpp | 11 +
offload/plugins-nextgen/mpi/src/rtl.cpp | 686 +++++++++++
offload/test/api/omp_device_managed_memory.c | 2 +
.../api/omp_device_managed_memory_alloc.c | 2 +
offload/test/api/omp_dynamic_shared_memory.c | 1 +
offload/test/api/omp_indirect_call.c | 2 +
offload/test/jit/empty_kernel_lvl1.c | 1 +
offload/test/jit/empty_kernel_lvl2.c | 1 +
offload/test/jit/type_punning.c | 1 +
offload/test/lit.cfg | 10 +-
.../target_derefence_array_pointrs.cpp | 1 +
offload/test/offloading/barrier_fence.c | 1 +
offload/test/offloading/bug49334.cpp | 1 +
.../test/offloading/default_thread_limit.c | 1 +
offload/test/offloading/ompx_bare.c | 1 +
offload/test/offloading/ompx_coords.c | 1 +
offload/test/offloading/ompx_saxpy_mixed.c | 1 +
offload/test/offloading/small_trip_count.c | 1 +
.../small_trip_count_thread_limit.cpp | 1 +
offload/test/offloading/spmdization.c | 1 +
.../offloading/target_critical_region.cpp | 1 +
offload/test/offloading/thread_limit.c | 1 +
offload/test/offloading/workshare_chunk.c | 1 +
27 files changed, 2422 insertions(+), 2 deletions(-)
create mode 100644 offload/plugins-nextgen/mpi/CMakeLists.txt
create mode 100644 offload/plugins-nextgen/mpi/src/EventSystem.cpp
create mode 100644 offload/plugins-nextgen/mpi/src/EventSystem.h
create mode 100644 offload/plugins-nextgen/mpi/src/MPIDeviceMain.cpp
create mode 100644 offload/plugins-nextgen/mpi/src/rtl.cpp
diff --git a/offload/CMakeLists.txt b/offload/CMakeLists.txt
index 3f77583ffa3b8..f6d1bbdda5e9f 100644
--- a/offload/CMakeLists.txt
+++ b/offload/CMakeLists.txt
@@ -151,7 +151,7 @@ if (NOT LIBOMPTARGET_LLVM_INCLUDE_DIRS)
message(FATAL_ERROR "Missing definition for LIBOMPTARGET_LLVM_INCLUDE_DIRS")
endif()
-set(LIBOMPTARGET_ALL_PLUGIN_TARGETS amdgpu cuda host)
+set(LIBOMPTARGET_ALL_PLUGIN_TARGETS amdgpu cuda mpi host)
set(LIBOMPTARGET_PLUGINS_TO_BUILD "all" CACHE STRING
"Semicolon-separated list of plugins to use: cuda, amdgpu, host or \"all\".")
@@ -182,6 +182,7 @@ set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} powerpc64-ibm-linux-g
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} powerpc64-ibm-linux-gnu-LTO")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} x86_64-pc-linux-gnu")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} x86_64-pc-linux-gnu-LTO")
+set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} x86_64-pc-linux-gnu-mpi")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} nvptx64-nvidia-cuda")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} nvptx64-nvidia-cuda-LTO")
set (LIBOMPTARGET_ALL_TARGETS "${LIBOMPTARGET_ALL_TARGETS} nvptx64-nvidia-cuda-JIT-LTO")
diff --git a/offload/plugins-nextgen/mpi/CMakeLists.txt b/offload/plugins-nextgen/mpi/CMakeLists.txt
new file mode 100644
index 0000000000000..c3a8c9a83b45f
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/CMakeLists.txt
@@ -0,0 +1,127 @@
+##===----------------------------------------------------------------------===##
+#
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#
+##===----------------------------------------------------------------------===##
+#
+# Build a plugin for a MPI machine if available.
+#
+##===----------------------------------------------------------------------===##
+
+# Looking for MPI...
+find_package(MPI QUIET)
+
+set(LIBOMPTARGET_DEP_MPI_FOUND ${MPI_CXX_FOUND})
+set(LIBOMPTARGET_DEP_MPI_LIBRARIES ${MPI_CXX_LIBRARIES})
+set(LIBOMPTARGET_DEP_MPI_INCLUDE_DIRS ${MPI_CXX_INCLUDE_DIRS})
+set(LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS ${MPI_CXX_COMPILE_FLAGS})
+set(LIBOMPTARGET_DEP_MPI_LINK_FLAGS ${MPI_CXX_LINK_FLAGS})
+
+mark_as_advanced(
+ LIBOMPTARGET_DEP_MPI_FOUND
+ LIBOMPTARGET_DEP_MPI_LIBRARIES
+ LIBOMPTARGET_DEP_MPI_INCLUDE_DIRS
+ LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS
+ LIBOMPTARGET_DEP_MPI_LINK_FLAGS)
+
+if(NOT(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(ppc64le)$" AND CMAKE_SYSTEM_NAME MATCHES "Linux"))
+ libomptarget_say("Not building MPI offloading plugin: only support MPI in Linux x86_64 or ppc64le hosts.")
+ return()
+elseif(NOT LIBOMPTARGET_DEP_LIBFFI_FOUND)
+ libomptarget_say("Not building MPI offloading plugin: libffi dependency not found.")
+ return()
+elseif(NOT LIBOMPTARGET_DEP_MPI_FOUND)
+ libomptarget_say("Not building MPI offloading plugin: MPI not found in system.")
+ return()
+endif()
+
+libomptarget_say("Building MPI NextGen offloading plugin.")
+
+# Create the library and add the default arguments.
+add_target_library(omptarget.rtl.mpi MPI)
+
+target_sources(omptarget.rtl.mpi PRIVATE
+ src/EventSystem.cpp
+ src/rtl.cpp
+)
+
+if(FFI_STATIC_LIBRARIES)
+ target_link_libraries(omptarget.rtl.mpi PRIVATE FFI::ffi_static)
+else()
+ target_link_libraries(omptarget.rtl.mpi PRIVATE FFI::ffi)
+endif()
+
+target_link_libraries(omptarget.rtl.mpi PRIVATE
+ ${LIBOMPTARGET_DEP_MPI_LIBRARIES}
+ ${LIBOMPTARGET_DEP_MPI_LINK_FLAGS}
+)
+
+# Add include directories
+target_include_directories(omptarget.rtl.mpi PRIVATE
+ ${LIBOMPTARGET_INCLUDE_DIR})
+
+# Install plugin under the lib destination folder.
+install(TARGETS omptarget.rtl.mpi
+ LIBRARY DESTINATION "${OFFLOAD_INSTALL_LIBDIR}")
+set_target_properties(omptarget.rtl.mpi PROPERTIES
+ INSTALL_RPATH "$ORIGIN" BUILD_RPATH "$ORIGIN:${CMAKE_CURRENT_BINARY_DIR}/..")
+
+if(LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS)
+ set_target_properties(omptarget.rtl.mpi PROPERTIES
+ COMPILE_FLAGS "${LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS}")
+endif()
+
+# Set C++20 as the target standard for this plugin.
+set_target_properties(omptarget.rtl.mpi
+ PROPERTIES
+ CXX_STANDARD 20
+ CXX_STANDARD_REQUIRED ON)
+
+# Configure testing for the MPI plugin.
+list(APPEND LIBOMPTARGET_TESTED_PLUGINS "omptarget.rtl.mpi")
+# Report to the parent scope that we are building a plugin for MPI.
+set(LIBOMPTARGET_TESTED_PLUGINS "${LIBOMPTARGET_TESTED_PLUGINS}" PARENT_SCOPE)
+
+# Define the target specific triples and ELF machine values.
+set(LIBOMPTARGET_SYSTEM_TARGETS
+ "${LIBOMPTARGET_SYSTEM_TARGETS} x86_64-pc-linux-gnu-mpi" PARENT_SCOPE)
+
+# MPI Device Binary
+llvm_add_tool(OPENMP llvm-offload-mpi-device src/EventSystem.cpp src/MPIDeviceMain.cpp)
+
+llvm_update_compile_flags(llvm-offload-mpi-device)
+
+target_link_libraries(llvm-offload-mpi-device PRIVATE
+ ${LIBOMPTARGET_DEP_MPI_LIBRARIES}
+ ${LIBOMPTARGET_DEP_MPI_LINK_FLAGS}
+ LLVMSupport
+ omp
+)
+
+if(FFI_STATIC_LIBRARIES)
+ target_link_libraries(llvm-offload-mpi-device PRIVATE FFI::ffi_static)
+else()
+ target_link_libraries(llvm-offload-mpi-device PRIVATE FFI::ffi)
+endif()
+
+target_include_directories(llvm-offload-mpi-device PRIVATE
+ ${LIBOMPTARGET_INCLUDE_DIR}
+ ${LIBOMPTARGET_DEP_MPI_INCLUDE_DIRS}
+)
+
+if(LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS)
+ set_target_properties(llvm-offload-mpi-device PROPERTIES
+ COMPILE_FLAGS "${LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS}"
+ )
+endif()
+
+set_target_properties(llvm-offload-mpi-device
+ PROPERTIES
+ CXX_STANDARD 20
+ CXX_STANDARD_REQUIRED ON
+)
+
+target_compile_definitions(llvm-offload-mpi-device PRIVATE
+ DEBUG_PREFIX="OFFLOAD MPI DEVICE")
diff --git a/offload/plugins-nextgen/mpi/src/EventSystem.cpp b/offload/plugins-nextgen/mpi/src/EventSystem.cpp
new file mode 100644
index 0000000000000..742d99f9424c0
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/src/EventSystem.cpp
@@ -0,0 +1,1071 @@
+//===------ event_system.cpp - Concurrent MPI communication -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the implementation of the MPI Event System used by the MPI
+// target runtime for concurrent communication.
+//
+//===----------------------------------------------------------------------===//
+
+#include "EventSystem.h"
+
+#include <algorithm>
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <functional>
+#include <memory>
+
+#include <ffi.h>
+#include <mpi.h>
+#include <unistd.h>
+
+#include "Shared/Debug.h"
+#include "Shared/EnvironmentVar.h"
+#include "Shared/Utils.h"
+#include "omptarget.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Error.h"
+
+#include "llvm/Support/DynamicLibrary.h"
+
+using llvm::sys::DynamicLibrary;
+
+#define CHECK(expr, msg, ...) \
+ if (!(expr)) { \
+ REPORT(msg, ##__VA_ARGS__); \
+ return false; \
+ }
+
+/// Customizable parameters of the event system
+///
+/// Number of execute event handlers to spawn.
+static IntEnvar NumExecEventHandlers("OMPTARGET_NUM_EXEC_EVENT_HANDLERS", 1);
+/// Number of data event handlers to spawn.
+static IntEnvar NumDataEventHandlers("OMPTARGET_NUM_DATA_EVENT_HANDLERS", 1);
+/// Polling rate period (us) used by event handlers.
+static IntEnvar EventPollingRate("OMPTARGET_EVENT_POLLING_RATE", 1);
+/// Number of communicators to be spawned and distributed for the events.
+/// Allows for parallel use of network resources.
+static Int64Envar NumMPIComms("OMPTARGET_NUM_MPI_COMMS", 10);
+/// Maximum buffer Size to use during data transfer.
+static Int64Envar MPIFragmentSize("OMPTARGET_MPI_FRAGMENT_SIZE", 100e6);
+
+/// Helper function to transform event type to string
+const char *toString(EventTypeTy Type) {
+ using enum EventTypeTy;
+
+ switch (Type) {
+ case ALLOC:
+ return "Alloc";
+ case DELETE:
+ return "Delete";
+ case RETRIEVE:
+ return "Retrieve";
+ case SUBMIT:
+ return "Submit";
+ case EXCHANGE:
+ return "Exchange";
+ case EXCHANGE_SRC:
+ return "exchangeSrc";
+ case EXCHANGE_DST:
+ return "ExchangeDst";
+ case EXECUTE:
+ return "Execute";
+ case SYNC:
+ return "Sync";
+ case LOAD_BINARY:
+ return "LoadBinary";
+ case EXIT:
+ return "Exit";
+ }
+
+ assert(false && "Every enum value must be checked on the switch above.");
+ return nullptr;
+}
+
+/// Resumes the most recent incomplete coroutine in the list.
+void EventTy::resume() {
+ // Acquire first handle not done.
+ const CoHandleTy &RootHandle = getHandle().promise().RootHandle;
+ auto &ResumableHandle = RootHandle.promise().PrevHandle;
+ while (ResumableHandle.done()) {
+ ResumableHandle = ResumableHandle.promise().PrevHandle;
+
+ if (ResumableHandle == RootHandle)
+ break;
+ }
+
+ if (!ResumableHandle.done())
+ ResumableHandle.resume();
+}
+
+/// Wait until event completes.
+void EventTy::wait() {
+ // Advance the event progress until it is completed.
+ while (!done()) {
+ resume();
+
+ std::this_thread::sleep_for(
+ std::chrono::microseconds(EventPollingRate.get()));
+ }
+}
+
+/// Check if the event has completed.
+bool EventTy::done() const { return getHandle().done(); }
+
+/// Check if it is an empty event.
+bool EventTy::empty() const { return !getHandle(); }
+
+/// Get the coroutine error from the Handle.
+llvm::Error EventTy::getError() const {
+ auto &Error = getHandle().promise().CoroutineError;
+ if (Error)
+ return std::move(*Error);
+
+ return llvm::Error::success();
+}
+
+/// MPI Request Manager Destructor. The Manager cannot be destroyed until all
+/// the requests it manages have been completed.
+MPIRequestManagerTy::~MPIRequestManagerTy() {
+ assert(Requests.empty() && "Requests must be fulfilled and emptied before "
+ "destruction. Did you co_await on it?");
+}
+
+/// Send a message to \p OtherRank asynchronously.
+void MPIRequestManagerTy::send(const void *Buffer, int Size,
+ MPI_Datatype Datatype) {
+ MPI_Isend(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+ &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+/// Divide the \p Buffer into fragments of size \p MPIFragmentSize and send them
+/// to \p OtherRank asynchronously.
+void MPIRequestManagerTy::sendInBatchs(void *Buffer, int Size) {
+ // Operates over many fragments of the original buffer of at most
+ // MPI_FRAGMENT_SIZE bytes.
+ char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+ int64_t RemainingBytes = Size;
+ while (RemainingBytes > 0) {
+ send(&BufferByteArray[Size - RemainingBytes],
+ static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+ MPI_BYTE);
+ RemainingBytes -= MPIFragmentSize.get();
+ }
+}
+
+/// Receive a message from \p OtherRank asynchronously.
+void MPIRequestManagerTy::receive(void *Buffer, int Size,
+ MPI_Datatype Datatype) {
+ MPI_Irecv(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+ &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+/// Asynchronously receive message fragments from \p OtherRank and reconstruct
+/// them into \p Buffer.
+void MPIRequestManagerTy::receiveInBatchs(void *Buffer, int Size) {
+ // Operates over many fragments of the original buffer of at most
+ // MPI_FRAGMENT_SIZE bytes.
+ char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+ int64_t RemainingBytes = Size;
+ while (RemainingBytes > 0) {
+ receive(&BufferByteArray[Size - RemainingBytes],
+ static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+ MPI_BYTE);
+ RemainingBytes -= MPIFragmentSize.get();
+ }
+}
+
+/// Coroutine that waits until all pending requests finish.
+EventTy MPIRequestManagerTy::wait() {
+ int RequestsCompleted = false;
+
+ while (!RequestsCompleted) {
+ int MPIError = MPI_Testall(Requests.size(), Requests.data(),
+ &RequestsCompleted, MPI_STATUSES_IGNORE);
+
+ if (MPIError != MPI_SUCCESS)
+ co_return createError("Waiting of MPI requests failed with code %d",
+ MPIError);
+
+ co_await std::suspend_always{};
+ }
+
+ Requests.clear();
+
+ co_return llvm::Error::success();
+}
+
+EventTy operator co_await(MPIRequestManagerTy &RequestManager) {
+ return RequestManager.wait();
+}
+
+void *memAllocHost(int64_t Size) {
+ void *HstPrt = nullptr;
+ int MPIError = MPI_Alloc_mem(Size, MPI_INFO_NULL, &HstPrt);
+ if (MPIError != MPI_SUCCESS)
+ return nullptr;
+ return HstPrt;
+}
+
+int memFreeHost(void *HstPtr) {
+ int MPIError = MPI_Free_mem(HstPtr);
+ if (MPIError != MPI_SUCCESS)
+ return OFFLOAD_FAIL;
+ return OFFLOAD_SUCCESS;
+}
+
+/// Device Image Storage. This class is used to store Device Image data
+/// in the remote device process.
+struct DeviceImage : __tgt_device_image {
+ llvm::SmallVector<unsigned char, 1> ImageBuffer;
+ llvm::SmallVector<__tgt_offload_entry, 16> Entries;
+ llvm::SmallVector<char> FlattenedEntryNames;
+
+ DeviceImage() {
+ ImageStart = nullptr;
+ ImageEnd = nullptr;
+ EntriesBegin = nullptr;
+ EntriesEnd = nullptr;
+ }
+
+ DeviceImage(size_t ImageSize, size_t EntryCount)
+ : ImageBuffer(ImageSize + alignof(void *)), Entries(EntryCount) {
+ // Align the image buffer to alignof(void *).
+ ImageStart = ImageBuffer.begin();
+ std::align(alignof(void *), ImageSize, ImageStart, ImageSize);
+ ImageEnd = (void *)((size_t)ImageStart + ImageSize);
+ }
+
+ void setImageEntries(llvm::SmallVector<size_t> EntryNameSizes) {
+ // Adjust the entry names to use the flattened name buffer.
+ size_t EntryCount = Entries.size();
+ size_t TotalNameSize = 0;
+ for (size_t I = 0; I < EntryCount; I++) {
+ TotalNameSize += EntryNameSizes[I];
+ }
+ FlattenedEntryNames.resize(TotalNameSize);
+
+ for (size_t I = EntryCount; I > 0; I--) {
+ TotalNameSize -= EntryNameSizes[I - 1];
+ Entries[I - 1].name = &FlattenedEntryNames[TotalNameSize];
+ }
+
+ // Set the entries pointers.
+ EntriesBegin = Entries.begin();
+ EntriesEnd = Entries.end();
+ }
+
+ /// Get the image size.
+ size_t getSize() const {
+ return llvm::omp::target::getPtrDiff(ImageEnd, ImageStart);
+ }
+
+ /// Getter and setter for the dynamic library.
+ DynamicLibrary &getDynamicLibrary() { return DynLib; }
+ void setDynamicLibrary(const DynamicLibrary &Lib) { DynLib = Lib; }
+
+private:
+ DynamicLibrary DynLib;
+};
+
+/// Event implementations on Host side.
+namespace OriginEvents {
+
+EventTy allocateBuffer(MPIRequestManagerTy RequestManager, int64_t Size,
+ void **Buffer) {
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+
+ // Event completion notification
+ RequestManager.receive(Buffer, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy deleteBuffer(MPIRequestManagerTy RequestManager, void *Buffer) {
+ RequestManager.send(&Buffer, sizeof(void *), MPI_BYTE);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy submit(MPIRequestManagerTy RequestManager, EventDataHandleTy DataHandle,
+ void *DstBuffer, int64_t Size) {
+ RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+
+ RequestManager.sendInBatchs(DataHandle.get(), Size);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy retrieve(MPIRequestManagerTy RequestManager, void *OrgBuffer,
+ const void *DstBuffer, int64_t Size) {
+ RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.receiveInBatchs(OrgBuffer, Size);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy exchange(MPIRequestManagerTy RequestManager, int SrcDevice,
+ const void *OrgBuffer, int DstDevice, void *DstBuffer,
+ int64_t Size) {
+ // Send data to SrcDevice
+ RequestManager.send(&OrgBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.send(&DstDevice, 1, MPI_INT);
+
+ // Send data to DstDevice
+ RequestManager.OtherRank = DstDevice;
+ RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.send(&SrcDevice, 1, MPI_INT);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ RequestManager.OtherRank = SrcDevice;
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy execute(MPIRequestManagerTy RequestManager, EventDataHandleTy Args,
+ uint32_t NumArgs, void *Func) {
+ RequestManager.send(&NumArgs, 1, MPI_UINT32_T);
+ RequestManager.send(Args.get(), NumArgs * sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Func, sizeof(void *), MPI_BYTE);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy sync(EventTy Event) {
+ while (!Event.done())
+ co_await std::suspend_always{};
+
+ co_return llvm::Error::success();
+}
+
+EventTy loadBinary(MPIRequestManagerTy RequestManager,
+ const __tgt_device_image *Image,
+ llvm::SmallVector<void *> *DeviceImageAddrs) {
+ auto &[ImageStart, ImageEnd, EntriesBegin, EntriesEnd] = *Image;
+
+ // Send the target table sizes.
+ size_t ImageSize = (size_t)ImageEnd - (size_t)ImageStart;
+ size_t EntryCount = EntriesEnd - EntriesBegin;
+ llvm::SmallVector<size_t> EntryNameSizes(EntryCount);
+
+ for (size_t I = 0; I < EntryCount; I++) {
+ // Note: +1 for the terminator.
+ EntryNameSizes[I] = std::strlen(EntriesBegin[I].name) + 1;
+ }
+
+ RequestManager.send(&ImageSize, 1, MPI_UINT64_T);
+ RequestManager.send(&EntryCount, 1, MPI_UINT64_T);
+ RequestManager.send(EntryNameSizes.begin(), EntryCount, MPI_UINT64_T);
+
+ // Send the image bytes and the table entries.
+ RequestManager.send(ImageStart, ImageSize, MPI_BYTE);
+
+ for (size_t I = 0; I < EntryCount; I++) {
+ RequestManager.send(&EntriesBegin[I].addr, 1, MPI_UINT64_T);
+ RequestManager.send(EntriesBegin[I].name, EntryNameSizes[I], MPI_CHAR);
+ RequestManager.send(&EntriesBegin[I].size, 1, MPI_UINT64_T);
+ RequestManager.send(&EntriesBegin[I].flags, 1, MPI_INT32_T);
+ RequestManager.send(&EntriesBegin[I].data, 1, MPI_INT32_T);
+ RequestManager.receive(&((*DeviceImageAddrs)[I]), 1, MPI_UINT64_T);
+ }
+
+ co_return (co_await RequestManager);
+}
+
+EventTy exit(MPIRequestManagerTy RequestManager) {
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+} // namespace OriginEvents
+
+/// Event Implementations on Device side.
+namespace DestinationEvents {
+
+EventTy allocateBuffer(MPIRequestManagerTy RequestManager) {
+ int64_t Size = 0;
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ void *Buffer = malloc(Size);
+ RequestManager.send(&Buffer, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy deleteBuffer(MPIRequestManagerTy RequestManager) {
+ void *Buffer = nullptr;
+ RequestManager.receive(&Buffer, sizeof(void *), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ free(Buffer);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy submit(MPIRequestManagerTy RequestManager) {
+ void *Buffer = nullptr;
+ int64_t Size = 0;
+ RequestManager.receive(&Buffer, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ RequestManager.receiveInBatchs(Buffer, Size);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy retrieve(MPIRequestManagerTy RequestManager) {
+ void *Buffer = nullptr;
+ int64_t Size = 0;
+ RequestManager.receive(&Buffer, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ RequestManager.sendInBatchs(Buffer, Size);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy exchangeSrc(MPIRequestManagerTy RequestManager) {
+ void *SrcBuffer;
+ int64_t Size;
+ int DstDevice;
+ // Save head node rank
+ int HeadNodeRank = RequestManager.OtherRank;
+
+ RequestManager.receive(&SrcBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+ RequestManager.receive(&DstDevice, 1, MPI_INT);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Set the Destination Rank in RequestManager
+ RequestManager.OtherRank = DstDevice;
+
+ // Send buffer to target device
+ RequestManager.sendInBatchs(SrcBuffer, Size);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Set the HeadNode Rank to send the final notificatin
+ RequestManager.OtherRank = HeadNodeRank;
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy exchangeDst(MPIRequestManagerTy RequestManager) {
+ void *DstBuffer;
+ int64_t Size;
+ int SrcDevice;
+ // Save head node rank
+ int HeadNodeRank = RequestManager.OtherRank;
+
+ RequestManager.receive(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+ RequestManager.receive(&SrcDevice, 1, MPI_INT);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Set the Source Rank in RequestManager
+ RequestManager.OtherRank = SrcDevice;
+
+ // Receive buffer from the Source device
+ RequestManager.receiveInBatchs(DstBuffer, Size);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Set the HeadNode Rank to send the final notificatin
+ RequestManager.OtherRank = HeadNodeRank;
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy execute(MPIRequestManagerTy RequestManager,
+ __tgt_device_image &DeviceImage) {
+
+ uint32_t NumArgs = 0;
+ RequestManager.receive(&NumArgs, 1, MPI_UINT32_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ llvm::SmallVector<void *> Args(NumArgs);
+ llvm::SmallVector<void *> ArgPtrs(NumArgs);
+
+ RequestManager.receive(Args.data(), NumArgs * sizeof(uintptr_t), MPI_BYTE);
+ void (*TargetFunc)(void) = nullptr;
+ RequestManager.receive(&TargetFunc, sizeof(uintptr_t), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Get Args references
+ for (unsigned I = 0; I < NumArgs; I++) {
+ ArgPtrs[I] = &Args[I];
+ }
+
+ ffi_cif Cif{};
+ llvm::SmallVector<ffi_type *> ArgsTypes(NumArgs, &ffi_type_pointer);
+ ffi_status Status = ffi_prep_cif(&Cif, FFI_DEFAULT_ABI, NumArgs,
+ &ffi_type_void, &ArgsTypes[0]);
+
+ if (Status != FFI_OK) {
+ co_return createError("Error in ffi_prep_cif: %d", Status);
+ }
+
+ long Return;
+ ffi_call(&Cif, TargetFunc, &Return, &ArgPtrs[0]);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy loadBinary(MPIRequestManagerTy RequestManager, DeviceImage &Image) {
+ // Receive the target table sizes.
+ size_t ImageSize = 0;
+ size_t EntryCount = 0;
+ RequestManager.receive(&ImageSize, 1, MPI_UINT64_T);
+ RequestManager.receive(&EntryCount, 1, MPI_UINT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ llvm::SmallVector<size_t> EntryNameSizes(EntryCount);
+
+ RequestManager.receive(EntryNameSizes.begin(), EntryCount, MPI_UINT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Create the device name with the appropriate sizes and receive its content.
+ Image = DeviceImage(ImageSize, EntryCount);
+ Image.setImageEntries(EntryNameSizes);
+
+ // Received the image bytes and the table entries.
+ RequestManager.receive(Image.ImageStart, ImageSize, MPI_BYTE);
+
+ for (size_t I = 0; I < EntryCount; I++) {
+ RequestManager.receive(&Image.Entries[I].addr, 1, MPI_UINT64_T);
+ RequestManager.receive(Image.Entries[I].name, EntryNameSizes[I], MPI_CHAR);
+ RequestManager.receive(&Image.Entries[I].size, 1, MPI_UINT64_T);
+ RequestManager.receive(&Image.Entries[I].flags, 1, MPI_INT32_T);
+ RequestManager.receive(&Image.Entries[I].data, 1, MPI_INT32_T);
+ }
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // The code below is for CPU plugin only
+ // Create a temporary file.
+ char TmpFileName[] = "/tmp/tmpfile_XXXXXX";
+ int TmpFileFd = mkstemp(TmpFileName);
+ if (TmpFileFd == -1)
+ co_return createError("Failed to create tmpfile for loading target image");
+
+ // Open the temporary file.
+ FILE *TmpFile = fdopen(TmpFileFd, "wb");
+ if (!TmpFile)
+ co_return createError("Failed to open tmpfile %s for loading target image",
+ TmpFileName);
+
+ // Write the image into the temporary file.
+ size_t Written = fwrite(Image.ImageStart, Image.getSize(), 1, TmpFile);
+ if (Written != 1)
+ co_return createError("Failed to write target image to tmpfile %s",
+ TmpFileName);
+
+ // Close the temporary file.
+ int Ret = fclose(TmpFile);
+ if (Ret)
+ co_return createError("Failed to close tmpfile %s with the target image",
+ TmpFileName);
+
+ // Load the temporary file as a dynamic library.
+ std::string ErrMsg;
+ DynamicLibrary DynLib =
+ DynamicLibrary::getPermanentLibrary(TmpFileName, &ErrMsg);
+
+ // Check if the loaded library is valid.
+ if (!DynLib.isValid())
+ co_return createError("Failed to load target image: %s", ErrMsg.c_str());
+
+ // Save a reference of the image's dynamic library.
+ Image.setDynamicLibrary(DynLib);
+
+ // Delete TmpFile
+ unlink(TmpFileName);
+
+ for (size_t I = 0; I < EntryCount; I++) {
+ Image.Entries[I].addr = DynLib.getAddressOfSymbol(Image.Entries[I].name);
+ RequestManager.send(&Image.Entries[I].addr, 1, MPI_UINT64_T);
+ }
+
+ co_return (co_await RequestManager);
+}
+
+EventTy exit(MPIRequestManagerTy RequestManager,
+ std::atomic<EventSystemStateTy> &EventSystemState) {
+ EventSystemStateTy OldState =
+ EventSystemState.exchange(EventSystemStateTy::EXITED);
+ assert(OldState != EventSystemStateTy::EXITED &&
+ "Exit event received multiple times");
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+} // namespace DestinationEvents
+
+/// Event Queue implementation
+EventQueue::EventQueue() : Queue(), QueueMtx(), CanPopCv() {}
+
+size_t EventQueue::size() {
+ std::lock_guard<std::mutex> lock(QueueMtx);
+ return Queue.size();
+}
+
+void EventQueue::push(EventTy &&Event) {
+ {
+ std::unique_lock<std::mutex> lock(QueueMtx);
+ Queue.emplace(Event);
+ }
+
+ // Notifies a thread possibly blocked by an empty queue.
+ CanPopCv.notify_one();
+}
+
+EventTy EventQueue::pop(std::stop_token &Stop) {
+ std::unique_lock<std::mutex> Lock(QueueMtx);
+
+ // Waits for at least one item to be pushed.
+ const bool HasNewEvent =
+ CanPopCv.wait(Lock, Stop, [&] { return !Queue.empty(); });
+
+ if (!HasNewEvent) {
+ assert(Stop.stop_requested() && "Queue was empty while running.");
+ return EventTy();
+ }
+
+ EventTy TargetEvent = std::move(Queue.front());
+ Queue.pop();
+ return TargetEvent;
+}
+
+/// Event System implementation
+EventSystemTy::EventSystemTy()
+ : EventSystemState(EventSystemStateTy::CREATED) {}
+
+EventSystemTy::~EventSystemTy() {
+ if (EventSystemState == EventSystemStateTy::FINALIZED)
+ return;
+
+ REPORT("Destructing internal event system before deinitializing it.\n");
+ deinitialize();
+}
+
+bool EventSystemTy::initialize() {
+ if (EventSystemState >= EventSystemStateTy::INITIALIZED) {
+ REPORT("Trying to initialize event system twice.\n");
+ return false;
+ }
+
+ if (!createLocalMPIContext())
+ return false;
+
+ EventSystemState = EventSystemStateTy::INITIALIZED;
+
+ return true;
+}
+
+bool EventSystemTy::deinitialize() {
+ if (EventSystemState == EventSystemStateTy::FINALIZED) {
+ REPORT("Trying to deinitialize event system twice.\n");
+ return false;
+ }
+
+ if (EventSystemState == EventSystemStateTy::RUNNING) {
+ REPORT("Trying to deinitialize event system while it is running.\n");
+ return false;
+ }
+
+ // Only send exit events from the host side
+ if (isHost() && WorldSize > 1) {
+ const int NumWorkers = WorldSize - 1;
+ llvm::SmallVector<EventTy> ExitEvents(NumWorkers);
+ for (int WorkerRank = 0; WorkerRank < NumWorkers; WorkerRank++) {
+ ExitEvents[WorkerRank] = createEvent(OriginEvents::exit, WorkerRank);
+ ExitEvents[WorkerRank].resume();
+ }
+
+ bool SuccessfullyExited = true;
+ for (int WorkerRank = 0; WorkerRank < NumWorkers; WorkerRank++) {
+ ExitEvents[WorkerRank].wait();
+ SuccessfullyExited &= ExitEvents[WorkerRank].done();
+ auto Error = ExitEvents[WorkerRank].getError();
+ if (Error)
+ REPORT("Exit event failed with msg: %s\n",
+ toString(std::move(Error)).data());
+ }
+
+ if (!SuccessfullyExited) {
+ REPORT("Failed to stop worker processes.\n");
+ return false;
+ }
+ }
+
+ if (!destroyLocalMPIContext())
+ return false;
+
+ EventSystemState = EventSystemStateTy::FINALIZED;
+
+ return true;
+}
+
+EventTy EventSystemTy::createExchangeEvent(int SrcDevice, const void *SrcBuffer,
+ int DstDevice, void *DstBuffer,
+ int64_t Size) {
+ const int EventTag = createNewEventTag();
+ auto &EventComm = getNewEventComm(EventTag);
+
+ int SrcEventNotificationInfo[] = {static_cast<int>(EventTypeTy::EXCHANGE_SRC),
+ EventTag};
+
+ int DstEventNotificationInfo[] = {static_cast<int>(EventTypeTy::EXCHANGE_DST),
+ EventTag};
+
+ MPI_Request SrcRequest = MPI_REQUEST_NULL;
+ MPI_Request DstRequest = MPI_REQUEST_NULL;
+
+ int MPIError = MPI_Isend(SrcEventNotificationInfo, 2, MPI_INT, SrcDevice,
+ static_cast<int>(ControlTagsTy::EVENT_REQUEST),
+ GateThreadComm, &SrcRequest);
+
+ MPIError &= MPI_Isend(DstEventNotificationInfo, 2, MPI_INT, DstDevice,
+ static_cast<int>(ControlTagsTy::EVENT_REQUEST),
+ GateThreadComm, &DstRequest);
+
+ if (MPIError != MPI_SUCCESS)
+ co_return createError(
+ "MPI failed during exchange event notification with error %d",
+ MPIError);
+
+ MPIRequestManagerTy RequestManager(EventComm, EventTag, SrcDevice,
+ {SrcRequest, DstRequest});
+
+ co_return (co_await OriginEvents::exchange(std::move(RequestManager),
+ SrcDevice, SrcBuffer, DstDevice,
+ DstBuffer, Size));
+}
+
+void EventSystemTy::runEventHandler(std::stop_token Stop, EventQueue &Queue) {
+ while (EventSystemState == EventSystemStateTy::RUNNING || Queue.size() > 0) {
+ EventTy Event = Queue.pop(Stop);
+
+ // Re-checks the stop condition when no event was found.
+ if (Event.empty()) {
+ continue;
+ }
+
+ Event.resume();
+
+ if (!Event.done()) {
+ Queue.push(std::move(Event));
+ }
+
+ auto Error = Event.getError();
+ if (Error)
+ REPORT("Internal event failed with msg: %s\n",
+ toString(std::move(Error)).data());
+ }
+}
+
+void EventSystemTy::runGateThread() {
+ // Device image to be used by this gate thread.
+ DeviceImage Image;
+
+ // Updates the event state and
+ EventSystemState = EventSystemStateTy::RUNNING;
+
+ // Spawns the event handlers.
+ llvm::SmallVector<std::jthread> EventHandlers;
+ EventHandlers.resize(NumExecEventHandlers.get() + NumDataEventHandlers.get());
+ int EventHandlersSize = EventHandlers.size();
+ auto HandlerFunction = std::bind_front(&EventSystemTy::runEventHandler, this);
+ for (int Idx = 0; Idx < EventHandlersSize; Idx++) {
+ EventHandlers[Idx] =
+ std::jthread(HandlerFunction, std::ref(Idx < NumExecEventHandlers.get()
+ ? ExecEventQueue
+ : DataEventQueue));
+ }
+
+ // Executes the gate thread logic
+ while (EventSystemState == EventSystemStateTy::RUNNING) {
+ // Checks for new incoming event requests.
+ MPI_Message EventReqMsg;
+ MPI_Status EventStatus;
+ int HasReceived = false;
+ MPI_Improbe(MPI_ANY_SOURCE, static_cast<int>(ControlTagsTy::EVENT_REQUEST),
+ GateThreadComm, &HasReceived, &EventReqMsg, MPI_STATUS_IGNORE);
+
+ // If none was received, wait for `EVENT_POLLING_RATE`us for the next
+ // check.
+ if (!HasReceived) {
+ std::this_thread::sleep_for(
+ std::chrono::microseconds(EventPollingRate.get()));
+ continue;
+ }
+
+ // Acquires the event information from the received request, which are:
+ // - Event type
+ // - Event tag
+ // - Target comm
+ // - Event source rank
+ int EventInfo[2];
+ MPI_Mrecv(EventInfo, 2, MPI_INT, &EventReqMsg, &EventStatus);
+ const auto NewEventType = static_cast<EventTypeTy>(EventInfo[0]);
+ MPIRequestManagerTy RequestManager(getNewEventComm(EventInfo[1]),
+ EventInfo[1], EventStatus.MPI_SOURCE);
+
+ // Creates a new receive event of 'event_type' type.
+ using namespace DestinationEvents;
+ using enum EventTypeTy;
+ EventTy NewEvent;
+ switch (NewEventType) {
+ case ALLOC:
+ NewEvent = allocateBuffer(std::move(RequestManager));
+ break;
+ case DELETE:
+ NewEvent = deleteBuffer(std::move(RequestManager));
+ break;
+ case SUBMIT:
+ NewEvent = submit(std::move(RequestManager));
+ break;
+ case RETRIEVE:
+ NewEvent = retrieve(std::move(RequestManager));
+ break;
+ case EXCHANGE_SRC:
+ NewEvent = exchangeSrc(std::move(RequestManager));
+ break;
+ case EXCHANGE_DST:
+ NewEvent = exchangeDst(std::move(RequestManager));
+ break;
+ case EXECUTE:
+ NewEvent = execute(std::move(RequestManager), Image);
+ break;
+ case EXIT:
+ NewEvent = exit(std::move(RequestManager), EventSystemState);
+ break;
+ case LOAD_BINARY:
+ NewEvent = loadBinary(std::move(RequestManager), Image);
+ break;
+ case SYNC:
+ case EXCHANGE:
+ assert(false && "Trying to create a local event on a remote node");
+ }
+
+ if (NewEventType == EXECUTE) {
+ ExecEventQueue.push(std::move(NewEvent));
+ } else {
+ DataEventQueue.push(std::move(NewEvent));
+ }
+ }
+
+ assert(EventSystemState == EventSystemStateTy::EXITED &&
+ "Event State should be EXITED after receiving an Exit event");
+}
+
+/// Creates a new event tag of at least 'FIRST_EVENT' value.
+/// Tag values smaller than 'FIRST_EVENT' are reserved for control
+/// messages between the event systems of different MPI processes.
+int EventSystemTy::createNewEventTag() {
+ int tag = 0;
+
+ do {
+ tag = EventCounter.fetch_add(1) % MPITagMaxValue;
+ } while (tag < static_cast<int>(ControlTagsTy::FIRST_EVENT));
+
+ return tag;
+}
+
+MPI_Comm &EventSystemTy::getNewEventComm(int MPITag) {
+ // Retrieve a comm using a round-robin strategy around the event's mpi tag.
+ return EventCommPool[MPITag % EventCommPool.size()];
+}
+
+static const char *threadLevelToString(int ThreadLevel) {
+ switch (ThreadLevel) {
+ case MPI_THREAD_SINGLE:
+ return "MPI_THREAD_SINGLE";
+ case MPI_THREAD_SERIALIZED:
+ return "MPI_THREAD_SERIALIZED";
+ case MPI_THREAD_FUNNELED:
+ return "MPI_THREAD_FUNNELED";
+ case MPI_THREAD_MULTIPLE:
+ return "MPI_THREAD_MULTIPLE";
+ default:
+ return "unkown";
+ }
+}
+
+/// Initialize the MPI context.
+bool EventSystemTy::createLocalMPIContext() {
+ int MPIError = MPI_SUCCESS;
+ int IsMPIInitialized = 0;
+ int ThreadLevel = MPI_THREAD_SINGLE;
+
+ MPI_Initialized(&IsMPIInitialized);
+
+ if (IsMPIInitialized)
+ MPI_Query_thread(&ThreadLevel);
+ else
+ MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &ThreadLevel);
+
+ CHECK(ThreadLevel == MPI_THREAD_MULTIPLE,
+ "MPI plugin requires a MPI implementation with %s thread level. "
+ "Implementation only supports up to %s.\n",
+ threadLevelToString(MPI_THREAD_MULTIPLE),
+ threadLevelToString(ThreadLevel));
+
+ if (IsMPIInitialized && ThreadLevel == MPI_THREAD_MULTIPLE) {
+ MPI_Comm_rank(MPI_COMM_WORLD, &LocalRank);
+ MPI_Comm_size(MPI_COMM_WORLD, &WorldSize);
+ return true;
+ }
+
+ // Create gate thread comm.
+ MPIError = MPI_Comm_dup(MPI_COMM_WORLD, &GateThreadComm);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to create gate thread MPI comm with error %d\n", MPIError);
+
+ // Create event comm pool.
+ EventCommPool.resize(NumMPIComms.get(), MPI_COMM_NULL);
+ for (auto &Comm : EventCommPool) {
+ MPI_Comm_dup(MPI_COMM_WORLD, &Comm);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to create MPI comm pool with error %d\n", MPIError);
+ }
+
+ // Get local MPI process description.
+ MPIError = MPI_Comm_rank(GateThreadComm, &LocalRank);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to acquire the local MPI rank with error %d\n", MPIError);
+
+ MPIError = MPI_Comm_size(GateThreadComm, &WorldSize);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to acquire the world size with error %d\n", MPIError);
+
+ // Get max value for MPI tags.
+ MPI_Aint *Value = nullptr;
+ int Flag = 0;
+ MPIError = MPI_Comm_get_attr(GateThreadComm, MPI_TAG_UB, &Value, &Flag);
+ CHECK(Flag == 1 && MPIError == MPI_SUCCESS,
+ "Failed to acquire the MPI max tag value with error %d\n", MPIError);
+ MPITagMaxValue = *Value;
+
+ return true;
+}
+
+/// Destroy the MPI context.
+bool EventSystemTy::destroyLocalMPIContext() {
+ int MPIError = MPI_SUCCESS;
+
+ if (GateThreadComm == MPI_COMM_NULL) {
+ return true;
+ }
+
+ // Note: We don't need to assert here since application part of the program
+ // was finished.
+ // Free gate thread comm.
+ MPIError = MPI_Comm_free(&GateThreadComm);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to destroy the gate thread MPI comm with error %d\n", MPIError);
+
+ // Free event comm pool.
+ for (auto &comm : EventCommPool) {
+ MPI_Comm_free(&comm);
+ CHECK(MPIError == MPI_SUCCESS,
+ "Failed to destroy the event MPI comm with error %d\n", MPIError);
+ }
+ EventCommPool.resize(0);
+
+ // Finalize the global MPI session.
+ int IsFinalized = false;
+ MPIError = MPI_Finalized(&IsFinalized);
+
+ if (IsFinalized) {
+ DP("MPI was already finalized externally.\n");
+ } else {
+ MPIError = MPI_Finalize();
+ CHECK(MPIError == MPI_SUCCESS, "Failed to finalize MPI with error: %d\n",
+ MPIError);
+ }
+
+ return true;
+}
+
+int EventSystemTy::getNumWorkers() const {
+ if (isHost())
+ return WorldSize - 1;
+ return 0;
+}
+
+int EventSystemTy::isHost() const { return LocalRank == WorldSize - 1; };
diff --git a/offload/plugins-nextgen/mpi/src/EventSystem.h b/offload/plugins-nextgen/mpi/src/EventSystem.h
new file mode 100644
index 0000000000000..8d830b8f4f178
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/src/EventSystem.h
@@ -0,0 +1,493 @@
+//===------- event_system.h - Concurrent MPI communication ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the declarations of the MPI Event System used by the MPI
+// target.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _OMPTARGET_OMPCLUSTER_EVENT_SYSTEM_H_
+#define _OMPTARGET_OMPCLUSTER_EVENT_SYSTEM_H_
+
+#include <atomic>
+#include <cassert>
+#include <concepts>
+#include <condition_variable>
+#include <coroutine>
+#include <cstddef>
+#include <cstdint>
+#include <exception>
+#include <memory>
+#include <mutex>
+#include <optional>
+#include <queue>
+#include <thread>
+#include <type_traits>
+
+#define MPICH_SKIP_MPICXX
+#include <mpi.h>
+
+#include "llvm/ADT/SmallVector.h"
+
+#include "Shared/EnvironmentVar.h"
+#include "Shared/Utils.h"
+
+/// External forward declarations.
+struct __tgt_device_image;
+
+/// Template helper for generating llvm::Error instances from events.
+template <typename... ArgsTy>
+static llvm::Error createError(const char *ErrFmt, ArgsTy... Args) {
+ return llvm::createStringError(llvm::inconvertibleErrorCode(), ErrFmt,
+ Args...);
+}
+
+/// The event type (type of action it will performed).
+///
+/// Enumerates the available events. Each enum item should be accompanied by an
+/// event class derived from BaseEvent. All the events are executed at a remote
+/// MPI process target by the event.
+enum class EventTypeTy : unsigned int {
+ // Memory management.
+ ALLOC, // Allocates a buffer at the remote process.
+ DELETE, // Deletes a buffer at the remote process.
+
+ // Data movement.
+ SUBMIT, // Sends a buffer data to a remote process.
+ RETRIEVE, // Receives a buffer data from a remote process.
+ EXCHANGE, // Wait data exchange between two remote processes.
+ EXCHANGE_SRC, // SRC side of the exchange event.
+ EXCHANGE_DST, // DST side of the exchange event.
+
+ // Target region execution.
+ EXECUTE, // Executes a target region at the remote process.
+
+ // Local event used to wait on other events.
+ SYNC,
+
+ // Internal event system commands.
+ LOAD_BINARY, // Transmits the binary descriptor to all workers
+ EXIT // Stops the event system execution at the remote process.
+};
+
+/// EventType to string conversion.
+///
+/// \returns the string representation of \p type.
+const char *toString(EventTypeTy Type);
+
+/// Coroutine events
+///
+/// Return object for the event system coroutines. This class works as an
+/// external handle for the coroutine execution, allowing anyone to: query for
+/// the coroutine completion, resume the coroutine and check its state.
+/// Moreover, this class allows for coroutines to be chainable, meaning a
+/// coroutine function may wait on the completion of another one by using the
+/// co_await operator, all through a single external handle.
+struct EventTy {
+ /// Internal event handle to access C++ coroutine states.
+ struct promise_type;
+ using CoHandleTy = std::coroutine_handle<promise_type>;
+ std::shared_ptr<void> HandlePtr;
+
+ /// Internal (and required) promise type. Allows for customization of the
+ /// coroutines behavior and to store custom data inside the coroutine itself.
+ struct promise_type {
+ /// Coroutines are chained as a reverse linked-list. The most-recent
+ /// coroutine in a chain points to the previous one and so on, until the
+ /// root (and first) coroutine, which then points to the most-recent one.
+ /// The root always refers to the coroutine stored in the external handle,
+ /// the only handle an external user have access to.
+ CoHandleTy PrevHandle;
+ CoHandleTy RootHandle;
+
+ /// Indicates if the coroutine was completed successfully. Contains the
+ /// appropriate error otherwise.
+ std::optional<llvm::Error> CoroutineError;
+
+ promise_type() : CoroutineError(std::nullopt) {
+ PrevHandle = RootHandle = CoHandleTy::from_promise(*this);
+ }
+
+ /// Event coroutines should always suspend upon creation and finalization.
+ std::suspend_always initial_suspend() { return {}; }
+ std::suspend_always final_suspend() noexcept { return {}; }
+
+ /// Coroutines should return llvm::Error::success() or an appropriate error
+ /// message.
+ void return_value(llvm::Error &&GivenError) noexcept {
+ CoroutineError = std::move(GivenError);
+ }
+
+ /// Any unhandled exception should create an externally visible error.
+ void unhandled_exception() {
+ assert(std::uncaught_exceptions() > 0 &&
+ "Function should only be called if an uncaught exception is "
+ "generated inside the coroutine");
+ CoroutineError = createError("Event generated an unhandled exception");
+ }
+
+ /// Returns the external coroutine handle from the promise object.
+ EventTy get_return_object() {
+ void *HandlePtr = CoHandleTy::from_promise(*this).address();
+ return {std::shared_ptr<void>(HandlePtr, [](void *HandlePtr) {
+ assert(HandlePtr);
+ CoHandleTy::from_address(HandlePtr).destroy();
+ })};
+ }
+ };
+
+ /// Returns the external coroutine handle from the event.
+ CoHandleTy getHandle() const {
+ return CoHandleTy::from_address(HandlePtr.get());
+ }
+
+ /// Execution handling.
+ /// Resume the coroutine execution up until the next suspension point.
+ void resume();
+
+ /// Blocks the caller thread until the coroutine is completed.
+ void wait();
+
+ /// Checks if the coroutine is completed or not.
+ bool done() const;
+
+ /// Coroutine state handling.
+ /// Checks if the coroutine is valid.
+ bool empty() const;
+
+ /// Get the returned error from the coroutine.
+ llvm::Error getError() const;
+
+ /// EventTy instances are also awaitables. This means one can link multiple
+ /// EventTy together by calling the co_await operator on one another. For this
+ /// to work, EventTy must implement the following three functions.
+ ///
+ /// Called on the new coroutine before suspending the current one on co_await.
+ /// If returns true, the new coroutine is already completed, thus it should
+ /// not be linked against the current one and the current coroutine can
+ /// continue without suspending.
+ bool await_ready() { return getHandle().done(); }
+
+ /// Called on the new coroutine when the current one is suspended. It is
+ /// responsible for chaining coroutines together.
+ void await_suspend(CoHandleTy SuspendedHandle) {
+ auto Handle = getHandle();
+ auto &CurrPromise = Handle.promise();
+ auto &SuspendedPromise = SuspendedHandle.promise();
+ auto &RootPromise = SuspendedPromise.RootHandle.promise();
+
+ CurrPromise.PrevHandle = SuspendedHandle;
+ CurrPromise.RootHandle = SuspendedPromise.RootHandle;
+
+ RootPromise.PrevHandle = Handle;
+ }
+
+ /// Called on the new coroutine when the current one is resumed. Used to
+ /// return errors when co_awaiting on other EventTy.
+ llvm::Error await_resume() {
+ auto &Error = getHandle().promise().CoroutineError;
+
+ if (Error) {
+ return std::move(*Error);
+ }
+
+ return llvm::Error::success();
+ }
+};
+
+/// Coroutine like manager for many non-blocking MPI calls. Allows for coroutine
+/// to co_await on the registered MPI requests.
+class MPIRequestManagerTy {
+ /// Target specification for the MPI messages.
+ const MPI_Comm Comm;
+ const int Tag;
+ /// Pending MPI requests.
+ llvm::SmallVector<MPI_Request> Requests;
+
+public:
+ /// Target peer to send and receive messages
+ int OtherRank;
+
+ MPIRequestManagerTy(MPI_Comm Comm, int Tag, int OtherRank,
+ llvm::SmallVector<MPI_Request> InitialRequests =
+ {}) // TODO: Change to initializer_list
+ : Comm(Comm), Tag(Tag), Requests(InitialRequests), OtherRank(OtherRank) {}
+
+ /// This class should not be copied.
+ MPIRequestManagerTy(const MPIRequestManagerTy &) = delete;
+ MPIRequestManagerTy &operator=(const MPIRequestManagerTy &) = delete;
+
+ MPIRequestManagerTy(MPIRequestManagerTy &&Other) noexcept
+ : Comm(Other.Comm), Tag(Other.Tag), Requests(Other.Requests),
+ OtherRank(Other.OtherRank) {
+ Other.Requests = {};
+ }
+
+ MPIRequestManagerTy &operator=(MPIRequestManagerTy &&Other) = delete;
+
+ ~MPIRequestManagerTy();
+
+ /// Sends a buffer of given datatype items with determined size to target.
+ void send(const void *Buffer, int Size, MPI_Datatype Datatype);
+
+ /// Send a buffer with determined size to target in batchs.
+ void sendInBatchs(void *Buffer, int Size);
+
+ /// Receives a buffer of given datatype items with determined size from
+ /// target.
+ void receive(void *Buffer, int Size, MPI_Datatype Datatype);
+
+ /// Receives a buffer with determined size from target in batchs.
+ void receiveInBatchs(void *Buffer, int Size);
+
+ /// Coroutine that waits on all internal pending requests.
+ EventTy wait();
+};
+
+/// Data handle for host buffers in event. It keeps the host data even if the
+/// original buffer is deallocated before the event happens.
+using EventDataHandleTy = std::shared_ptr<void>;
+
+/// Routines to alloc/dealloc pinned host memory.
+///
+/// Allocate \p Size of host memory and returns its ptr.
+void *memAllocHost(int64_t Size);
+
+/// Deallocate the host memory pointered by \p HstPrt.
+int memFreeHost(void *HstPtr);
+
+/// Coroutine events created at the origin rank of the event.
+namespace OriginEvents {
+
+EventTy allocateBuffer(MPIRequestManagerTy RequestManager, int64_t Size,
+ void **Buffer);
+EventTy deleteBuffer(MPIRequestManagerTy RequestManager, void *Buffer);
+EventTy submit(MPIRequestManagerTy RequestManager,
+ EventDataHandleTy DataHandler, void *DstBuffer, int64_t Size);
+EventTy retrieve(MPIRequestManagerTy RequestManager, void *OrgBuffer,
+ const void *DstBuffer, int64_t Size);
+EventTy exchange(MPIRequestManagerTy RequestManager, int SrcDevice,
+ const void *OrgBuffer, int DstDevice, void *DstBuffer,
+ int64_t Size);
+EventTy execute(MPIRequestManagerTy RequestManager, EventDataHandleTy Args,
+ uint32_t NumArgs, void *Func);
+EventTy sync(EventTy Event);
+EventTy loadBinary(MPIRequestManagerTy RequestManager,
+ const __tgt_device_image *Image,
+ llvm::SmallVector<void *> *DeviceImageAddrs);
+EventTy exit(MPIRequestManagerTy RequestManager);
+
+/// Transform a function pointer to its representing enumerator.
+template <typename FuncTy> constexpr EventTypeTy toEventType(FuncTy) {
+ using enum EventTypeTy;
+
+ if constexpr (std::is_same_v<FuncTy, decltype(&allocateBuffer)>)
+ return ALLOC;
+ else if constexpr (std::is_same_v<FuncTy, decltype(&deleteBuffer)>)
+ return DELETE;
+ else if constexpr (std::is_same_v<FuncTy, decltype(&submit)>)
+ return SUBMIT;
+ else if constexpr (std::is_same_v<FuncTy, decltype(&retrieve)>)
+ return RETRIEVE;
+ else if constexpr (std::is_same_v<FuncTy, decltype(&exchange)>)
+ return EXCHANGE;
+ else if constexpr (std::is_same_v<FuncTy, decltype(&execute)>)
+ return EXECUTE;
+ else if constexpr (std::is_same_v<FuncTy, decltype(&sync)>)
+ return SYNC;
+ else if constexpr (std::is_same_v<FuncTy, decltype(&exit)>)
+ return EXIT;
+ else if constexpr (std::is_same_v<FuncTy, decltype(&loadBinary)>)
+ return LOAD_BINARY;
+
+ assert(false && "Invalid event function");
+}
+
+} // namespace OriginEvents
+
+/// Event Queue
+///
+/// Event queue for received events.
+class EventQueue {
+private:
+ /// Base internal queue.
+ std::queue<EventTy> Queue;
+ /// Base queue sync mutex.
+ std::mutex QueueMtx;
+
+ /// Conditional variables to block popping on an empty queue.
+ std::condition_variable_any CanPopCv;
+
+public:
+ /// Event Queue default constructor.
+ EventQueue();
+
+ /// Gets current queue size.
+ size_t size();
+
+ /// Push an event to the queue, resizing it when needed.
+ void push(EventTy &&Event);
+
+ /// Pops an event from the queue, waiting if the queue is empty. When stopped,
+ /// returns a nullptr event.
+ EventTy pop(std::stop_token &Stop);
+};
+
+/// Event System
+///
+/// MPI tags used in control messages.
+///
+/// Special tags values used to send control messages between event systems of
+/// different processes. When adding new tags, please summarize the tag usage
+/// with a side comment as done below.
+enum class ControlTagsTy : int {
+ EVENT_REQUEST = 0, // Used by event handlers to receive new event requests.
+ FIRST_EVENT // Tag used by the first event. Must always be placed last.
+};
+
+/// Event system execution state.
+///
+/// Describes the event system state through the program.
+enum class EventSystemStateTy {
+ CREATED, // ES was created but it is not ready to send or receive new
+ // events.
+ INITIALIZED, // ES was initialized alongside internal MPI states. It is ready
+ // to send new events, but not receive them.
+ RUNNING, // ES is running and ready to receive new events.
+ EXITED, // ES was stopped.
+ FINALIZED // ES was finalized and cannot run anything else.
+};
+
+/// The distributed event system.
+class EventSystemTy {
+ /// MPI definitions.
+ /// The largest MPI tag allowed by its implementation.
+ int32_t MPITagMaxValue = 0;
+
+ /// Communicator used by the gate thread and base communicator for the event
+ /// system.
+ MPI_Comm GateThreadComm = MPI_COMM_NULL;
+
+ /// Communicator pool distributed over the events. Many MPI implementations
+ /// allow for better network hardware parallelism when unrelated MPI messages
+ /// are exchanged over distinct communicators. Thus this pool will be given in
+ /// a round-robin fashion to each newly created event to better utilize the
+ /// hardware capabilities.
+ llvm::SmallVector<MPI_Comm> EventCommPool{};
+
+ /// Number of process used by the event system.
+ int WorldSize = -1;
+
+ /// The local rank of the current instance.
+ int LocalRank = -1;
+
+ /// Number of events created by the current instance so far. This is used to
+ /// generate unique MPI tags for each event.
+ std::atomic<int> EventCounter{0};
+
+ /// Event queue between the local gate thread and the event handlers. The exec
+ /// queue is responsible for only running the execution events, while the data
+ /// queue executes all the other ones. This allows for long running execution
+ /// events to not block any data transfers (which are all done in a
+ /// non-blocking fashion).
+ EventQueue ExecEventQueue{};
+ EventQueue DataEventQueue{};
+
+ /// Event System execution state.
+ std::atomic<EventSystemStateTy> EventSystemState{};
+
+private:
+ /// Function executed by the event handler threads.
+ void runEventHandler(std::stop_token Stop, EventQueue &Queue);
+
+ /// Creates a new unique event tag for a new event.
+ int createNewEventTag();
+
+ /// Gets a comm for a new event from the comm pool.
+ MPI_Comm &getNewEventComm(int MPITag);
+
+ /// Creates a local MPI context containing a exclusive comm for the gate
+ /// thread, and a comm pool to be used internally by the events. It also
+ /// acquires the local MPI process description.
+ bool createLocalMPIContext();
+
+ /// Destroy the local MPI context and all of its comms.
+ bool destroyLocalMPIContext();
+
+public:
+ EventSystemTy();
+ ~EventSystemTy();
+
+ bool initialize();
+ bool deinitialize();
+
+ /// Creates a new event.
+ ///
+ /// Creates a new event of 'EventClass' type targeting the 'DestRank'. The
+ /// 'args' parameters are additional arguments that may be passed to the
+ /// EventClass origin constructor.
+ ///
+ /// /note: since this is a template function, it must be defined in
+ /// this header.
+ template <class EventFuncTy, typename... ArgsTy>
+ requires std::invocable<EventFuncTy, MPIRequestManagerTy, ArgsTy...>
+ EventTy createEvent(EventFuncTy EventFunc, int DstDeviceID, ArgsTy... Args);
+
+ /// Create a new Exchange event.
+ ///
+ /// This function notifies \p SrcDevice and \p TargetDevice about the
+ /// transfer and creates a host event that waits until the transfer is
+ /// completed.
+ EventTy createExchangeEvent(int SrcDevice, const void *SrcBuffer,
+ int DstDevice, void *DstBuffer, int64_t Size);
+
+ /// Gate thread procedure.
+ ///
+ /// Caller thread will spawn the event handlers, execute the gate logic and
+ /// wait until the event system receive an Exit event.
+ void runGateThread();
+
+ /// Get the number of workers available.
+ ///
+ /// \return the number of MPI available workers.
+ int getNumWorkers() const;
+
+ /// Check if we are at the host MPI process.
+ ///
+ /// \return true if the current MPI process is the host (rank WorldSize-1),
+ /// false otherwise.
+ int isHost() const;
+};
+
+template <class EventFuncTy, typename... ArgsTy>
+ requires std::invocable<EventFuncTy, MPIRequestManagerTy, ArgsTy...>
+EventTy EventSystemTy::createEvent(EventFuncTy EventFunc, int DstDeviceID,
+ ArgsTy... Args) {
+ // Create event MPI request manager.
+ const int EventTag = createNewEventTag();
+ auto &EventComm = getNewEventComm(EventTag);
+
+ // Send new event notification.
+ int EventNotificationInfo[] = {
+ static_cast<int>(OriginEvents::toEventType(EventFunc)), EventTag};
+ MPI_Request NotificationRequest = MPI_REQUEST_NULL;
+ int MPIError = MPI_Isend(EventNotificationInfo, 2, MPI_INT, DstDeviceID,
+ static_cast<int>(ControlTagsTy::EVENT_REQUEST),
+ GateThreadComm, &NotificationRequest);
+
+ if (MPIError != MPI_SUCCESS)
+ co_return createError("MPI failed during event notification with error %d",
+ MPIError);
+
+ MPIRequestManagerTy RequestManager(EventComm, EventTag, DstDeviceID,
+ {NotificationRequest});
+
+ co_return (co_await EventFunc(std::move(RequestManager), Args...));
+}
+
+#endif // _OMPTARGET_OMPCLUSTER_EVENT_SYSTEM_H_
diff --git a/offload/plugins-nextgen/mpi/src/MPIDeviceMain.cpp b/offload/plugins-nextgen/mpi/src/MPIDeviceMain.cpp
new file mode 100644
index 0000000000000..462f2d778c4b2
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/src/MPIDeviceMain.cpp
@@ -0,0 +1,11 @@
+#include "EventSystem.h"
+
+int main(int argc, char *argv[]) {
+ EventSystemTy EventSystem;
+
+ EventSystem.initialize();
+
+ EventSystem.runGateThread();
+
+ EventSystem.deinitialize();
+}
diff --git a/offload/plugins-nextgen/mpi/src/rtl.cpp b/offload/plugins-nextgen/mpi/src/rtl.cpp
new file mode 100644
index 0000000000000..849cb9f8cd38f
--- /dev/null
+++ b/offload/plugins-nextgen/mpi/src/rtl.cpp
@@ -0,0 +1,686 @@
+//===------RTLs/mpi/src/rtl.cpp - Target RTLs Implementation - C++ ------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// RTL NextGen for MPI applications
+//
+//===----------------------------------------------------------------------===//
+
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <optional>
+#include <string>
+
+#include "GlobalHandler.h"
+#include "OpenMP/OMPT/Callback.h"
+#include "PluginInterface.h"
+#include "Shared/Debug.h"
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/BinaryFormat/ELF.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
+#include "llvm/Support/Error.h"
+#include "llvm/TargetParser/Triple.h"
+
+#include "EventSystem.h"
+
+namespace llvm::omp::target::plugin {
+
+/// Forward declarations for all specialized data structures.
+struct MPIPluginTy;
+struct MPIDeviceTy;
+struct MPIDeviceImageTy;
+struct MPIKernelTy;
+class MPIGlobalHandlerTy;
+
+// TODO: Should this be defined inside the EventSystem?
+using MPIEventQueue = SmallVector<EventTy>;
+using MPIEventQueuePtr = MPIEventQueue *;
+
+/// Class implementing the MPI device images properties.
+struct MPIDeviceImageTy : public DeviceImageTy {
+ /// Create the MPI image with the id and the target image pointer.
+ MPIDeviceImageTy(int32_t ImageId, GenericDeviceTy &Device,
+ const __tgt_device_image *TgtImage)
+ : DeviceImageTy(ImageId, Device, TgtImage), DeviceImageAddrs(getSize()) {}
+
+ llvm::SmallVector<void *> DeviceImageAddrs;
+};
+
+class MPIGlobalHandlerTy final : public GenericGlobalHandlerTy {
+public:
+ Error getGlobalMetadataFromDevice(GenericDeviceTy &GenericDevice,
+ DeviceImageTy &Image,
+ GlobalTy &DeviceGlobal) override {
+ const char *GlobalName = DeviceGlobal.getName().data();
+ MPIDeviceImageTy &MPIImage = static_cast<MPIDeviceImageTy &>(Image);
+
+ if (GlobalName == nullptr) {
+ return Plugin::error("Failed to get name for global %p", &DeviceGlobal);
+ }
+
+ void *EntryAddress = nullptr;
+
+ __tgt_offload_entry *Begin = MPIImage.getTgtImage()->EntriesBegin;
+ __tgt_offload_entry *End = MPIImage.getTgtImage()->EntriesEnd;
+
+ int I = 0;
+ for (auto &Entry = Begin; Entry < End; ++Entry) {
+ if (!strcmp(Entry->name, GlobalName)) {
+ EntryAddress = MPIImage.DeviceImageAddrs[I];
+ break;
+ }
+ I++;
+ }
+
+ if (EntryAddress == nullptr) {
+ return Plugin::error("Failed to find global %s", GlobalName);
+ }
+
+ // Save the pointer to the symbol.
+ DeviceGlobal.setPtr(EntryAddress);
+
+ return Plugin::success();
+ }
+};
+
+struct MPIKernelTy : public GenericKernelTy {
+ /// Construct the kernel with a name and an execution mode.
+ MPIKernelTy(const char *Name, EventSystemTy &EventSystem)
+ : GenericKernelTy(Name), Func(nullptr), EventSystem(EventSystem) {}
+
+ /// Initialize the kernel.
+ Error initImpl(GenericDeviceTy &Device, DeviceImageTy &Image) override {
+ // Functions have zero size.
+ GlobalTy Global(getName(), 0);
+
+ // Get the metadata (address) of the kernel function.
+ GenericGlobalHandlerTy &GHandler = Device.Plugin.getGlobalHandler();
+ if (auto Err = GHandler.getGlobalMetadataFromDevice(Device, Image, Global))
+ return Err;
+
+ // Check that the function pointer is valid.
+ if (!Global.getPtr())
+ return Plugin::error("Invalid function for kernel %s", getName());
+
+ // Save the function pointer.
+ Func = (void (*)())Global.getPtr();
+
+ // TODO: Check which settings are appropriate for the mpi plugin
+ // for now we are using the Elf64 plugin configuration
+ KernelEnvironment.Configuration.ExecMode = OMP_TGT_EXEC_MODE_GENERIC;
+ KernelEnvironment.Configuration.MayUseNestedParallelism = /* Unknown */ 2;
+ KernelEnvironment.Configuration.UseGenericStateMachine = /* Unknown */ 2;
+
+ // Set the maximum number of threads to a single.
+ MaxNumThreads = 1;
+ return Plugin::success();
+ }
+
+ /// Launch the kernel.
+ Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
+ uint64_t NumBlocks, KernelArgsTy &KernelArgs, void *Args,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
+
+private:
+ /// The kernel function to execute.
+ void (*Func)(void);
+ EventSystemTy &EventSystem;
+};
+
+/// MPI resource reference and queue. These are the objects handled by the
+/// MPIQueue Manager for the MPI plugin.
+template <typename ResourceTy>
+struct MPIResourceRef final : public GenericDeviceResourceRef {
+
+ /// The underlying handler type for the resource.
+ using HandleTy = ResourceTy *;
+
+ /// Create a empty reference to an invalid resource.
+ MPIResourceRef() : Resource(nullptr) {}
+
+ /// Create a reference to an existing resource.
+ MPIResourceRef(HandleTy Queue) : Resource(Queue) {}
+
+ /// Create a new resource and save the reference.
+ Error create(GenericDeviceTy &Device) override {
+ if (Resource)
+ return Plugin::error("Recreating an existing resource");
+
+ Resource = new ResourceTy;
+ if (!Resource)
+ return Plugin::error("Failed to allocated a new resource");
+
+ return Plugin::success();
+ }
+
+ /// Destroy the resource and invalidate the reference.
+ Error destroy(GenericDeviceTy &Device) override {
+ if (!Resource)
+ return Plugin::error("Destroying an invalid resource");
+
+ delete Resource;
+ Resource = nullptr;
+
+ return Plugin::success();
+ }
+
+ operator HandleTy() const { return Resource; }
+
+private:
+ HandleTy Resource;
+};
+
+/// Class implementing the device functionalities for remote x86_64 processes.
+struct MPIDeviceTy : public GenericDeviceTy {
+ /// Create a MPI Device with a device id and the default MPI grid values.
+ MPIDeviceTy(GenericPluginTy &Plugin, int32_t DeviceId, int32_t NumDevices,
+ EventSystemTy &EventSystem)
+ : GenericDeviceTy(Plugin, DeviceId, NumDevices, MPIGridValues),
+ MPIEventQueueManager(*this), MPIEventManager(*this),
+ EventSystem(EventSystem) {}
+
+ /// Initialize the device, its resources and get its properties.
+ Error initImpl(GenericPluginTy &Plugin) override {
+ if (auto Err = MPIEventQueueManager.init(OMPX_InitialNumStreams))
+ return Err;
+
+ if (auto Err = MPIEventManager.init(OMPX_InitialNumEvents))
+ return Err;
+
+ return Plugin::success();
+ }
+
+ /// Deinitizalize the device and release its resources.
+ Error deinitImpl() override {
+ if (auto Err = MPIEventQueueManager.deinit())
+ return Err;
+
+ if (auto Err = MPIEventManager.deinit())
+ return Err;
+
+ return Plugin::success();
+ }
+
+ Error setContext() override { return Plugin::success(); }
+
+ /// Load the binary image into the device and allocate an image object.
+ Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage,
+ int32_t ImageId) override {
+
+ // Allocate and initialize the image object.
+ MPIDeviceImageTy *Image = Plugin.allocate<MPIDeviceImageTy>();
+ new (Image) MPIDeviceImageTy(ImageId, *this, TgtImage);
+
+ auto Event = EventSystem.createEvent(OriginEvents::loadBinary, DeviceId,
+ TgtImage, &(Image->DeviceImageAddrs));
+
+ if (Event.empty()) {
+ return Plugin::error("Failed to create loadBinary event for image %p",
+ TgtImage);
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError(); Error) {
+ return Plugin::error("Event failed during loadBinary. %s\n",
+ toString(std::move(Error)).c_str());
+ }
+
+ return Image;
+ }
+
+ /// Allocate memory on the device or related to the device.
+ void *allocate(size_t Size, void *, TargetAllocTy Kind) override {
+ if (Size == 0)
+ return nullptr;
+
+ void *BufferAddress = nullptr;
+ std::optional<Error> Err = std::nullopt;
+ EventTy Event{nullptr};
+
+ switch (Kind) {
+ case TARGET_ALLOC_DEFAULT:
+ case TARGET_ALLOC_DEVICE:
+ case TARGET_ALLOC_DEVICE_NON_BLOCKING:
+ Event = EventSystem.createEvent(OriginEvents::allocateBuffer, DeviceId,
+ Size, &BufferAddress);
+
+ if (Event.empty()) {
+ Err = Plugin::error("Failed to create alloc event with size %z", Size);
+ break;
+ }
+
+ Event.wait();
+ Err = Event.getError();
+ break;
+ case TARGET_ALLOC_HOST:
+ BufferAddress = memAllocHost(Size);
+ Err = Plugin::check(BufferAddress == nullptr,
+ "Failed to allocate host memory");
+ break;
+ case TARGET_ALLOC_SHARED:
+ Err = Plugin::error("Incompatible memory type %d", Kind);
+ break;
+ }
+
+ if (*Err) {
+ REPORT("Failed to allocate memory: %s\n",
+ toString(std::move(*Err)).c_str());
+ return nullptr;
+ }
+
+ return BufferAddress;
+ }
+
+ /// Deallocate memory on the device or related to the device.
+ int free(void *TgtPtr, TargetAllocTy Kind) override {
+ if (TgtPtr == nullptr)
+ return OFFLOAD_SUCCESS;
+
+ std::optional<Error> Err = std::nullopt;
+ EventTy Event{nullptr};
+
+ switch (Kind) {
+ case TARGET_ALLOC_DEFAULT:
+ case TARGET_ALLOC_DEVICE:
+ case TARGET_ALLOC_DEVICE_NON_BLOCKING:
+ Event =
+ EventSystem.createEvent(OriginEvents::deleteBuffer, DeviceId, TgtPtr);
+
+ if (Event.empty()) {
+ Err = Plugin::error("Failed to create delete event");
+ break;
+ }
+
+ Event.wait();
+ Err = Event.getError();
+ break;
+ case TARGET_ALLOC_HOST:
+ Err = Plugin::check(memFreeHost(TgtPtr), "Failed to free host memory");
+ break;
+ case TARGET_ALLOC_SHARED:
+ Err = createStringError(inconvertibleErrorCode(),
+ "Incompatible memory type %d", Kind);
+ break;
+ }
+
+ if (*Err) {
+ REPORT("Failed to free memory: %s\n", toString(std::move(*Err)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ /// Submit data to the device (host to device transfer).
+ Error dataSubmitImpl(void *TgtPtr, const void *HstPtr, int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Err = getQueue(AsyncInfoWrapper, Queue))
+ return Err;
+
+ // Copy HstData to a buffer with event-managed lifetime.
+ void *SubmitBuffer = std::malloc(Size);
+ std::memcpy(SubmitBuffer, HstPtr, Size);
+ EventDataHandleTy DataHandle(SubmitBuffer, &std::free);
+
+ auto Event = EventSystem.createEvent(OriginEvents::submit, DeviceId,
+ DataHandle, TgtPtr, Size);
+
+ if (Event.empty())
+ return Plugin::error("Failed to create submit event");
+
+ Queue->push_back(Event);
+
+ return Plugin::success();
+ }
+
+ /// Retrieve data from the device (device to host transfer).
+ Error dataRetrieveImpl(void *HstPtr, const void *TgtPtr, int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Err = getQueue(AsyncInfoWrapper, Queue))
+ return Err;
+
+ auto Event = EventSystem.createEvent(OriginEvents::retrieve, DeviceId,
+ HstPtr, TgtPtr, Size);
+
+ if (Event.empty())
+ return Plugin::error("Failed to create retrieve event");
+
+ Queue->push_back(Event);
+
+ return Plugin::success();
+ }
+
+ /// Exchange data between two devices directly. In the MPI plugin, this
+ /// function will create an event for the host to tell the devices about the
+ /// exchange. Then, the devices will do the transfer themselves and let the
+ /// host know when it's done.
+ Error dataExchangeImpl(const void *SrcPtr, GenericDeviceTy &DstDev,
+ void *DstPtr, int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Err = getQueue(AsyncInfoWrapper, Queue))
+ return Err;
+
+ auto Event = EventSystem.createExchangeEvent(
+ DeviceId, SrcPtr, DstDev.getDeviceId(), DstPtr, Size);
+
+ if (Event.empty())
+ return Plugin::error("Failed to create exchange event");
+
+ Queue->push_back(Event);
+
+ return Plugin::success();
+ }
+
+ /// Allocate and construct a MPI kernel.
+ Expected<GenericKernelTy &> constructKernel(const char *Name) override {
+ // Allocate and construct the kernel.
+ MPIKernelTy *MPIKernel = Plugin.allocate<MPIKernelTy>();
+
+ if (!MPIKernel)
+ return Plugin::error("Failed to allocate memory for MPI kernel");
+
+ new (MPIKernel) MPIKernelTy(Name, EventSystem);
+
+ return *MPIKernel;
+ }
+
+ /// Create an event.
+ Error createEventImpl(void **EventStoragePtr) override {
+ if (!EventStoragePtr)
+ return Plugin::error("Received invalid event storage pointer");
+
+ EventTy **NewEvent = reinterpret_cast<EventTy **>(EventStoragePtr);
+ auto Err = MPIEventManager.getResource(*NewEvent);
+ if (Err)
+ return Plugin::error("Could not allocate a new synchronization event");
+
+ return Plugin::success();
+ }
+
+ /// Destroy a previously created event.
+ Error destroyEventImpl(void *Event) override {
+ if (!Event)
+ return Plugin::error("Received invalid event pointer");
+
+ return MPIEventManager.returnResource(reinterpret_cast<EventTy *>(Event));
+ }
+
+ /// Record the event.
+ Error recordEventImpl(void *Event,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ if (!Event)
+ return Plugin::error("Received invalid event pointer");
+
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Err = getQueue(AsyncInfoWrapper, Queue))
+ return Err;
+
+ if (Queue->empty())
+ return Plugin::success();
+
+ auto &RecordedEvent = *reinterpret_cast<EventTy *>(Event);
+ RecordedEvent = Queue->back();
+
+ return Plugin::success();
+ }
+
+ /// Make the queue wait on the event.
+ Error waitEventImpl(void *Event,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ if (!Event)
+ return Plugin::error("Received invalid event pointer");
+
+ auto &RecordedEvent = *reinterpret_cast<EventTy *>(Event);
+ auto SyncEvent = OriginEvents::sync(RecordedEvent);
+
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Err = getQueue(AsyncInfoWrapper, Queue))
+ return Err;
+
+ Queue->push_back(SyncEvent);
+
+ return Plugin::success();
+ }
+
+ /// Synchronize the current thread with the event
+ Error syncEventImpl(void *Event) override {
+ if (!Event)
+ return Plugin::error("Received invalid event pointer");
+
+ auto &RecordedEvent = *reinterpret_cast<EventTy *>(Event);
+ auto SyncEvent = OriginEvents::sync(RecordedEvent);
+
+ SyncEvent.wait();
+
+ return SyncEvent.getError();
+ }
+
+ /// Synchronize current thread with the pending operations on the async info.
+ Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
+ auto *Queue = reinterpret_cast<MPIEventQueue *>(AsyncInfo.Queue);
+
+ for (auto &Event : *Queue) {
+ Event.wait();
+
+ if (auto Error = Event.getError(); Error)
+ return Plugin::error("Event failed during synchronization. %s\n",
+ toString(std::move(Error)).c_str());
+ }
+
+ // Once the queue is synchronized, return it to the pool and reset the
+ // AsyncInfo. This is to make sure that the synchronization only works
+ // for its own tasks.
+ AsyncInfo.Queue = nullptr;
+ return MPIEventQueueManager.returnResource(Queue);
+ }
+
+ /// Query for the completion of the pending operations on the async info.
+ Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override {
+ auto *Queue = reinterpret_cast<MPIEventQueue *>(AsyncInfo.Queue);
+
+ // Returns success when there are pending operations in the AsyncInfo.
+ if (!Queue->empty() && !Queue->back().done())
+ return Plugin::success();
+
+ // Once the queue is synchronized, return it to the pool and reset the
+ // AsyncInfo. This is to make sure that the synchronization only works
+ // for its own tasks.
+ AsyncInfo.Queue = nullptr;
+ return MPIEventQueueManager.returnResource(Queue);
+ }
+
+ Expected<void *> dataLockImpl(void *HstPtr, int64_t Size) override {
+ return HstPtr;
+ }
+
+ /// Indicate that the buffer is not pinned.
+ Expected<bool> isPinnedPtrImpl(void *HstPtr, void *&BaseHstPtr,
+ void *&BaseDevAccessiblePtr,
+ size_t &BaseSize) const override {
+ return false;
+ }
+
+ Error dataUnlockImpl(void *HstPtr) override { return Plugin::success(); }
+
+ /// This plugin should not setup the device environment or memory pool.
+ virtual bool shouldSetupDeviceEnvironment() const override { return false; };
+ virtual bool shouldSetupDeviceMemoryPool() const override { return false; };
+
+ /// Device memory limits are currently not applicable to the MPI plugin.
+ Error getDeviceStackSize(uint64_t &Value) override {
+ Value = 0;
+ return Plugin::success();
+ }
+
+ Error setDeviceStackSize(uint64_t Value) override {
+ return Plugin::success();
+ }
+
+ Error getDeviceHeapSize(uint64_t &Value) override {
+ Value = 0;
+ return Plugin::success();
+ }
+
+ Error setDeviceHeapSize(uint64_t Value) override { return Plugin::success(); }
+
+ /// Device interoperability. Not supported by MPI right now.
+ Error initAsyncInfoImpl(AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::error("initAsyncInfoImpl not supported");
+ }
+
+ /// This plugin does not support interoperability.
+ Error initDeviceInfoImpl(__tgt_device_info *DeviceInfo) override {
+ return Plugin::error("initDeviceInfoImpl not supported");
+ }
+
+ /// Print information about the device.
+ Error obtainInfoImpl(InfoQueueTy &Info) override {
+ // TODO: Add more information about the device.
+ Info.add("MPI plugin");
+ Info.add("MPI OpenMP Device Number", DeviceId);
+
+ return Plugin::success();
+ }
+
+ Error getQueue(AsyncInfoWrapperTy &AsyncInfoWrapper,
+ MPIEventQueuePtr &Queue) {
+ Queue = AsyncInfoWrapper.getQueueAs<MPIEventQueuePtr>();
+ if (!Queue) {
+ // There was no queue; get a new one.
+ if (auto Err = MPIEventQueueManager.getResource(Queue))
+ return Err;
+
+ // Modify the AsyncInfoWrapper to hold the new queue.
+ AsyncInfoWrapper.setQueueAs<MPIEventQueuePtr>(Queue);
+ }
+ return Plugin::success();
+ }
+
+private:
+ using MPIEventQueueManagerTy =
+ GenericDeviceResourceManagerTy<MPIResourceRef<MPIEventQueue>>;
+ using MPIEventManagerTy =
+ GenericDeviceResourceManagerTy<MPIResourceRef<EventTy>>;
+
+ MPIEventQueueManagerTy MPIEventQueueManager;
+ MPIEventManagerTy MPIEventManager;
+ EventSystemTy &EventSystem;
+
+ /// Grid values for the MPI plugin.
+ static constexpr GV MPIGridValues = {
+ 1, // GV_Slot_Size
+ 1, // GV_Warp_Size
+ 1, // GV_Max_Teams
+ 1, // GV_Default_Num_Teams
+ 1, // GV_SimpleBufferSize
+ 1, // GV_Max_WG_Size
+ 1, // GV_Default_WG_Size
+ };
+};
+
+Error MPIKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
+ uint32_t NumThreads, uint64_t NumBlocks,
+ KernelArgsTy &KernelArgs, void *Args,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) const {
+ MPIDeviceTy &MPIDevice = static_cast<MPIDeviceTy &>(GenericDevice);
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Err = MPIDevice.getQueue(AsyncInfoWrapper, Queue))
+ return Err;
+
+ uint32_t NumArgs = KernelArgs.NumArgs;
+
+ // Copy explicit Args to a buffer with event-managed lifetime.
+ // This is necessary because host addresses are not accessible on the MPI
+ // device and the Args buffer lifetime is not compatible with the lifetime of
+ // the Execute Event
+ void *TgtArgs = std::malloc(sizeof(void *) * NumArgs);
+ std::memcpy(TgtArgs, *static_cast<void **>(Args), sizeof(void *) * NumArgs);
+ EventDataHandleTy DataHandle(TgtArgs, &std::free);
+
+ auto Event = EventSystem.createEvent(OriginEvents::execute,
+ GenericDevice.getDeviceId(), DataHandle,
+ NumArgs, (void *)Func);
+ if (Event.empty())
+ return Plugin::error("Failed to create execute event");
+
+ Queue->push_back(Event);
+
+ return Plugin::success();
+}
+
+/// Class implementing the MPI plugin.
+struct MPIPluginTy : GenericPluginTy {
+ MPIPluginTy() : GenericPluginTy(getTripleArch()) {}
+
+ /// This class should not be copied.
+ MPIPluginTy(const MPIPluginTy &) = delete;
+ MPIPluginTy(MPIPluginTy &&) = delete;
+
+ /// Initialize the plugin and return the number of devices.
+ Expected<int32_t> initImpl() override {
+#ifdef OMPT_SUPPORT
+ ompt::connectLibrary();
+#endif
+
+ EventSystem.initialize();
+ return EventSystem.getNumWorkers();
+ }
+
+ Error deinitImpl() override {
+ EventSystem.deinitialize();
+ return Plugin::success();
+ }
+
+ /// Create a MPI device.
+ GenericDeviceTy *createDevice(GenericPluginTy &Plugin, int32_t DeviceId,
+ int32_t NumDevices) override {
+ return new MPIDeviceTy(Plugin, DeviceId, NumDevices, EventSystem);
+ }
+
+ GenericGlobalHandlerTy *createGlobalHandler() override {
+ return new MPIGlobalHandlerTy();
+ }
+
+ /// Get the ELF code to recognize the compatible binary images.
+ uint16_t getMagicElfBits() const override { return ELF::EM_X86_64; }
+
+ bool isDataExchangable(int32_t SrcDeviceId, int32_t DstDeviceId) override {
+ return isValidDeviceId(SrcDeviceId) && isValidDeviceId(DstDeviceId);
+ }
+
+ /// All images (ELF-compatible) should be compatible with this plugin.
+ Expected<bool> isELFCompatible(StringRef) const override { return true; }
+
+ Triple::ArchType getTripleArch() const override { return Triple::x86_64; }
+
+ // private:
+ // TODO: How to mantain the EventSystem private and still allow the device to
+ // access it?
+ EventSystemTy EventSystem;
+};
+
+GenericPluginTy *PluginTy::createPlugin() { return new MPIPluginTy(); }
+
+template <typename... ArgsTy>
+static Error Plugin::check(int32_t ErrorCode, const char *ErrFmt,
+ ArgsTy... Args) {
+ if (ErrorCode == 0)
+ return Error::success();
+
+ return createStringError<ArgsTy..., const char *>(
+ inconvertibleErrorCode(), ErrFmt, Args...,
+ std::to_string(ErrorCode).data());
+}
+
+} // namespace llvm::omp::target::plugin
diff --git a/offload/test/api/omp_device_managed_memory.c b/offload/test/api/omp_device_managed_memory.c
index 2a9fe09a8334c..4a5fae24ac1bf 100644
--- a/offload/test/api/omp_device_managed_memory.c
+++ b/offload/test/api/omp_device_managed_memory.c
@@ -1,5 +1,7 @@
// RUN: %libomptarget-compile-run-and-check-generic
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
+
#include <omp.h>
#include <stdio.h>
diff --git a/offload/test/api/omp_device_managed_memory_alloc.c b/offload/test/api/omp_device_managed_memory_alloc.c
index c48866922deba..fde20ac9ce2f0 100644
--- a/offload/test/api/omp_device_managed_memory_alloc.c
+++ b/offload/test/api/omp_device_managed_memory_alloc.c
@@ -1,6 +1,8 @@
// RUN: %libomptarget-compile-run-and-check-generic
// RUN: %libomptarget-compileopt-run-and-check-generic
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
+
#include <omp.h>
#include <stdio.h>
diff --git a/offload/test/api/omp_dynamic_shared_memory.c b/offload/test/api/omp_dynamic_shared_memory.c
index 3fe75f24db3e6..9a563b68cbd9e 100644
--- a/offload/test/api/omp_dynamic_shared_memory.c
+++ b/offload/test/api/omp_dynamic_shared_memory.c
@@ -8,6 +8,7 @@
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: s390x-ibm-linux-gnu
diff --git a/offload/test/api/omp_indirect_call.c b/offload/test/api/omp_indirect_call.c
index ac0febf7854da..f0828d3476687 100644
--- a/offload/test/api/omp_indirect_call.c
+++ b/offload/test/api/omp_indirect_call.c
@@ -1,5 +1,7 @@
// RUN: %libomptarget-compile-run-and-check-generic
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
+
#include <assert.h>
#include <stdio.h>
diff --git a/offload/test/jit/empty_kernel_lvl1.c b/offload/test/jit/empty_kernel_lvl1.c
index a0b8cd448837d..53ac243e7c918 100644
--- a/offload/test/jit/empty_kernel_lvl1.c
+++ b/offload/test/jit/empty_kernel_lvl1.c
@@ -32,6 +32,7 @@
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
diff --git a/offload/test/jit/empty_kernel_lvl2.c b/offload/test/jit/empty_kernel_lvl2.c
index 81a04f55ce43d..816b670a1ba42 100644
--- a/offload/test/jit/empty_kernel_lvl2.c
+++ b/offload/test/jit/empty_kernel_lvl2.c
@@ -92,6 +92,7 @@
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
diff --git a/offload/test/jit/type_punning.c b/offload/test/jit/type_punning.c
index 10e3d2cef718b..2ece4722cbbac 100644
--- a/offload/test/jit/type_punning.c
+++ b/offload/test/jit/type_punning.c
@@ -12,6 +12,7 @@
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
diff --git a/offload/test/lit.cfg b/offload/test/lit.cfg
index 6c590603079c4..42d6b05e1afe3 100644
--- a/offload/test/lit.cfg
+++ b/offload/test/lit.cfg
@@ -137,6 +137,8 @@ elif config.libomptarget_current_target.startswith('amdgcn'):
(config.amdgpu_test_arch.startswith("gfx942") and
evaluate_bool_env(config.environment['IS_APU']))):
supports_apu = True
+if config.libomptarget_current_target.endswith('-mpi'):
+ supports_unified_shared_memory = False
if supports_unified_shared_memory:
config.available_features.add('unified_shared_memory')
if supports_apu:
@@ -175,6 +177,8 @@ def remove_suffix_if_present(name):
return name[:-4]
elif name.endswith('-JIT-LTO'):
return name[:-8]
+ elif name.endswith('-mpi'):
+ return name[:-4]
else:
return name
@@ -302,7 +306,7 @@ for libomptarget_target in config.libomptarget_all_targets:
"%clang-" + libomptarget_target + add_libraries(" -O3 %s -o %t")))
config.substitutions.append(("%libomptarget-run-" + \
libomptarget_target, \
- "%t"))
+ "%pre_bin %t"))
config.substitutions.append(("%libomptarget-run-fail-" + \
libomptarget_target, \
"%not --crash %t"))
@@ -395,5 +399,9 @@ else:
config.substitutions.append(("%cuda_flags", ""))
config.substitutions.append(("%flags_clang", config.test_flags_clang))
config.substitutions.append(("%flags_flang", config.test_flags_flang))
+if config.libomptarget_current_target.endswith('-mpi'):
+ config.substitutions.append(("%pre_bin", "mpirun -np 1 llvm-offload-mpi-device : -np 1"))
+else:
+ config.substitutions.append(("%pre_bin", ""))
config.substitutions.append(("%flags", config.test_flags))
config.substitutions.append(("%not", config.libomptarget_not))
diff --git a/offload/test/mapping/target_derefence_array_pointrs.cpp b/offload/test/mapping/target_derefence_array_pointrs.cpp
index a6dd4069a8f58..af19cca674733 100644
--- a/offload/test/mapping/target_derefence_array_pointrs.cpp
+++ b/offload/test/mapping/target_derefence_array_pointrs.cpp
@@ -6,6 +6,7 @@
// UNSUPPORTED: amdgcn-amd-amdhsa
// UNSUPPORTED: nvptx64-nvidia-cuda
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
#include <stdio.h>
#include <stdlib.h>
diff --git a/offload/test/offloading/barrier_fence.c b/offload/test/offloading/barrier_fence.c
index b9a8ca27965a0..850491a02f038 100644
--- a/offload/test/offloading/barrier_fence.c
+++ b/offload/test/offloading/barrier_fence.c
@@ -7,6 +7,7 @@
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
diff --git a/offload/test/offloading/bug49334.cpp b/offload/test/offloading/bug49334.cpp
index 1f19dab378810..0f61a58da1e00 100644
--- a/offload/test/offloading/bug49334.cpp
+++ b/offload/test/offloading/bug49334.cpp
@@ -7,6 +7,7 @@
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: s390x-ibm-linux-gnu
diff --git a/offload/test/offloading/default_thread_limit.c b/offload/test/offloading/default_thread_limit.c
index 4da02bbb152e6..beef5f5e187d6 100644
--- a/offload/test/offloading/default_thread_limit.c
+++ b/offload/test/offloading/default_thread_limit.c
@@ -9,6 +9,7 @@
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
diff --git a/offload/test/offloading/ompx_bare.c b/offload/test/offloading/ompx_bare.c
index 3dabdcd15e0d8..05b4cfa3ed822 100644
--- a/offload/test/offloading/ompx_bare.c
+++ b/offload/test/offloading/ompx_bare.c
@@ -4,6 +4,7 @@
//
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: s390x-ibm-linux-gnu
diff --git a/offload/test/offloading/ompx_coords.c b/offload/test/offloading/ompx_coords.c
index 5e4e14b4c6dae..a05749990d037 100644
--- a/offload/test/offloading/ompx_coords.c
+++ b/offload/test/offloading/ompx_coords.c
@@ -2,6 +2,7 @@
//
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: s390x-ibm-linux-gnu
diff --git a/offload/test/offloading/ompx_saxpy_mixed.c b/offload/test/offloading/ompx_saxpy_mixed.c
index f479be8a484fc..e857277d8a9de 100644
--- a/offload/test/offloading/ompx_saxpy_mixed.c
+++ b/offload/test/offloading/ompx_saxpy_mixed.c
@@ -2,6 +2,7 @@
//
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: aarch64-unknown-linux-gnu
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: s390x-ibm-linux-gnu
diff --git a/offload/test/offloading/small_trip_count.c b/offload/test/offloading/small_trip_count.c
index 65f094f157469..bdbacea6c33a5 100644
--- a/offload/test/offloading/small_trip_count.c
+++ b/offload/test/offloading/small_trip_count.c
@@ -9,6 +9,7 @@
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
diff --git a/offload/test/offloading/small_trip_count_thread_limit.cpp b/offload/test/offloading/small_trip_count_thread_limit.cpp
index b7ae52a62c83b..c1ebb3d7bfb25 100644
--- a/offload/test/offloading/small_trip_count_thread_limit.cpp
+++ b/offload/test/offloading/small_trip_count_thread_limit.cpp
@@ -7,6 +7,7 @@
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
diff --git a/offload/test/offloading/spmdization.c b/offload/test/offloading/spmdization.c
index 77913bec8342f..212a3fa4b37b0 100644
--- a/offload/test/offloading/spmdization.c
+++ b/offload/test/offloading/spmdization.c
@@ -11,6 +11,7 @@
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
diff --git a/offload/test/offloading/target_critical_region.cpp b/offload/test/offloading/target_critical_region.cpp
index 495632bf76e17..605350e36e8a0 100644
--- a/offload/test/offloading/target_critical_region.cpp
+++ b/offload/test/offloading/target_critical_region.cpp
@@ -6,6 +6,7 @@
// UNSUPPORTED: nvptx64-nvidia-cuda-LTO
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
// UNSUPPORTED: amdgcn-amd-amdhsa
diff --git a/offload/test/offloading/thread_limit.c b/offload/test/offloading/thread_limit.c
index a8cc51b651dc9..81c0359e20f02 100644
--- a/offload/test/offloading/thread_limit.c
+++ b/offload/test/offloading/thread_limit.c
@@ -9,6 +9,7 @@
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
diff --git a/offload/test/offloading/workshare_chunk.c b/offload/test/offloading/workshare_chunk.c
index a8c60c0097791..36d843417a011 100644
--- a/offload/test/offloading/workshare_chunk.c
+++ b/offload/test/offloading/workshare_chunk.c
@@ -5,6 +5,7 @@
// UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
// UNSUPPORTED: x86_64-pc-linux-gnu
// UNSUPPORTED: x86_64-pc-linux-gnu-LTO
+// UNSUPPORTED: x86_64-pc-linux-gnu-mpi
// UNSUPPORTED: s390x-ibm-linux-gnu
// UNSUPPORTED: s390x-ibm-linux-gnu-LTO
>From 2546ba43e80bf5e825ccc7fef8cda03e6a1c3577 Mon Sep 17 00:00:00 2001
From: Guilherme Valarini <guilherme.a.valarini at gmail.com>
Date: Thu, 9 May 2024 00:46:30 -0700
Subject: [PATCH 2/3] [Offload] Fix queryAsyncImpl to match MPI progress model
This commit also refactors the MPI dependency in CMakeLists
---
offload/plugins-nextgen/mpi/CMakeLists.txt | 36 ++++------------------
offload/plugins-nextgen/mpi/src/rtl.cpp | 22 ++++++++++---
2 files changed, 24 insertions(+), 34 deletions(-)
diff --git a/offload/plugins-nextgen/mpi/CMakeLists.txt b/offload/plugins-nextgen/mpi/CMakeLists.txt
index c3a8c9a83b45f..f7fc3a5b02a68 100644
--- a/offload/plugins-nextgen/mpi/CMakeLists.txt
+++ b/offload/plugins-nextgen/mpi/CMakeLists.txt
@@ -13,18 +13,6 @@
# Looking for MPI...
find_package(MPI QUIET)
-set(LIBOMPTARGET_DEP_MPI_FOUND ${MPI_CXX_FOUND})
-set(LIBOMPTARGET_DEP_MPI_LIBRARIES ${MPI_CXX_LIBRARIES})
-set(LIBOMPTARGET_DEP_MPI_INCLUDE_DIRS ${MPI_CXX_INCLUDE_DIRS})
-set(LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS ${MPI_CXX_COMPILE_FLAGS})
-set(LIBOMPTARGET_DEP_MPI_LINK_FLAGS ${MPI_CXX_LINK_FLAGS})
-
-mark_as_advanced(
- LIBOMPTARGET_DEP_MPI_FOUND
- LIBOMPTARGET_DEP_MPI_LIBRARIES
- LIBOMPTARGET_DEP_MPI_INCLUDE_DIRS
- LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS
- LIBOMPTARGET_DEP_MPI_LINK_FLAGS)
if(NOT(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(ppc64le)$" AND CMAKE_SYSTEM_NAME MATCHES "Linux"))
libomptarget_say("Not building MPI offloading plugin: only support MPI in Linux x86_64 or ppc64le hosts.")
@@ -32,7 +20,7 @@ if(NOT(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(ppc64le)$" AND CMAKE_SYSTEM_NAM
elseif(NOT LIBOMPTARGET_DEP_LIBFFI_FOUND)
libomptarget_say("Not building MPI offloading plugin: libffi dependency not found.")
return()
-elseif(NOT LIBOMPTARGET_DEP_MPI_FOUND)
+elseif(NOT MPI_CXX_FOUND)
libomptarget_say("Not building MPI offloading plugin: MPI not found in system.")
return()
endif()
@@ -53,9 +41,8 @@ else()
target_link_libraries(omptarget.rtl.mpi PRIVATE FFI::ffi)
endif()
-target_link_libraries(omptarget.rtl.mpi PRIVATE
- ${LIBOMPTARGET_DEP_MPI_LIBRARIES}
- ${LIBOMPTARGET_DEP_MPI_LINK_FLAGS}
+target_link_libraries(omptarget.rtl.mpi PRIVATE
+ MPI::MPI_CXX
)
# Add include directories
@@ -65,13 +52,9 @@ target_include_directories(omptarget.rtl.mpi PRIVATE
# Install plugin under the lib destination folder.
install(TARGETS omptarget.rtl.mpi
LIBRARY DESTINATION "${OFFLOAD_INSTALL_LIBDIR}")
-set_target_properties(omptarget.rtl.mpi PROPERTIES
+set_target_properties(omptarget.rtl.mpi PROPERTIES
INSTALL_RPATH "$ORIGIN" BUILD_RPATH "$ORIGIN:${CMAKE_CURRENT_BINARY_DIR}/..")
-if(LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS)
- set_target_properties(omptarget.rtl.mpi PROPERTIES
- COMPILE_FLAGS "${LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS}")
-endif()
# Set C++20 as the target standard for this plugin.
set_target_properties(omptarget.rtl.mpi
@@ -94,8 +77,7 @@ llvm_add_tool(OPENMP llvm-offload-mpi-device src/EventSystem.cpp src/MPIDeviceMa
llvm_update_compile_flags(llvm-offload-mpi-device)
target_link_libraries(llvm-offload-mpi-device PRIVATE
- ${LIBOMPTARGET_DEP_MPI_LIBRARIES}
- ${LIBOMPTARGET_DEP_MPI_LINK_FLAGS}
+ MPI::MPI_CXX
LLVMSupport
omp
)
@@ -108,14 +90,8 @@ endif()
target_include_directories(llvm-offload-mpi-device PRIVATE
${LIBOMPTARGET_INCLUDE_DIR}
- ${LIBOMPTARGET_DEP_MPI_INCLUDE_DIRS}
)
-if(LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS)
- set_target_properties(llvm-offload-mpi-device PROPERTIES
- COMPILE_FLAGS "${LIBOMPTARGET_DEP_MPI_COMPILE_FLAGS}"
- )
-endif()
set_target_properties(llvm-offload-mpi-device
PROPERTIES
@@ -123,5 +99,5 @@ set_target_properties(llvm-offload-mpi-device
CXX_STANDARD_REQUIRED ON
)
-target_compile_definitions(llvm-offload-mpi-device PRIVATE
+target_compile_definitions(llvm-offload-mpi-device PRIVATE
DEBUG_PREFIX="OFFLOAD MPI DEVICE")
diff --git a/offload/plugins-nextgen/mpi/src/rtl.cpp b/offload/plugins-nextgen/mpi/src/rtl.cpp
index 849cb9f8cd38f..87f1cdbc4fe4f 100644
--- a/offload/plugins-nextgen/mpi/src/rtl.cpp
+++ b/offload/plugins-nextgen/mpi/src/rtl.cpp
@@ -16,6 +16,7 @@
#include <cstring>
#include <optional>
#include <string>
+#include <list>
#include "GlobalHandler.h"
#include "OpenMP/OMPT/Callback.h"
@@ -40,7 +41,7 @@ struct MPIKernelTy;
class MPIGlobalHandlerTy;
// TODO: Should this be defined inside the EventSystem?
-using MPIEventQueue = SmallVector<EventTy>;
+using MPIEventQueue = std::list<EventTy>;
using MPIEventQueuePtr = MPIEventQueue *;
/// Class implementing the MPI device images properties.
@@ -489,9 +490,22 @@ struct MPIDeviceTy : public GenericDeviceTy {
Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override {
auto *Queue = reinterpret_cast<MPIEventQueue *>(AsyncInfo.Queue);
- // Returns success when there are pending operations in the AsyncInfo.
- if (!Queue->empty() && !Queue->back().done())
- return Plugin::success();
+ // Returns success when there are pending operations in AsyncInfo, moving
+ // forward through the events on the queue until it is fully completed.
+ while (!Queue->empty()) {
+ auto &Event = Queue->front();
+
+ Event.resume();
+
+ if (!Event.done())
+ return Plugin::success();
+
+ if (auto Error = Event.getError(); Error)
+ return Plugin::error("Event failed during query. %s\n",
+ toString(std::move(Error)).c_str());
+
+ Queue->pop_front();
+ }
// Once the queue is synchronized, return it to the pool and reset the
// AsyncInfo. This is to make sure that the synchronization only works
>From 8eaedc44a9ec39e83954984d57b92e677fde4bac Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jhonatan=20Cl=C3=A9to?= <j256444 at dac.unicamp.br>
Date: Thu, 9 May 2024 14:37:57 -0300
Subject: [PATCH 3/3] [Offload] Update MPI Plugin
Update the MPI Plugin to fit the recent changes in Plugin Interface
---
offload/plugins-nextgen/mpi/CMakeLists.txt | 7 ---
offload/plugins-nextgen/mpi/src/rtl.cpp | 72 ++++++++++++++++------
2 files changed, 52 insertions(+), 27 deletions(-)
diff --git a/offload/plugins-nextgen/mpi/CMakeLists.txt b/offload/plugins-nextgen/mpi/CMakeLists.txt
index f7fc3a5b02a68..2ed611e990bec 100644
--- a/offload/plugins-nextgen/mpi/CMakeLists.txt
+++ b/offload/plugins-nextgen/mpi/CMakeLists.txt
@@ -49,13 +49,6 @@ target_link_libraries(omptarget.rtl.mpi PRIVATE
target_include_directories(omptarget.rtl.mpi PRIVATE
${LIBOMPTARGET_INCLUDE_DIR})
-# Install plugin under the lib destination folder.
-install(TARGETS omptarget.rtl.mpi
- LIBRARY DESTINATION "${OFFLOAD_INSTALL_LIBDIR}")
-set_target_properties(omptarget.rtl.mpi PROPERTIES
- INSTALL_RPATH "$ORIGIN" BUILD_RPATH "$ORIGIN:${CMAKE_CURRENT_BINARY_DIR}/..")
-
-
# Set C++20 as the target standard for this plugin.
set_target_properties(omptarget.rtl.mpi
PROPERTIES
diff --git a/offload/plugins-nextgen/mpi/src/rtl.cpp b/offload/plugins-nextgen/mpi/src/rtl.cpp
index 87f1cdbc4fe4f..db9d3d4f83a32 100644
--- a/offload/plugins-nextgen/mpi/src/rtl.cpp
+++ b/offload/plugins-nextgen/mpi/src/rtl.cpp
@@ -14,22 +14,34 @@
#include <cstdint>
#include <cstdlib>
#include <cstring>
+#include <list>
#include <optional>
#include <string>
-#include <list>
+#include "Shared/Debug.h"
+#include "Utils/ELF.h"
+
+#include "EventSystem.h"
#include "GlobalHandler.h"
#include "OpenMP/OMPT/Callback.h"
#include "PluginInterface.h"
-#include "Shared/Debug.h"
+#include "omptarget.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/BinaryFormat/ELF.h"
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
#include "llvm/Support/Error.h"
-#include "llvm/TargetParser/Triple.h"
-#include "EventSystem.h"
+#if !defined(__BYTE_ORDER__) || !defined(__ORDER_LITTLE_ENDIAN__) || \
+ !defined(__ORDER_BIG_ENDIAN__)
+#error "Missing preprocessor definitions for endianness detection."
+#endif
+
+#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+#define LITTLEENDIAN_CPU
+#elif defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+#define BIGENDIAN_CPU
+#endif
namespace llvm::omp::target::plugin {
@@ -643,53 +655,67 @@ struct MPIPluginTy : GenericPluginTy {
/// Initialize the plugin and return the number of devices.
Expected<int32_t> initImpl() override {
-#ifdef OMPT_SUPPORT
- ompt::connectLibrary();
-#endif
-
EventSystem.initialize();
return EventSystem.getNumWorkers();
}
+ /// Deinitialize the plugin.
Error deinitImpl() override {
EventSystem.deinitialize();
return Plugin::success();
}
- /// Create a MPI device.
+ /// Creates a MPI device.
GenericDeviceTy *createDevice(GenericPluginTy &Plugin, int32_t DeviceId,
int32_t NumDevices) override {
return new MPIDeviceTy(Plugin, DeviceId, NumDevices, EventSystem);
}
+ /// Creates a MPI global handler.
GenericGlobalHandlerTy *createGlobalHandler() override {
return new MPIGlobalHandlerTy();
}
/// Get the ELF code to recognize the compatible binary images.
- uint16_t getMagicElfBits() const override { return ELF::EM_X86_64; }
-
- bool isDataExchangable(int32_t SrcDeviceId, int32_t DstDeviceId) override {
- return isValidDeviceId(SrcDeviceId) && isValidDeviceId(DstDeviceId);
+ uint16_t getMagicElfBits() const override {
+ return utils::elf::getTargetMachine();
}
/// All images (ELF-compatible) should be compatible with this plugin.
Expected<bool> isELFCompatible(StringRef) const override { return true; }
- Triple::ArchType getTripleArch() const override { return Triple::x86_64; }
+ Triple::ArchType getTripleArch() const override {
+#if defined(__x86_64__)
+ return llvm::Triple::x86_64;
+#elif defined(__s390x__)
+ return llvm::Triple::systemz;
+#elif defined(__aarch64__)
+#ifdef LITTLEENDIAN_CPU
+ return llvm::Triple::aarch64;
+#else
+ return llvm::Triple::aarch64_be;
+#endif
+#elif defined(__powerpc64__)
+#ifdef LITTLEENDIAN_CPU
+ return llvm::Triple::ppc64le;
+#else
+ return llvm::Triple::ppc64;
+#endif
+#else
+ return llvm::Triple::UnknownArch;
+#endif
+ }
- // private:
- // TODO: How to mantain the EventSystem private and still allow the device to
- // access it?
+ const char *getName() const override { return GETNAME(TARGET_NAME); }
+
+private:
EventSystemTy EventSystem;
};
-GenericPluginTy *PluginTy::createPlugin() { return new MPIPluginTy(); }
-
template <typename... ArgsTy>
static Error Plugin::check(int32_t ErrorCode, const char *ErrFmt,
ArgsTy... Args) {
- if (ErrorCode == 0)
+ if (ErrorCode == OFFLOAD_SUCCESS)
return Error::success();
return createStringError<ArgsTy..., const char *>(
@@ -698,3 +724,9 @@ static Error Plugin::check(int32_t ErrorCode, const char *ErrFmt,
}
} // namespace llvm::omp::target::plugin
+
+extern "C" {
+llvm::omp::target::plugin::GenericPluginTy *createPlugin_mpi() {
+ return new llvm::omp::target::plugin::MPIPluginTy();
+}
+}
\ No newline at end of file
More information about the llvm-commits
mailing list