[llvm] [Offload] Add MPI Plugin (PR #90890)

Jhonatan Cléto via llvm-commits llvm-commits at lists.llvm.org
Fri May 3 10:55:02 PDT 2024


================
@@ -0,0 +1,1049 @@
+//===------ event_system.cpp - Concurrent MPI communication -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the implementation of the MPI Event System used by the MPI
+// target runtime for concurrent communication.
+//
+//===----------------------------------------------------------------------===//
+
+#include "EventSystem.h"
+
+#include <algorithm>
+#include <chrono>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <functional>
+#include <memory>
+
+#include <ffi.h>
+#include <mpi.h>
+
+#include "Shared/Debug.h"
+#include "Shared/EnvironmentVar.h"
+#include "Shared/Utils.h"
+#include "omptarget.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Error.h"
+
+#include "llvm/Support/DynamicLibrary.h"
+
+using llvm::sys::DynamicLibrary;
+
+#define CHECK(expr, msg, ...)                                                  \
+  if (!(expr)) {                                                               \
+    REPORT(msg, ##__VA_ARGS__);                                                \
+    return false;                                                              \
+  }
+
+// Customizable parameters of the event system
+// =============================================================================
+// Number of execute event handlers to spawn.
+static IntEnvar NumExecEventHandlers("OMPTARGET_NUM_EXEC_EVENT_HANDLERS", 1);
+// Number of data event handlers to spawn.
+static IntEnvar NumDataEventHandlers("OMPTARGET_NUM_DATA_EVENT_HANDLERS", 1);
+// Polling rate period (us) used by event handlers.
+static IntEnvar EventPollingRate("OMPTARGET_EVENT_POLLING_RATE", 1);
+// Number of communicators to be spawned and distributed for the events.
+// Allows for parallel use of network resources.
+static Int64Envar NumMPIComms("OMPTARGET_NUM_MPI_COMMS", 10);
+// Maximum buffer Size to use during data transfer.
+static Int64Envar MPIFragmentSize("OMPTARGET_MPI_FRAGMENT_SIZE", 100e6);
+
+// Helper functions
+// =============================================================================
+const char *toString(EventTypeTy Type) {
+  using enum EventTypeTy;
+
+  switch (Type) {
+  case ALLOC:
+    return "Alloc";
+  case DELETE:
+    return "Delete";
+  case RETRIEVE:
+    return "Retrieve";
+  case SUBMIT:
+    return "Submit";
+  case EXCHANGE:
+    return "Exchange";
+  case EXCHANGE_SRC:
+    return "exchangeSrc";
+  case EXCHANGE_DST:
+    return "ExchangeDst";
+  case EXECUTE:
+    return "Execute";
+  case SYNC:
+    return "Sync";
+  case LOAD_BINARY:
+    return "LoadBinary";
+  case EXIT:
+    return "Exit";
+  }
+
+  assert(false && "Every enum value must be checked on the switch above.");
+  return nullptr;
+}
+
+// Coroutine events implementation
+// =============================================================================
+void EventTy::resume() {
+  // Acquire first handle not done.
+  const CoHandleTy &RootHandle = getHandle().promise().RootHandle;
+  auto &ResumableHandle = RootHandle.promise().PrevHandle;
+  while (ResumableHandle.done()) {
+    ResumableHandle = ResumableHandle.promise().PrevHandle;
+
+    if (ResumableHandle == RootHandle)
+      break;
+  }
+
+  if (!ResumableHandle.done())
+    ResumableHandle.resume();
+}
+
+void EventTy::wait() {
+  // Advance the event progress until it is completed.
+  while (!done()) {
+    resume();
+
+    std::this_thread::sleep_for(
+        std::chrono::microseconds(EventPollingRate.get()));
+  }
+}
+
+bool EventTy::done() const { return getHandle().done(); }
+
+bool EventTy::empty() const { return !getHandle(); }
+
+llvm::Error EventTy::getError() const {
+  auto &Error = getHandle().promise().CoroutineError;
+  if (Error)
+    return std::move(*Error);
+
+  return llvm::Error::success();
+}
+
+// Helpers
+// =============================================================================
+MPIRequestManagerTy::~MPIRequestManagerTy() {
+  assert(Requests.empty() && "Requests must be fulfilled and emptied before "
+                             "destruction. Did you co_await on it?");
+}
+
+void MPIRequestManagerTy::send(const void *Buffer, int Size,
+                               MPI_Datatype Datatype) {
+  MPI_Isend(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+            &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+void MPIRequestManagerTy::sendInBatchs(void *Buffer, int Size) {
+  // Operates over many fragments of the original buffer of at most
+  // MPI_FRAGMENT_SIZE bytes.
+  char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+  int64_t RemainingBytes = Size;
+  while (RemainingBytes > 0) {
+    send(&BufferByteArray[Size - RemainingBytes],
+         static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+         MPI_BYTE);
+    RemainingBytes -= MPIFragmentSize.get();
+  }
+}
+
+void MPIRequestManagerTy::receive(void *Buffer, int Size,
+                                  MPI_Datatype Datatype) {
+  MPI_Irecv(Buffer, Size, Datatype, OtherRank, Tag, Comm,
+            &Requests.emplace_back(MPI_REQUEST_NULL));
+}
+
+void MPIRequestManagerTy::receiveInBatchs(void *Buffer, int Size) {
+  // Operates over many fragments of the original buffer of at most
+  // MPI_FRAGMENT_SIZE bytes.
+  char *BufferByteArray = reinterpret_cast<char *>(Buffer);
+  int64_t RemainingBytes = Size;
+  while (RemainingBytes > 0) {
+    receive(&BufferByteArray[Size - RemainingBytes],
+            static_cast<int>(std::min(RemainingBytes, MPIFragmentSize.get())),
+            MPI_BYTE);
+    RemainingBytes -= MPIFragmentSize.get();
+  }
+}
+
+EventTy MPIRequestManagerTy::wait() {
+  int RequestsCompleted = false;
+
+  while (!RequestsCompleted) {
+    int MPIError = MPI_Testall(Requests.size(), Requests.data(),
+                               &RequestsCompleted, MPI_STATUSES_IGNORE);
+
+    if (MPIError != MPI_SUCCESS)
+      co_return createError("Waiting of MPI requests failed with code %d",
+                            MPIError);
+
+    co_await std::suspend_always{};
+  }
+
+  Requests.clear();
+
+  co_return llvm::Error::success();
+}
+
+EventTy operator co_await(MPIRequestManagerTy &RequestManager) {
+  return RequestManager.wait();
+}
+
+// Device Image Storage
+// =============================================================================
+
+struct DeviceImage : __tgt_device_image {
+  llvm::SmallVector<unsigned char, 1> ImageBuffer;
+  llvm::SmallVector<__tgt_offload_entry, 16> Entries;
+  llvm::SmallVector<char> FlattenedEntryNames;
+
+  DeviceImage() {
+    ImageStart = nullptr;
+    ImageEnd = nullptr;
+    EntriesBegin = nullptr;
+    EntriesEnd = nullptr;
+  }
+
+  DeviceImage(size_t ImageSize, size_t EntryCount)
+      : ImageBuffer(ImageSize + alignof(void *)), Entries(EntryCount) {
+    // Align the image buffer to alignof(void *).
+    ImageStart = ImageBuffer.begin();
+    std::align(alignof(void *), ImageSize, ImageStart, ImageSize);
+    ImageEnd = (void *)((size_t)ImageStart + ImageSize);
+  }
+
+  void setImageEntries(llvm::SmallVector<size_t> EntryNameSizes) {
+    // Adjust the entry names to use the flattened name buffer.
+    size_t EntryCount = Entries.size();
+    size_t TotalNameSize = 0;
+    for (size_t I = 0; I < EntryCount; I++) {
+      TotalNameSize += EntryNameSizes[I];
+    }
+    FlattenedEntryNames.resize(TotalNameSize);
+
+    for (size_t I = EntryCount; I > 0; I--) {
+      TotalNameSize -= EntryNameSizes[I - 1];
+      Entries[I - 1].name = &FlattenedEntryNames[TotalNameSize];
+    }
+
+    // Set the entries pointers.
+    EntriesBegin = Entries.begin();
+    EntriesEnd = Entries.end();
+  }
+
+  /// Get the image size.
+  size_t getSize() const {
+    return llvm::omp::target::getPtrDiff(ImageEnd, ImageStart);
+  }
+
+  /// Getter and setter for the dynamic library.
+  DynamicLibrary &getDynamicLibrary() { return DynLib; }
+  void setDynamicLibrary(const DynamicLibrary &Lib) { DynLib = Lib; }
+
+private:
+  DynamicLibrary DynLib;
+};
+
+// Event Implementations
+// =============================================================================
+
+namespace OriginEvents {
+
+EventTy allocateBuffer(MPIRequestManagerTy RequestManager, int64_t Size,
+                       void **Buffer) {
+  RequestManager.send(&Size, 1, MPI_INT64_T);
+
+  RequestManager.receive(Buffer, sizeof(void *), MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy deleteBuffer(MPIRequestManagerTy RequestManager, void *Buffer) {
+  RequestManager.send(&Buffer, sizeof(void *), MPI_BYTE);
+
+  // Event completion notification
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy submit(MPIRequestManagerTy RequestManager, EventDataHandleTy DataHandle,
+               void *DstBuffer, int64_t Size) {
+  RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+  RequestManager.send(&Size, 1, MPI_INT64_T);
+
+  RequestManager.sendInBatchs(DataHandle.get(), Size);
+
+  // Event completion notification
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy retrieve(MPIRequestManagerTy RequestManager, void *OrgBuffer,
+                 const void *DstBuffer, int64_t Size) {
+  RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+  RequestManager.send(&Size, 1, MPI_INT64_T);
+  RequestManager.receiveInBatchs(OrgBuffer, Size);
+
+  // Event completion notification
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy exchange(MPIRequestManagerTy RequestManager, int SrcDevice,
+                 const void *OrgBuffer, int DstDevice, void *DstBuffer,
+                 int64_t Size) {
+  // Send data to SrcDevice
+  RequestManager.send(&OrgBuffer, sizeof(void *), MPI_BYTE);
+  RequestManager.send(&Size, 1, MPI_INT64_T);
+  RequestManager.send(&DstDevice, 1, MPI_INT);
+
+  // Send data to DstDevice
+  RequestManager.OtherRank = DstDevice;
+  RequestManager.send(&DstBuffer, sizeof(void *), MPI_BYTE);
+  RequestManager.send(&Size, 1, MPI_INT64_T);
+  RequestManager.send(&SrcDevice, 1, MPI_INT);
+
+  // Event completion notification
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+  RequestManager.OtherRank = SrcDevice;
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy execute(MPIRequestManagerTy RequestManager, EventDataHandleTy Args,
+                uint32_t NumArgs, void *Func) {
+  RequestManager.send(&NumArgs, 1, MPI_UINT32_T);
+  RequestManager.send(Args.get(), NumArgs * sizeof(void *), MPI_BYTE);
+  RequestManager.send(&Func, sizeof(void *), MPI_BYTE);
+
+  // Event completion notification
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+  co_return (co_await RequestManager);
+}
+
+EventTy sync(EventTy Event) {
+  while (!Event.done())
+    co_await std::suspend_always{};
+
+  co_return llvm::Error::success();
+}
+
+EventTy loadBinary(MPIRequestManagerTy RequestManager,
+                   const __tgt_device_image *Image,
+                   llvm::SmallVector<void *> *DeviceImageAddrs) {
+  auto &[ImageStart, ImageEnd, EntriesBegin, EntriesEnd] = *Image;
+
+  // Send the target table sizes.
+  size_t ImageSize = (size_t)ImageEnd - (size_t)ImageStart;
+  size_t EntryCount = EntriesEnd - EntriesBegin;
+  llvm::SmallVector<size_t> EntryNameSizes(EntryCount);
+
+  for (size_t I = 0; I < EntryCount; I++) {
+    // Note: +1 for the terminator.
+    EntryNameSizes[I] = std::strlen(EntriesBegin[I].name) + 1;
+  }
+
+  RequestManager.send(&ImageSize, 1, MPI_UINT64_T);
+  RequestManager.send(&EntryCount, 1, MPI_UINT64_T);
+  RequestManager.send(EntryNameSizes.begin(), EntryCount, MPI_UINT64_T);
+
+  // Send the image bytes and the table entries.
+  RequestManager.send(ImageStart, ImageSize, MPI_BYTE);
+
+  for (size_t I = 0; I < EntryCount; I++) {
+    RequestManager.send(&EntriesBegin[I].addr, 1, MPI_UINT64_T);
+    RequestManager.send(EntriesBegin[I].name, EntryNameSizes[I], MPI_CHAR);
+    RequestManager.send(&EntriesBegin[I].size, 1, MPI_UINT64_T);
+    RequestManager.send(&EntriesBegin[I].flags, 1, MPI_INT32_T);
+    RequestManager.send(&EntriesBegin[I].data, 1, MPI_INT32_T);
+  }
+
+  for (size_t I = 0; I < EntryCount; I++) {
+    RequestManager.receive(&((*DeviceImageAddrs)[I]), 1, MPI_UINT64_T);
+  }
+
+  co_return (co_await RequestManager);
+}
+
+EventTy exit(MPIRequestManagerTy RequestManager) {
+  // Event completion notification
+  RequestManager.receive(nullptr, 0, MPI_BYTE);
+  co_return (co_await RequestManager);
+}
+
+} // namespace OriginEvents
+
+namespace DestinationEvents {
+
+EventTy allocateBuffer(MPIRequestManagerTy RequestManager) {
+  int64_t Size = 0;
+  RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  void *Buffer = malloc(Size);
+  RequestManager.send(&Buffer, sizeof(void *), MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy deleteBuffer(MPIRequestManagerTy RequestManager) {
+  void *Buffer = nullptr;
+  RequestManager.receive(&Buffer, sizeof(void *), MPI_BYTE);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  free(Buffer);
+
+  // Event completion notification
+  RequestManager.send(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy submit(MPIRequestManagerTy RequestManager) {
+  void *Buffer = nullptr;
+  int64_t Size = 0;
+  RequestManager.receive(&Buffer, sizeof(void *), MPI_BYTE);
+  RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  RequestManager.receiveInBatchs(Buffer, Size);
+
+  // Event completion notification
+  RequestManager.send(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy retrieve(MPIRequestManagerTy RequestManager) {
+  void *Buffer = nullptr;
+  int64_t Size = 0;
+  RequestManager.receive(&Buffer, sizeof(void *), MPI_BYTE);
+  RequestManager.receive(&Size, 1, MPI_INT64_T);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  RequestManager.sendInBatchs(Buffer, Size);
+
+  // Event completion notification
+  RequestManager.send(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy exchangeSrc(MPIRequestManagerTy RequestManager) {
+  void *SrcBuffer;
+  int64_t Size;
+  int DstDevice;
+  // Save head node rank
+  int HeadNodeRank = RequestManager.OtherRank;
+
+  RequestManager.receive(&SrcBuffer, sizeof(void *), MPI_BYTE);
+  RequestManager.receive(&Size, 1, MPI_INT64_T);
+  RequestManager.receive(&DstDevice, 1, MPI_INT);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  // Set the Destination Rank in RequestManager
+  RequestManager.OtherRank = DstDevice;
+
+  // Send buffer to target device
+  RequestManager.sendInBatchs(SrcBuffer, Size);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  // Set the HeadNode Rank to send the final notificatin
+  RequestManager.OtherRank = HeadNodeRank;
+
+  // Event completion notification
+  RequestManager.send(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy exchangeDst(MPIRequestManagerTy RequestManager) {
+  void *DstBuffer;
+  int64_t Size;
+  int SrcDevice;
+  // Save head node rank
+  int HeadNodeRank = RequestManager.OtherRank;
+
+  RequestManager.receive(&DstBuffer, sizeof(void *), MPI_BYTE);
+  RequestManager.receive(&Size, 1, MPI_INT64_T);
+  RequestManager.receive(&SrcDevice, 1, MPI_INT);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  // Set the Source Rank in RequestManager
+  RequestManager.OtherRank = SrcDevice;
+
+  // Receive buffer from the Source device
+  RequestManager.receiveInBatchs(DstBuffer, Size);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  // Set the HeadNode Rank to send the final notificatin
+  RequestManager.OtherRank = HeadNodeRank;
+
+  // Event completion notification
+  RequestManager.send(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy execute(MPIRequestManagerTy RequestManager,
+                __tgt_device_image &DeviceImage) {
+
+  uint32_t NumArgs = 0;
+  RequestManager.receive(&NumArgs, 1, MPI_UINT32_T);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  llvm::SmallVector<void *> Args(NumArgs);
+  llvm::SmallVector<void *> ArgPtrs(NumArgs);
+
+  RequestManager.receive(Args.data(), NumArgs * sizeof(uintptr_t), MPI_BYTE);
+  void (*TargetFunc)(void) = nullptr;
+  RequestManager.receive(&TargetFunc, sizeof(uintptr_t), MPI_BYTE);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  // Get Args references
+  for (unsigned I = 0; I < NumArgs; I++) {
+    ArgPtrs[I] = &Args[I];
+  }
+
+  ffi_cif Cif{};
+  llvm::SmallVector<ffi_type *> ArgsTypes(NumArgs, &ffi_type_pointer);
+  ffi_status Status = ffi_prep_cif(&Cif, FFI_DEFAULT_ABI, NumArgs,
+                                   &ffi_type_void, &ArgsTypes[0]);
+
+  if (Status != FFI_OK) {
+    co_return createError("Error in ffi_prep_cif: %d", Status);
+  }
+
+  long Return;
+  ffi_call(&Cif, TargetFunc, &Return, &ArgPtrs[0]);
+
+  // Event completion notification
+  RequestManager.send(nullptr, 0, MPI_BYTE);
+
+  co_return (co_await RequestManager);
+}
+
+EventTy loadBinary(MPIRequestManagerTy RequestManager, DeviceImage &Image) {
+  // Receive the target table sizes.
+  size_t ImageSize = 0;
+  size_t EntryCount = 0;
+  RequestManager.receive(&ImageSize, 1, MPI_UINT64_T);
+  RequestManager.receive(&EntryCount, 1, MPI_UINT64_T);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  llvm::SmallVector<size_t> EntryNameSizes(EntryCount);
+
+  RequestManager.receive(EntryNameSizes.begin(), EntryCount, MPI_UINT64_T);
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  // Create the device name with the appropriate sizes and receive its content.
+  Image = DeviceImage(ImageSize, EntryCount);
+  Image.setImageEntries(EntryNameSizes);
+
+  // Received the image bytes and the table entries.
+  RequestManager.receive(Image.ImageStart, ImageSize, MPI_BYTE);
+
+  for (size_t I = 0; I < EntryCount; I++) {
+    RequestManager.receive(&Image.Entries[I].addr, 1, MPI_UINT64_T);
+    RequestManager.receive(Image.Entries[I].name, EntryNameSizes[I], MPI_CHAR);
+    RequestManager.receive(&Image.Entries[I].size, 1, MPI_UINT64_T);
+    RequestManager.receive(&Image.Entries[I].flags, 1, MPI_INT32_T);
+    RequestManager.receive(&Image.Entries[I].data, 1, MPI_INT32_T);
+  }
+
+  if (auto Error = co_await RequestManager; Error)
+    co_return Error;
+
+  // The code below is for CPU plugin only
+  // Create a temporary file.
+  char TmpFileName[] = "/tmp/tmpfile_XXXXXX";
----------------
cl3to wrote:

They are only removed on system restart if configured as such. I'll add file deletion at the end of the dynlib loading process.

https://github.com/llvm/llvm-project/pull/90890


More information about the llvm-commits mailing list