[llvm] [Offload] Add MPI Plugin (PR #90890)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Thu May 2 12:57:46 PDT 2024


================
@@ -0,0 +1,685 @@
+//===------RTLs/mpi/src/rtl.cpp - Target RTLs Implementation - C++ ------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// RTL NextGen for MPI applications
+//
+//===----------------------------------------------------------------------===//
+
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <optional>
+#include <string>
+
+#include "GlobalHandler.h"
+#include "OpenMP/OMPT/Callback.h"
+#include "PluginInterface.h"
+#include "Shared/Debug.h"
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/BinaryFormat/ELF.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
+#include "llvm/Support/Error.h"
+#include "llvm/TargetParser/Triple.h"
+
+#include "EventSystem.h"
+
+namespace llvm::omp::target::plugin {
+
+/// Forward declarations for all specialized data structures.
+struct MPIPluginTy;
+struct MPIDeviceTy;
+struct MPIDeviceImageTy;
+struct MPIKernelTy;
+class MPIGlobalHandlerTy;
+
+// TODO: Should this be defined inside the EventSystem?
+using MPIEventQueue = SmallVector<EventTy>;
+using MPIEventQueuePtr = MPIEventQueue *;
+
+/// Class implementing the MPI device images properties.
+struct MPIDeviceImageTy : public DeviceImageTy {
+  /// Create the MPI image with the id and the target image pointer.
+  MPIDeviceImageTy(int32_t ImageId, GenericDeviceTy &Device,
+                   const __tgt_device_image *TgtImage)
+      : DeviceImageTy(ImageId, Device, TgtImage), DeviceImageAddrs(getSize()) {}
+
+  llvm::SmallVector<void *> DeviceImageAddrs;
+};
+
+class MPIGlobalHandlerTy final : public GenericGlobalHandlerTy {
+public:
+  Error getGlobalMetadataFromDevice(GenericDeviceTy &GenericDevice,
+                                    DeviceImageTy &Image,
+                                    GlobalTy &DeviceGlobal) override {
+    const char *GlobalName = DeviceGlobal.getName().data();
+    MPIDeviceImageTy &MPIImage = static_cast<MPIDeviceImageTy &>(Image);
+
+    if (GlobalName == nullptr) {
+      return Plugin::error("Failed to get name for global %p", &DeviceGlobal);
+    }
+
+    void *EntryAddress = nullptr;
+
+    __tgt_offload_entry *Begin = MPIImage.getTgtImage()->EntriesBegin;
+    __tgt_offload_entry *End = MPIImage.getTgtImage()->EntriesEnd;
+
+    int I = 0;
+    for (auto &Entry = Begin; Entry < End; ++Entry) {
+      if (!strcmp(Entry->name, GlobalName)) {
+        EntryAddress = MPIImage.DeviceImageAddrs[I];
+        break;
+      }
+      I++;
+    }
+
+    if (EntryAddress == nullptr) {
+      return Plugin::error("Failed to find global %s", GlobalName);
+    }
+
+    // Save the pointer to the symbol.
+    DeviceGlobal.setPtr(EntryAddress);
+
+    return Plugin::success();
+  }
+};
+
+struct MPIKernelTy : public GenericKernelTy {
+  /// Construct the kernel with a name and an execution mode.
+  MPIKernelTy(const char *Name, EventSystemTy &EventSystem)
+      : GenericKernelTy(Name), Func(nullptr), EventSystem(EventSystem) {}
+
+  /// Initialize the kernel.
+  Error initImpl(GenericDeviceTy &Device, DeviceImageTy &Image) override {
+    // Functions have zero size.
+    GlobalTy Global(getName(), 0);
+
+    // Get the metadata (address) of the kernel function.
+    GenericGlobalHandlerTy &GHandler = Device.Plugin.getGlobalHandler();
+    if (auto Err = GHandler.getGlobalMetadataFromDevice(Device, Image, Global))
+      return Err;
+
+    // Check that the function pointer is valid.
+    if (!Global.getPtr())
+      return Plugin::error("Invalid function for kernel %s", getName());
+
+    // Save the function pointer.
+    Func = (void (*)())Global.getPtr();
+
+    // TODO: Check which settings are appropriate for the mpi plugin
+    // for now we are using the Elf64 plugin configuration
+    KernelEnvironment.Configuration.ExecMode = OMP_TGT_EXEC_MODE_GENERIC;
+    KernelEnvironment.Configuration.MayUseNestedParallelism = /* Unknown */ 2;
+    KernelEnvironment.Configuration.UseGenericStateMachine = /* Unknown */ 2;
+
+    // Set the maximum number of threads to a single.
+    MaxNumThreads = 1;
+    return Plugin::success();
+  }
+
+  /// Launch the kernel.
+  Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
+                   uint64_t NumBlocks, KernelArgsTy &KernelArgs, void *Args,
+                   AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
+
+private:
+  /// The kernel function to execute.
+  void (*Func)(void);
+  EventSystemTy &EventSystem;
+};
+
+// MPI resource reference and queue
+// =============================================================================
+template <typename ResourceTy>
+struct MPIResourceRef final : public GenericDeviceResourceRef {
+
+  // The underlying handler type for the resource.
+  using HandleTy = ResourceTy *;
+
+  // Create a empty reference to an invalid resource.
+  MPIResourceRef() : Resource(nullptr) {}
+
+  // Create a reference to an existing resource.
+  MPIResourceRef(HandleTy Queue) : Resource(Queue) {}
+
+  // Create a new resource and save the reference.
+  Error create(GenericDeviceTy &Device) override {
+    if (Resource)
+      return Plugin::error("Recreating an existing resource");
+
+    Resource = new ResourceTy;
+    if (!Resource)
+      return Plugin::error("Failed to allocated a new resource");
+
+    return Plugin::success();
+  }
+
+  // Destroy the resource and invalidate the reference.
+  Error destroy(GenericDeviceTy &Device) override {
+    if (!Resource)
+      return Plugin::error("Destroying an invalid resource");
+
+    delete Resource;
+    Resource = nullptr;
+
+    return Plugin::success();
+  }
+
+  operator HandleTy() const { return Resource; }
+
+private:
+  HandleTy Resource;
+};
+
+// Device class
+// =============================================================================
+struct MPIDeviceTy : public GenericDeviceTy {
+  MPIDeviceTy(GenericPluginTy &Plugin, int32_t DeviceId, int32_t NumDevices,
+              EventSystemTy &EventSystem)
+      : GenericDeviceTy(Plugin, DeviceId, NumDevices, MPIGridValues),
+        MPIEventQueueManager(*this), MPIEventManager(*this),
+        EventSystem(EventSystem) {}
+
+  Error initImpl(GenericPluginTy &Plugin) override {
+    // TODO: Check if EventQueueManager is equivalent to StreamManager.
+    if (auto Err = MPIEventQueueManager.init(OMPX_InitialNumStreams))
+      return Err;
+
+    if (auto Err = MPIEventManager.init(OMPX_InitialNumEvents))
+      return Err;
+
+    return Plugin::success();
+  }
+
+  Error deinitImpl() override {
+    if (auto Err = MPIEventQueueManager.deinit())
+      return Err;
+
+    if (auto Err = MPIEventManager.deinit())
+      return Err;
+
+    return Plugin::success();
+  }
+
+  Error setContext() override { return Plugin::success(); }
+
+  /// Load the binary image into the device and allocate an image object.
+  Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage,
+                                           int32_t ImageId) override {
+
+    // Allocate and initialize the image object.
+    MPIDeviceImageTy *Image = Plugin.allocate<MPIDeviceImageTy>();
+    new (Image) MPIDeviceImageTy(ImageId, *this, TgtImage);
+
+    auto Event = EventSystem.createEvent(OriginEvents::loadBinary, DeviceId,
+                                         TgtImage, &(Image->DeviceImageAddrs));
+
+    if (Event.empty()) {
+      return Plugin::error("Failed to create loadBinary event for image %p",
+                           TgtImage);
+    }
+
+    Event.wait();
+
+    if (auto Error = Event.getError(); Error) {
+      return Plugin::error("Event failed during loadBinary. %s\n",
+                           toString(std::move(Error)).c_str());
+    }
+
+    return Image;
+  }
+
+  // Data management
+  // ===========================================================================
+  void *allocate(size_t Size, void *, TargetAllocTy Kind) override {
+    if (Size == 0)
+      return nullptr;
+
+    void *BufferAddress = nullptr;
+    std::optional<Error> Err = std::nullopt;
+    EventTy Event{nullptr};
+
+    switch (Kind) {
+    case TARGET_ALLOC_DEFAULT:
+    case TARGET_ALLOC_DEVICE:
+    case TARGET_ALLOC_DEVICE_NON_BLOCKING:
+      Event = EventSystem.createEvent(OriginEvents::allocateBuffer, DeviceId,
+                                      Size, &BufferAddress);
+
+      if (Event.empty()) {
+        Err = Plugin::error("Failed to create alloc event with size %z", Size);
+        break;
+      }
+
+      Event.wait();
+      Err = Event.getError();
+      break;
+    case TARGET_ALLOC_HOST:
+    case TARGET_ALLOC_SHARED:
+      Err = Plugin::error("Incompatible memory type %d", Kind);
+      break;
+    }
+
+    if (*Err) {
+      REPORT("Failed to allocate memory: %s\n",
+             toString(std::move(*Err)).c_str());
+      return nullptr;
+    }
+
+    return BufferAddress;
+  }
+
+  int free(void *TgtPtr, TargetAllocTy Kind) override {
+    if (TgtPtr == nullptr)
+      return OFFLOAD_SUCCESS;
+
+    std::optional<Error> Err = std::nullopt;
+    EventTy Event{nullptr};
+
+    switch (Kind) {
+    case TARGET_ALLOC_DEFAULT:
+    case TARGET_ALLOC_DEVICE:
+    case TARGET_ALLOC_DEVICE_NON_BLOCKING:
+      Event =
+          EventSystem.createEvent(OriginEvents::deleteBuffer, DeviceId, TgtPtr);
+
+      if (Event.empty()) {
+        Err = Plugin::error("Failed to create delete event");
+        break;
+      }
+
+      Event.wait();
+      Err = Event.getError();
+      break;
+    case TARGET_ALLOC_HOST:
+    case TARGET_ALLOC_SHARED:
+      Err = createStringError(inconvertibleErrorCode(),
+                              "Incompatible memory type %d", Kind);
+      break;
+    }
+
+    if (*Err) {
+      REPORT("Failed to free memory: %s\n", toString(std::move(*Err)).c_str());
+      return OFFLOAD_FAIL;
+    }
+
+    return OFFLOAD_SUCCESS;
+  }
+
+  // Data transfer
+  // ===========================================================================
+  Error dataSubmitImpl(void *TgtPtr, const void *HstPtr, int64_t Size,
+                       AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+    MPIEventQueuePtr Queue = nullptr;
+    if (auto Err = getQueue(AsyncInfoWrapper, Queue))
+      return Err;
+
+    // Copy HstData to a buffer with event-managed lifetime.
+    void *SubmitBuffer = std::malloc(Size);
+    std::memcpy(SubmitBuffer, HstPtr, Size);
+    EventDataHandleTy DataHandle(SubmitBuffer, &std::free);
+
+    auto Event = EventSystem.createEvent(OriginEvents::submit, DeviceId,
+                                         DataHandle, TgtPtr, Size);
+
+    if (Event.empty())
+      return Plugin::error("Failed to create submit event");
+
+    Queue->push_back(Event);
+
+    return Plugin::success();
+  }
+
+  Error dataRetrieveImpl(void *HstPtr, const void *TgtPtr, int64_t Size,
+                         AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+    MPIEventQueuePtr Queue = nullptr;
+    if (auto Err = getQueue(AsyncInfoWrapper, Queue))
+      return Err;
+
+    auto Event = EventSystem.createEvent(OriginEvents::retrieve, DeviceId,
+                                         HstPtr, TgtPtr, Size);
+
+    if (Event.empty())
+      return Plugin::error("Failed to create retrieve event");
+
+    Queue->push_back(Event);
+
+    return Plugin::success();
+  }
+
+  Error dataExchangeImpl(const void *SrcPtr, GenericDeviceTy &DstDev,
+                         void *DstPtr, int64_t Size,
+                         AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+    MPIEventQueuePtr Queue = nullptr;
+    if (auto Err = getQueue(AsyncInfoWrapper, Queue))
+      return Err;
+
+    auto Event = EventSystem.createExchangeEvent(
+        DeviceId, SrcPtr, DstDev.getDeviceId(), DstPtr, Size);
+
+    if (Event.empty())
+      return Plugin::error("Failed to create exchange event");
+
+    Queue->push_back(Event);
+
+    return Plugin::success();
+  }
+
+  // Allocate and construct a MPI kernel.
+  // ===========================================================================
+  Expected<GenericKernelTy &> constructKernel(const char *Name) override {
+    // Allocate and construct the kernel.
+    MPIKernelTy *MPIKernel = Plugin.allocate<MPIKernelTy>();
+
+    if (!MPIKernel)
+      return Plugin::error("Failed to allocate memory for MPI kernel");
+
+    new (MPIKernel) MPIKernelTy(Name, EventSystem);
+
+    return *MPIKernel;
+  }
+
+  // External event management
+  // ===========================================================================
+  Error createEventImpl(void **EventStoragePtr) override {
+    if (!EventStoragePtr)
+      return Plugin::error("Received invalid event storage pointer");
+
+    EventTy **NewEvent = reinterpret_cast<EventTy **>(EventStoragePtr);
+    auto Err = MPIEventManager.getResource(*NewEvent);
+    if (Err)
+      return Plugin::error("Could not allocate a new synchronization event");
+
+    return Plugin::success();
+  }
+
+  Error destroyEventImpl(void *Event) override {
+    if (!Event)
+      return Plugin::error("Received invalid event pointer");
+
+    return MPIEventManager.returnResource(reinterpret_cast<EventTy *>(Event));
+  }
+
+  Error recordEventImpl(void *Event,
+                        AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+    if (!Event)
+      return Plugin::error("Received invalid event pointer");
+
+    MPIEventQueuePtr Queue = nullptr;
+    if (auto Err = getQueue(AsyncInfoWrapper, Queue))
+      return Err;
+
+    if (Queue->empty())
+      return Plugin::success();
+
+    auto &RecordedEvent = *reinterpret_cast<EventTy *>(Event);
+    RecordedEvent = Queue->back();
+
+    return Plugin::success();
+  }
+
+  Error waitEventImpl(void *Event,
+                      AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+    if (!Event)
+      return Plugin::error("Received invalid event pointer");
+
+    auto &RecordedEvent = *reinterpret_cast<EventTy *>(Event);
+    auto SyncEvent = OriginEvents::sync(RecordedEvent);
+
+    MPIEventQueuePtr Queue = nullptr;
+    if (auto Err = getQueue(AsyncInfoWrapper, Queue))
+      return Err;
+
+    Queue->push_back(SyncEvent);
+
+    return Plugin::success();
+  }
+
+  Error syncEventImpl(void *Event) override {
+    if (!Event)
+      return Plugin::error("Received invalid event pointer");
+
+    auto &RecordedEvent = *reinterpret_cast<EventTy *>(Event);
+    auto SyncEvent = OriginEvents::sync(RecordedEvent);
+
+    SyncEvent.wait();
+
+    return SyncEvent.getError();
+  }
+
+  // Asynchronous queue management
+  // ===========================================================================
+  Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
+    auto *Queue = reinterpret_cast<MPIEventQueue *>(AsyncInfo.Queue);
+
+    for (auto &Event : *Queue) {
+      Event.wait();
+
+      if (auto Error = Event.getError(); Error)
+        return Plugin::error("Event failed during synchronization. %s\n",
+                             toString(std::move(Error)).c_str());
+    }
+
+    // Once the queue is synchronized, return it to the pool and reset the
+    // AsyncInfo. This is to make sure that the synchronization only works
+    // for its own tasks.
+    AsyncInfo.Queue = nullptr;
+    return MPIEventQueueManager.returnResource(Queue);
+  }
+
+  Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override {
+    auto *Queue = reinterpret_cast<MPIEventQueue *>(AsyncInfo.Queue);
+
+    // Returns success when there are pending operations in the AsyncInfo.
+    if (!Queue->empty() && !Queue->back().done()) {
+      return Plugin::success();
+    }
----------------
shiltian wrote:

```suggestion
    if (!Queue->empty() && !Queue->back().done())
      return Plugin::success();
```

https://github.com/llvm/llvm-project/pull/90890


More information about the llvm-commits mailing list