[llvm] [Offload] Add MPI Plugin (PR #90890)
Jhonatan Cléto via llvm-commits
llvm-commits at lists.llvm.org
Fri May 3 10:55:02 PDT 2024
================
@@ -0,0 +1,1049 @@
+//===------ event_system.cpp - Concurrent MPI communication -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the implementation of the MPI Event System used by the MPI
+// target runtime for concurrent communication.
+//
+//===----------------------------------------------------------------------===//
+
+#include "EventSystem.h"
+
+#include <algorithm>
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <functional>
+#include <memory>
+
+#include <ffi.h>
+#include <mpi.h>
+
+#include "Shared/Debug.h"
+#include "Shared/EnvironmentVar.h"
+#include "Shared/Utils.h"
+#include "omptarget.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Error.h"
+
+#include "llvm/Support/DynamicLibrary.h"
+
+using llvm::sys::DynamicLibrary;
+
+#define CHECK(expr, msg, ...) \
+ if (!(expr)) { \
+ REPORT(msg, ##__VA_ARGS__); \
+ return false; \
+ }
+
+// Customizable parameters of the event system
+// =============================================================================
+// Number of execute event handlers to spawn.
+static IntEnvar NumExecEventHandlers("OMPTARGET_NUM_EXEC_EVENT_HANDLERS", 1);
+// Number of data event handlers to spawn.
+static IntEnvar NumDataEventHandlers("OMPTARGET_NUM_DATA_EVENT_HANDLERS", 1);
+// Polling rate period (us) used by event handlers.
+static IntEnvar EventPollingRate("OMPTARGET_EVENT_POLLING_RATE", 1);
+// Number of communicators to be spawned and distributed for the events.
+// Allows for parallel use of network resources.
+static Int64Envar NumMPIComms("OMPTARGET_NUM_MPI_COMMS", 10);
+// Maximum buffer Size to use during data transfer.
+static Int64Envar MPIFragmentSize("OMPTARGET_MPI_FRAGMENT_SIZE", 100e6);
+
+// Helper functions
+// =============================================================================
+const char *toString(EventTypeTy Type) {
+ using enum EventTypeTy;
+
+ switch (Type) {
+ case ALLOC:
+ return "Alloc";
+ case DELETE:
+ return "Delete";
+ case RETRIEVE:
+ return "Retrieve";
+ case SUBMIT:
+ return "Submit";
+ case EXCHANGE:
+ return "Exchange";
+ case EXCHANGE_SRC:
+ return "exchangeSrc";
+ case EXCHANGE_DST:
+ return "ExchangeDst";
+ case EXECUTE:
+ return "Execute";
+ case SYNC:
+ return "Sync";
+ case LOAD_BINARY:
+ return "LoadBinary";
+ case EXIT:
+ return "Exit";
+ }
+
+ assert(false && "Every enum value must be checked on the switch above.");
+ return nullptr;
+}
+
+// Coroutine events implementation
+// =============================================================================
+void EventTy::resume() {
+ // Acquire first handle not done.
+ const CoHandleTy &RootHandle = getHandle().promise().RootHandle;
+ auto &ResumableHandle = RootHandle.promise().PrevHandle;
+ while (ResumableHandle.done()) {
+ ResumableHandle = ResumableHandle.promise().PrevHandle;
+
+ if (ResumableHandle == RootHandle)
+ break;
+ }
+
+ if (!ResumableHandle.done())
+ ResumableHandle.resume();
+}
+
+void EventTy::wait() {
+ // Advance the event progress until it is completed.
+ while (!done()) {
+ resume();
+
+ std::this_thread::sleep_for(
+ std::chrono::microseconds(EventPollingRate.get()));
+ }
+}
+
+bool EventTy::done() const { return getHandle().done(); }
+
+bool EventTy::empty() const { return !getHandle(); }
+
+llvm::Error EventTy::getError() const {
+ auto &Error = getHandle().promise().CoroutineError;
+ if (Error)
+ return std::move(*Error);
+
+ return llvm::Error::success();
+}
+
+// Helpers
+// =============================================================================
+MPIRequestManagerTy::~MPIRequestManagerTy() {
+ assert(Requests.empty() && "Requests must be fulfilled and emptied before "
+ "destruction. Did you co_await on it?");
+}
+
+void MPIRequestManagerTy::send(const void *Buffer, int Size,
+ MPI_Datatype Datatype) {
+ MPI_Isend(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+ &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+void MPIRequestManagerTy::sendInBatchs(void *Buffer, int Size) {
+ // Operates over many fragments of the original buffer of at most
+ // MPI_FRAGMENT_SIZE bytes.
+ char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+ int64_t RemainingBytes = Size;
+ while (RemainingBytes > 0) {
+ send(&BufferByteArray[Size - RemainingBytes],
+ static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+ MPI_BYTE);
+ RemainingBytes -= MPIFragmentSize.get();
+ }
+}
+
+void MPIRequestManagerTy::receive(void *Buffer, int Size,
+ MPI_Datatype Datatype) {
+ MPI_Irecv(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+ &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+void MPIRequestManagerTy::receiveInBatchs(void *Buffer, int Size) {
+ // Operates over many fragments of the original buffer of at most
+ // MPI_FRAGMENT_SIZE bytes.
+ char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+ int64_t RemainingBytes = Size;
+ while (RemainingBytes > 0) {
+ receive(&BufferByteArray[Size - RemainingBytes],
+ static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+ MPI_BYTE);
+ RemainingBytes -= MPIFragmentSize.get();
+ }
+}
+
+EventTy MPIRequestManagerTy::wait() {
+ int RequestsCompleted = false;
+
+ while (!RequestsCompleted) {
+ int MPIError = MPI_Testall(Requests.size(), Requests.data(),
+ &RequestsCompleted, MPI_STATUSES_IGNORE);
+
+ if (MPIError != MPI_SUCCESS)
+ co_return createError("Waiting of MPI requests failed with code %d",
+ MPIError);
+
+ co_await std::suspend_always{};
+ }
+
+ Requests.clear();
+
+ co_return llvm::Error::success();
+}
+
+EventTy operator co_await(MPIRequestManagerTy &RequestManager) {
+ return RequestManager.wait();
+}
+
+// Device Image Storage
+// =============================================================================
+
+struct DeviceImage : __tgt_device_image {
+ llvm::SmallVector<unsigned char, 1> ImageBuffer;
+ llvm::SmallVector<__tgt_offload_entry, 16> Entries;
+ llvm::SmallVector<char> FlattenedEntryNames;
+
+ DeviceImage() {
+ ImageStart = nullptr;
+ ImageEnd = nullptr;
+ EntriesBegin = nullptr;
+ EntriesEnd = nullptr;
+ }
+
+ DeviceImage(size_t ImageSize, size_t EntryCount)
+ : ImageBuffer(ImageSize + alignof(void *)), Entries(EntryCount) {
+ // Align the image buffer to alignof(void *).
+ ImageStart = ImageBuffer.begin();
+ std::align(alignof(void *), ImageSize, ImageStart, ImageSize);
+ ImageEnd = (void *)((size_t)ImageStart + ImageSize);
+ }
+
+ void setImageEntries(llvm::SmallVector<size_t> EntryNameSizes) {
+ // Adjust the entry names to use the flattened name buffer.
+ size_t EntryCount = Entries.size();
+ size_t TotalNameSize = 0;
+ for (size_t I = 0; I < EntryCount; I++) {
+ TotalNameSize += EntryNameSizes[I];
+ }
+ FlattenedEntryNames.resize(TotalNameSize);
+
+ for (size_t I = EntryCount; I > 0; I--) {
+ TotalNameSize -= EntryNameSizes[I - 1];
+ Entries[I - 1].name = &FlattenedEntryNames[TotalNameSize];
+ }
+
+ // Set the entries pointers.
+ EntriesBegin = Entries.begin();
+ EntriesEnd = Entries.end();
+ }
+
+ /// Get the image size.
+ size_t getSize() const {
+ return llvm::omp::target::getPtrDiff(ImageEnd, ImageStart);
+ }
+
+ /// Getter and setter for the dynamic library.
+ DynamicLibrary &getDynamicLibrary() { return DynLib; }
+ void setDynamicLibrary(const DynamicLibrary &Lib) { DynLib = Lib; }
+
+private:
+ DynamicLibrary DynLib;
+};
+
+// Event Implementations
+// =============================================================================
+
+namespace OriginEvents {
+
+EventTy allocateBuffer(MPIRequestManagerTy RequestManager, int64_t Size,
+ void **Buffer) {
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+
+ RequestManager.receive(Buffer, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy deleteBuffer(MPIRequestManagerTy RequestManager, void *Buffer) {
+ RequestManager.send(&Buffer, sizeof(void *), MPI_BYTE);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy submit(MPIRequestManagerTy RequestManager, EventDataHandleTy DataHandle,
+ void *DstBuffer, int64_t Size) {
+ RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+
+ RequestManager.sendInBatchs(DataHandle.get(), Size);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy retrieve(MPIRequestManagerTy RequestManager, void *OrgBuffer,
+ const void *DstBuffer, int64_t Size) {
+ RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.receiveInBatchs(OrgBuffer, Size);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy exchange(MPIRequestManagerTy RequestManager, int SrcDevice,
+ const void *OrgBuffer, int DstDevice, void *DstBuffer,
+ int64_t Size) {
+ // Send data to SrcDevice
+ RequestManager.send(&OrgBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.send(&DstDevice, 1, MPI_INT);
+
+ // Send data to DstDevice
+ RequestManager.OtherRank = DstDevice;
+ RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.send(&SrcDevice, 1, MPI_INT);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ RequestManager.OtherRank = SrcDevice;
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy execute(MPIRequestManagerTy RequestManager, EventDataHandleTy Args,
+ uint32_t NumArgs, void *Func) {
+ RequestManager.send(&NumArgs, 1, MPI_UINT32_T);
+ RequestManager.send(Args.get(), NumArgs * sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Func, sizeof(void *), MPI_BYTE);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+EventTy sync(EventTy Event) {
+ while (!Event.done())
+ co_await std::suspend_always{};
+
+ co_return llvm::Error::success();
+}
+
+EventTy loadBinary(MPIRequestManagerTy RequestManager,
+ const __tgt_device_image *Image,
+ llvm::SmallVector<void *> *DeviceImageAddrs) {
+ auto &[ImageStart, ImageEnd, EntriesBegin, EntriesEnd] = *Image;
+
+ // Send the target table sizes.
+ size_t ImageSize = (size_t)ImageEnd - (size_t)ImageStart;
+ size_t EntryCount = EntriesEnd - EntriesBegin;
+ llvm::SmallVector<size_t> EntryNameSizes(EntryCount);
+
+ for (size_t I = 0; I < EntryCount; I++) {
+ // Note: +1 for the terminator.
+ EntryNameSizes[I] = std::strlen(EntriesBegin[I].name) + 1;
+ }
+
+ RequestManager.send(&ImageSize, 1, MPI_UINT64_T);
+ RequestManager.send(&EntryCount, 1, MPI_UINT64_T);
+ RequestManager.send(EntryNameSizes.begin(), EntryCount, MPI_UINT64_T);
+
+ // Send the image bytes and the table entries.
+ RequestManager.send(ImageStart, ImageSize, MPI_BYTE);
+
+ for (size_t I = 0; I < EntryCount; I++) {
+ RequestManager.send(&EntriesBegin[I].addr, 1, MPI_UINT64_T);
+ RequestManager.send(EntriesBegin[I].name, EntryNameSizes[I], MPI_CHAR);
+ RequestManager.send(&EntriesBegin[I].size, 1, MPI_UINT64_T);
+ RequestManager.send(&EntriesBegin[I].flags, 1, MPI_INT32_T);
+ RequestManager.send(&EntriesBegin[I].data, 1, MPI_INT32_T);
+ }
+
+ for (size_t I = 0; I < EntryCount; I++) {
+ RequestManager.receive(&((*DeviceImageAddrs)[I]), 1, MPI_UINT64_T);
+ }
+
+ co_return (co_await RequestManager);
+}
+
+EventTy exit(MPIRequestManagerTy RequestManager) {
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
+}
+
+} // namespace OriginEvents
+
+namespace DestinationEvents {
+
+EventTy allocateBuffer(MPIRequestManagerTy RequestManager) {
+ int64_t Size = 0;
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ void *Buffer = malloc(Size);
+ RequestManager.send(&Buffer, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy deleteBuffer(MPIRequestManagerTy RequestManager) {
+ void *Buffer = nullptr;
+ RequestManager.receive(&Buffer, sizeof(void *), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ free(Buffer);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy submit(MPIRequestManagerTy RequestManager) {
+ void *Buffer = nullptr;
+ int64_t Size = 0;
+ RequestManager.receive(&Buffer, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ RequestManager.receiveInBatchs(Buffer, Size);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy retrieve(MPIRequestManagerTy RequestManager) {
+ void *Buffer = nullptr;
+ int64_t Size = 0;
+ RequestManager.receive(&Buffer, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ RequestManager.sendInBatchs(Buffer, Size);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy exchangeSrc(MPIRequestManagerTy RequestManager) {
+ void *SrcBuffer;
+ int64_t Size;
+ int DstDevice;
+ // Save head node rank
+ int HeadNodeRank = RequestManager.OtherRank;
+
+ RequestManager.receive(&SrcBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+ RequestManager.receive(&DstDevice, 1, MPI_INT);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Set the Destination Rank in RequestManager
+ RequestManager.OtherRank = DstDevice;
+
+ // Send buffer to target device
+ RequestManager.sendInBatchs(SrcBuffer, Size);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Set the HeadNode Rank to send the final notificatin
+ RequestManager.OtherRank = HeadNodeRank;
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy exchangeDst(MPIRequestManagerTy RequestManager) {
+ void *DstBuffer;
+ int64_t Size;
+ int SrcDevice;
+ // Save head node rank
+ int HeadNodeRank = RequestManager.OtherRank;
+
+ RequestManager.receive(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.receive(&Size, 1, MPI_INT64_T);
+ RequestManager.receive(&SrcDevice, 1, MPI_INT);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Set the Source Rank in RequestManager
+ RequestManager.OtherRank = SrcDevice;
+
+ // Receive buffer from the Source device
+ RequestManager.receiveInBatchs(DstBuffer, Size);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Set the HeadNode Rank to send the final notificatin
+ RequestManager.OtherRank = HeadNodeRank;
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy execute(MPIRequestManagerTy RequestManager,
+ __tgt_device_image &DeviceImage) {
+
+ uint32_t NumArgs = 0;
+ RequestManager.receive(&NumArgs, 1, MPI_UINT32_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ llvm::SmallVector<void *> Args(NumArgs);
+ llvm::SmallVector<void *> ArgPtrs(NumArgs);
+
+ RequestManager.receive(Args.data(), NumArgs * sizeof(uintptr_t), MPI_BYTE);
+ void (*TargetFunc)(void) = nullptr;
+ RequestManager.receive(&TargetFunc, sizeof(uintptr_t), MPI_BYTE);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Get Args references
+ for (unsigned I = 0; I < NumArgs; I++) {
+ ArgPtrs[I] = &Args[I];
+ }
+
+ ffi_cif Cif{};
+ llvm::SmallVector<ffi_type *> ArgsTypes(NumArgs, &ffi_type_pointer);
+ ffi_status Status = ffi_prep_cif(&Cif, FFI_DEFAULT_ABI, NumArgs,
+ &ffi_type_void, &ArgsTypes[0]);
+
+ if (Status != FFI_OK) {
+ co_return createError("Error in ffi_prep_cif: %d", Status);
+ }
+
+ long Return;
+ ffi_call(&Cif, TargetFunc, &Return, &ArgPtrs[0]);
+
+ // Event completion notification
+ RequestManager.send(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy loadBinary(MPIRequestManagerTy RequestManager, DeviceImage &Image) {
+ // Receive the target table sizes.
+ size_t ImageSize = 0;
+ size_t EntryCount = 0;
+ RequestManager.receive(&ImageSize, 1, MPI_UINT64_T);
+ RequestManager.receive(&EntryCount, 1, MPI_UINT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ llvm::SmallVector<size_t> EntryNameSizes(EntryCount);
+
+ RequestManager.receive(EntryNameSizes.begin(), EntryCount, MPI_UINT64_T);
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // Create the device name with the appropriate sizes and receive its content.
+ Image = DeviceImage(ImageSize, EntryCount);
+ Image.setImageEntries(EntryNameSizes);
+
+ // Received the image bytes and the table entries.
+ RequestManager.receive(Image.ImageStart, ImageSize, MPI_BYTE);
+
+ for (size_t I = 0; I < EntryCount; I++) {
+ RequestManager.receive(&Image.Entries[I].addr, 1, MPI_UINT64_T);
+ RequestManager.receive(Image.Entries[I].name, EntryNameSizes[I], MPI_CHAR);
+ RequestManager.receive(&Image.Entries[I].size, 1, MPI_UINT64_T);
+ RequestManager.receive(&Image.Entries[I].flags, 1, MPI_INT32_T);
+ RequestManager.receive(&Image.Entries[I].data, 1, MPI_INT32_T);
+ }
+
+ if (auto Error = co_await RequestManager; Error)
+ co_return Error;
+
+ // The code below is for CPU plugin only
+ // Create a temporary file.
+ char TmpFileName[] = "/tmp/tmpfile_XXXXXX";
----------------
cl3to wrote:
They are only removed on system restart if configured as such. I'll add file deletion at the end of the dynlib loading process.
https://github.com/llvm/llvm-project/pull/90890
More information about the llvm-commits
mailing list