[llvm] [Offload] Add MPI Plugin (PR #90890)
Hervé Yviquel via llvm-commits
llvm-commits at lists.llvm.org
Thu May 2 13:43:30 PDT 2024
================
@@ -0,0 +1,1049 @@
+//===------ event_system.cpp - Concurrent MPI communication -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the implementation of the MPI Event System used by the MPI
+// target runtime for concurrent communication.
+//
+//===----------------------------------------------------------------------===//
+
+#include "EventSystem.h"
+
+#include <algorithm>
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <functional>
+#include <memory>
+
+#include <ffi.h>
+#include <mpi.h>
+
+#include "Shared/Debug.h"
+#include "Shared/EnvironmentVar.h"
+#include "Shared/Utils.h"
+#include "omptarget.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Error.h"
+
+#include "llvm/Support/DynamicLibrary.h"
+
+using llvm::sys::DynamicLibrary;
+
+#define CHECK(expr, msg, ...) \
+ if (!(expr)) { \
+ REPORT(msg, ##__VA_ARGS__); \
+ return false; \
+ }
+
+// Customizable parameters of the event system
+// =============================================================================
+// Number of execute event handlers to spawn.
+static IntEnvar NumExecEventHandlers("OMPTARGET_NUM_EXEC_EVENT_HANDLERS", 1);
+// Number of data event handlers to spawn.
+static IntEnvar NumDataEventHandlers("OMPTARGET_NUM_DATA_EVENT_HANDLERS", 1);
+// Polling rate period (us) used by event handlers.
+static IntEnvar EventPollingRate("OMPTARGET_EVENT_POLLING_RATE", 1);
+// Number of communicators to be spawned and distributed for the events.
+// Allows for parallel use of network resources.
+static Int64Envar NumMPIComms("OMPTARGET_NUM_MPI_COMMS", 10);
+// Maximum buffer Size to use during data transfer.
+static Int64Envar MPIFragmentSize("OMPTARGET_MPI_FRAGMENT_SIZE", 100e6);
+
+// Helper functions
+// =============================================================================
+const char *toString(EventTypeTy Type) {
+ using enum EventTypeTy;
+
+ switch (Type) {
+ case ALLOC:
+ return "Alloc";
+ case DELETE:
+ return "Delete";
+ case RETRIEVE:
+ return "Retrieve";
+ case SUBMIT:
+ return "Submit";
+ case EXCHANGE:
+ return "Exchange";
+ case EXCHANGE_SRC:
+ return "exchangeSrc";
+ case EXCHANGE_DST:
+ return "ExchangeDst";
+ case EXECUTE:
+ return "Execute";
+ case SYNC:
+ return "Sync";
+ case LOAD_BINARY:
+ return "LoadBinary";
+ case EXIT:
+ return "Exit";
+ }
+
+ assert(false && "Every enum value must be checked on the switch above.");
+ return nullptr;
+}
+
+// Coroutine events implementation
+// =============================================================================
+void EventTy::resume() {
+ // Acquire first handle not done.
+ const CoHandleTy &RootHandle = getHandle().promise().RootHandle;
+ auto &ResumableHandle = RootHandle.promise().PrevHandle;
+ while (ResumableHandle.done()) {
+ ResumableHandle = ResumableHandle.promise().PrevHandle;
+
+ if (ResumableHandle == RootHandle)
+ break;
+ }
+
+ if (!ResumableHandle.done())
+ ResumableHandle.resume();
+}
+
+void EventTy::wait() {
+ // Advance the event progress until it is completed.
+ while (!done()) {
+ resume();
+
+ std::this_thread::sleep_for(
+ std::chrono::microseconds(EventPollingRate.get()));
+ }
+}
+
+bool EventTy::done() const { return getHandle().done(); }
+
+bool EventTy::empty() const { return !getHandle(); }
+
+llvm::Error EventTy::getError() const {
+ auto &Error = getHandle().promise().CoroutineError;
+ if (Error)
+ return std::move(*Error);
+
+ return llvm::Error::success();
+}
+
+// Helpers
+// =============================================================================
+MPIRequestManagerTy::~MPIRequestManagerTy() {
+ assert(Requests.empty() && "Requests must be fulfilled and emptied before "
+ "destruction. Did you co_await on it?");
+}
+
+void MPIRequestManagerTy::send(const void *Buffer, int Size,
+ MPI_Datatype Datatype) {
+ MPI_Isend(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+ &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+void MPIRequestManagerTy::sendInBatchs(void *Buffer, int Size) {
+ // Operates over many fragments of the original buffer of at most
+ // MPI_FRAGMENT_SIZE bytes.
+ char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+ int64_t RemainingBytes = Size;
+ while (RemainingBytes > 0) {
+ send(&BufferByteArray[Size - RemainingBytes],
+ static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+ MPI_BYTE);
+ RemainingBytes -= MPIFragmentSize.get();
+ }
+}
+
+void MPIRequestManagerTy::receive(void *Buffer, int Size,
+ MPI_Datatype Datatype) {
+ MPI_Irecv(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+ &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+void MPIRequestManagerTy::receiveInBatchs(void *Buffer, int Size) {
+ // Operates over many fragments of the original buffer of at most
+ // MPI_FRAGMENT_SIZE bytes.
+ char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+ int64_t RemainingBytes = Size;
+ while (RemainingBytes > 0) {
+ receive(&BufferByteArray[Size - RemainingBytes],
+ static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+ MPI_BYTE);
+ RemainingBytes -= MPIFragmentSize.get();
+ }
+}
+
+EventTy MPIRequestManagerTy::wait() {
+ int RequestsCompleted = false;
+
+ while (!RequestsCompleted) {
+ int MPIError = MPI_Testall(Requests.size(), Requests.data(),
+ &RequestsCompleted, MPI_STATUSES_IGNORE);
+
+ if (MPIError != MPI_SUCCESS)
+ co_return createError("Waiting of MPI requests failed with code %d",
+ MPIError);
+
+ co_await std::suspend_always{};
+ }
+
+ Requests.clear();
+
+ co_return llvm::Error::success();
+}
+
+EventTy operator co_await(MPIRequestManagerTy &RequestManager) {
+ return RequestManager.wait();
+}
+
+// Device Image Storage
+// =============================================================================
+
+struct DeviceImage : __tgt_device_image {
+ llvm::SmallVector<unsigned char, 1> ImageBuffer;
+ llvm::SmallVector<__tgt_offload_entry, 16> Entries;
+ llvm::SmallVector<char> FlattenedEntryNames;
+
+ DeviceImage() {
+ ImageStart = nullptr;
+ ImageEnd = nullptr;
+ EntriesBegin = nullptr;
+ EntriesEnd = nullptr;
+ }
+
+ DeviceImage(size_t ImageSize, size_t EntryCount)
+ : ImageBuffer(ImageSize + alignof(void *)), Entries(EntryCount) {
+ // Align the image buffer to alignof(void *).
+ ImageStart = ImageBuffer.begin();
+ std::align(alignof(void *), ImageSize, ImageStart, ImageSize);
+ ImageEnd = (void *)((size_t)ImageStart + ImageSize);
+ }
+
+ void setImageEntries(llvm::SmallVector<size_t> EntryNameSizes) {
+ // Adjust the entry names to use the flattened name buffer.
+ size_t EntryCount = Entries.size();
+ size_t TotalNameSize = 0;
+ for (size_t I = 0; I < EntryCount; I++) {
+ TotalNameSize += EntryNameSizes[I];
+ }
+ FlattenedEntryNames.resize(TotalNameSize);
+
+ for (size_t I = EntryCount; I > 0; I--) {
+ TotalNameSize -= EntryNameSizes[I - 1];
+ Entries[I - 1].name = &FlattenedEntryNames[TotalNameSize];
+ }
+
+ // Set the entries pointers.
+ EntriesBegin = Entries.begin();
+ EntriesEnd = Entries.end();
+ }
+
+ /// Get the image size.
+ size_t getSize() const {
+ return llvm::omp::target::getPtrDiff(ImageEnd, ImageStart);
+ }
+
+ /// Getter and setter for the dynamic library.
+ DynamicLibrary &getDynamicLibrary() { return DynLib; }
+ void setDynamicLibrary(const DynamicLibrary &Lib) { DynLib = Lib; }
+
+private:
+ DynamicLibrary DynLib;
+};
+
+// Event Implementations
+// =============================================================================
+
+namespace OriginEvents {
+
+EventTy allocateBuffer(MPIRequestManagerTy RequestManager, int64_t Size,
+ void **Buffer) {
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+
+ RequestManager.receive(Buffer, sizeof(void *), MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy deleteBuffer(MPIRequestManagerTy RequestManager, void *Buffer) {
+ RequestManager.send(&Buffer, sizeof(void *), MPI_BYTE);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy submit(MPIRequestManagerTy RequestManager, EventDataHandleTy DataHandle,
+ void *DstBuffer, int64_t Size) {
+ RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+
+ RequestManager.sendInBatchs(DataHandle.get(), Size);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy retrieve(MPIRequestManagerTy RequestManager, void *OrgBuffer,
+ const void *DstBuffer, int64_t Size) {
+ RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.receiveInBatchs(OrgBuffer, Size);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy exchange(MPIRequestManagerTy RequestManager, int SrcDevice,
+ const void *OrgBuffer, int DstDevice, void *DstBuffer,
+ int64_t Size) {
+ // Send data to SrcDevice
+ RequestManager.send(&OrgBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.send(&DstDevice, 1, MPI_INT);
+
+ // Send data to DstDevice
+ RequestManager.OtherRank = DstDevice;
+ RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Size, 1, MPI_INT64_T);
+ RequestManager.send(&SrcDevice, 1, MPI_INT);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ RequestManager.OtherRank = SrcDevice;
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+ co_return (co_await RequestManager);
+}
+
+EventTy execute(MPIRequestManagerTy RequestManager, EventDataHandleTy Args,
+ uint32_t NumArgs, void *Func) {
+ RequestManager.send(&NumArgs, 1, MPI_UINT32_T);
+ RequestManager.send(Args.get(), NumArgs * sizeof(void *), MPI_BYTE);
+ RequestManager.send(&Func, sizeof(void *), MPI_BYTE);
+
+ // Event completion notification
+ RequestManager.receive(nullptr, 0, MPI_BYTE);
+ co_return (co_await RequestManager);
----------------
hyviquel wrote:
Not it won't work, our idea would be to enable C++20 for the plugin (or offload) since using coroutines and asynchronous MPI communications really improve the implementation of this plugin.
https://github.com/llvm/llvm-project/pull/90890
More information about the llvm-commits
mailing list