[llvm] [Offload] Add MPI Plugin (PR #90890)

Hervé Yviquel via llvm-commits llvm-commits at lists.llvm.org
Thu May 2 13:43:30 PDT 2024


================
@@ -0,0 +1,1049 @@
+//===------ event_system.cpp - Concurrent MPI communication -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the implementation of the MPI Event System used by the MPI
+// target runtime for concurrent communication.
+//
+//===----------------------------------------------------------------------===//
+
+#include "EventSystem.h"
+
+#include <algorithm>
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <functional>
+#include <memory>
+
+#include <ffi.h>
+#include <mpi.h>
+
+#include "Shared/Debug.h"
+#include "Shared/EnvironmentVar.h"
+#include "Shared/Utils.h"
+#include "omptarget.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Error.h"
+
+#include "llvm/Support/DynamicLibrary.h"
+
+using llvm::sys::DynamicLibrary;
+
+#define CHECK(expr, msg, ...)                                                  \
+  if (!(expr)) {                                                               \
+    REPORT(msg, ##__VA_ARGS__);                                                \
+    return false;                                                              \
+  }
+
+// Customizable parameters of the event system
+// =============================================================================
+// Number of execute event handlers to spawn.
+static IntEnvar NumExecEventHandlers("OMPTARGET_NUM_EXEC_EVENT_HANDLERS", 1);
+// Number of data event handlers to spawn.
+static IntEnvar NumDataEventHandlers("OMPTARGET_NUM_DATA_EVENT_HANDLERS", 1);
+// Polling rate period (us) used by event handlers.
+static IntEnvar EventPollingRate("OMPTARGET_EVENT_POLLING_RATE", 1);
+// Number of communicators to be spawned and distributed for the events.
+// Allows for parallel use of network resources.
+static Int64Envar NumMPIComms("OMPTARGET_NUM_MPI_COMMS", 10);
+// Maximum buffer Size to use during data transfer.
+static Int64Envar MPIFragmentSize("OMPTARGET_MPI_FRAGMENT_SIZE", 100e6);
+
+// Helper functions
+// =============================================================================
+const char *toString(EventTypeTy Type) {
+  using enum EventTypeTy;
+
+  switch (Type) {
+  case ALLOC:
+    return "Alloc";
+  case DELETE:
+    return "Delete";
+  case RETRIEVE:
+    return "Retrieve";
+  case SUBMIT:
+    return "Submit";
+  case EXCHANGE:
+    return "Exchange";
+  case EXCHANGE_SRC:
+    return "exchangeSrc";
+  case EXCHANGE_DST:
+    return "ExchangeDst";
+  case EXECUTE:
+    return "Execute";
+  case SYNC:
+    return "Sync";
+  case LOAD_BINARY:
+    return "LoadBinary";
+  case EXIT:
+    return "Exit";
+  }
+
+  assert(false && "Every enum value must be checked on the switch above.");
+  return nullptr;
+}
+
+// Coroutine events implementation
+// =============================================================================
+void EventTy::resume() {
+  // Acquire first handle not done.
+  const CoHandleTy &RootHandle = getHandle().promise().RootHandle;
+  auto &ResumableHandle = RootHandle.promise().PrevHandle;
+  while (ResumableHandle.done()) {
+    ResumableHandle = ResumableHandle.promise().PrevHandle;
+
+    if (ResumableHandle == RootHandle)
+      break;
+  }
+
+  if (!ResumableHandle.done())
+    ResumableHandle.resume();
+}
+
+void EventTy::wait() {
+  // Advance the event progress until it is completed.
+  while (!done()) {
+    resume();
+
+    std::this_thread::sleep_for(
+        std::chrono::microseconds(EventPollingRate.get()));
+  }
+}
+
+bool EventTy::done() const { return getHandle().done(); }
+
+bool EventTy::empty() const { return !getHandle(); }
+
+llvm::Error EventTy::getError() const {
+  auto &Error = getHandle().promise().CoroutineError;
+  if (Error)
+    return std::move(*Error);
+
+  return llvm::Error::success();
+}
+
+// Helpers
+// =============================================================================
+MPIRequestManagerTy::~MPIRequestManagerTy() {
+  assert(Requests.empty() && "Requests must be fulfilled and emptied before "
+                             "destruction. Did you co_await on it?");
+}
+
+void MPIRequestManagerTy::send(const void *Buffer, int Size,
+                               MPI_Datatype Datatype) {
+  MPI_Isend(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+            &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+void MPIRequestManagerTy::sendInBatchs(void *Buffer, int Size) {
+  // Operates over many fragments of the original buffer of at most
+  // MPI_FRAGMENT_SIZE bytes.
+  char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+  int64_t RemainingBytes = Size;
+  while (RemainingBytes > 0) {
+    send(&BufferByteArray[Size - RemainingBytes],
+         static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+         MPI_BYTE);
+    RemainingBytes -= MPIFragmentSize.get();
+  }
+}
+
+void MPIRequestManagerTy::receive(void *Buffer, int Size,
+                                  MPI_Datatype Datatype) {
+  MPI_Irecv(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+            &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+void MPIRequestManagerTy::receiveInBatchs(void *Buffer, int Size) {
+  // Operates over many fragments of the original buffer of at most
+  // MPI_FRAGMENT_SIZE bytes.
+  char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+  int64_t RemainingBytes = Size;
+  while (RemainingBytes > 0) {
+    receive(&BufferByteArray[Size - RemainingBytes],
+            static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+            MPI_BYTE);
+    RemainingBytes -= MPIFragmentSize.get();
+  }
+}
+
+EventTy MPIRequestManagerTy::wait() {
+  int RequestsCompleted = false;
+
+  while (!RequestsCompleted) {
+    int MPIError = MPI_Testall(Requests.size(), Requests.data(),
+                               &RequestsCompleted, MPI_STATUSES_IGNORE);
+
+    if (MPIError != MPI_SUCCESS)
+      co_return createError("Waiting of MPI requests failed with code %d",
+                            MPIError);
+
+    co_await std::suspend_always{};
+  }
+
+  Requests.clear();
+
+  co_return llvm::Error::success();
+}
+
+EventTy operator co_await(MPIRequestManagerTy &RequestManager) {
+  return RequestManager.wait();
+}
+
+// Device Image Storage
+// =============================================================================
+
+struct DeviceImage : __tgt_device_image {
+  llvm::SmallVector<unsigned char, 1> ImageBuffer;
+  llvm::SmallVector<__tgt_offload_entry, 16> Entries;
+  llvm::SmallVector<char> FlattenedEntryNames;
+
+  DeviceImage() {
+    ImageStart = nullptr;
+    ImageEnd = nullptr;
+    EntriesBegin = nullptr;
+    EntriesEnd = nullptr;
+  }
+
+  DeviceImage(size_t ImageSize, size_t EntryCount)
+      : ImageBuffer(ImageSize + alignof(void *)), Entries(EntryCount) {
+    // Align the image buffer to alignof(void *).
+    ImageStart = ImageBuffer.begin();
+    std::align(alignof(void *), ImageSize, ImageStart, ImageSize);
+    ImageEnd = (void *)((size_t)ImageStart + ImageSize);
+  }
+
+  void setImageEntries(llvm::SmallVector<size_t> EntryNameSizes) {
+    // Adjust the entry names to use the flattened name buffer.
+    size_t EntryCount = Entries.size();
+    size_t TotalNameSize = 0;
+    for (size_t I = 0; I < EntryCount; I++) {
+      TotalNameSize += EntryNameSizes[I];
+    }
+    FlattenedEntryNames.resize(TotalNameSize);
+
+    for (size_t I = EntryCount; I > 0; I--) {
+      TotalNameSize -= EntryNameSizes[I - 1];
+      Entries[I - 1].name = &FlattenedEntryNames[TotalNameSize];
+    }
+
+    // Set the entries pointers.
+    EntriesBegin = Entries.begin();
+    EntriesEnd = Entries.end();
+  }
+
+  /// Get the image size.
+  size_t getSize() const {
+    return llvm::omp::target::getPtrDiff(ImageEnd, ImageStart);
+  }
+
+  /// Getter and setter for the dynamic library.
+  DynamicLibrary &getDynamicLibrary() { return DynLib; }
+  void setDynamicLibrary(const DynamicLibrary &Lib) { DynLib = Lib; }
+
+private:
+  DynamicLibrary DynLib;
+};
+
+// Event Implementations
+// =============================================================================
+
+namespace OriginEvents {
+
+EventTy allocateBuffer(MPIRequestManagerTy RequestManager, int64_t Size,
+                       void **Buffer) {
+  RequestManager.send(&Size, 1, MPI_INT64_T);
+
+  RequestManager.receive(Buffer, sizeof(void *), MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy deleteBuffer(MPIRequestManagerTy RequestManager, void *Buffer) {
+  RequestManager.send(&Buffer, sizeof(void *), MPI_BYTE);
+
+  // Event completion notification
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy submit(MPIRequestManagerTy RequestManager, EventDataHandleTy DataHandle,
+               void *DstBuffer, int64_t Size) {
+  RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+  RequestManager.send(&Size, 1, MPI_INT64_T);
+
+  RequestManager.sendInBatchs(DataHandle.get(), Size);
+
+  // Event completion notification
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy retrieve(MPIRequestManagerTy RequestManager, void *OrgBuffer,
+                 const void *DstBuffer, int64_t Size) {
+  RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+  RequestManager.send(&Size, 1, MPI_INT64_T);
+  RequestManager.receiveInBatchs(OrgBuffer, Size);
+
+  // Event completion notification
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy exchange(MPIRequestManagerTy RequestManager, int SrcDevice,
+                 const void *OrgBuffer, int DstDevice, void *DstBuffer,
+                 int64_t Size) {
+  // Send data to SrcDevice
+  RequestManager.send(&OrgBuffer, sizeof(void *), MPI_BYTE);
+  RequestManager.send(&Size, 1, MPI_INT64_T);
+  RequestManager.send(&DstDevice, 1, MPI_INT);
+
+  // Send data to DstDevice
+  RequestManager.OtherRank = DstDevice;
+  RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+  RequestManager.send(&Size, 1, MPI_INT64_T);
+  RequestManager.send(&SrcDevice, 1, MPI_INT);
+
+  // Event completion notification
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+  RequestManager.OtherRank = SrcDevice;
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy execute(MPIRequestManagerTy RequestManager, EventDataHandleTy Args,
+                uint32_t NumArgs, void *Func) {
+  RequestManager.send(&NumArgs, 1, MPI_UINT32_T);
+  RequestManager.send(Args.get(), NumArgs * sizeof(void *), MPI_BYTE);
+  RequestManager.send(&Func, sizeof(void *), MPI_BYTE);
+
+  // Event completion notification
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+  co_return (co_await RequestManager);
----------------
hyviquel wrote:

Not it won't work, our idea would be to enable C++20 for the plugin (or offload) since using coroutines and asynchronous MPI communications really improve the implementation of this plugin.

https://github.com/llvm/llvm-project/pull/90890


More information about the llvm-commits mailing list