[llvm] [Offload] Add MPI Proxy Plugin (PR #114574)
Shilei Tian via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 1 10:24:13 PDT 2024
================
@@ -0,0 +1,1309 @@
+//===------RTLs/mpi/src/rtl.cpp - Target RTLs Implementation - C++ ------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// RTL NextGen for MPI applications
+//
+//===----------------------------------------------------------------------===//
+
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <cstring>
+#include <list>
+#include <optional>
+#include <string>
+#include <thread>
+#include <tuple>
+
+#include "Shared/APITypes.h"
+#include "Shared/Debug.h"
+#include "Utils/ELF.h"
+
+#include "EventSystem.h"
+#include "GlobalHandler.h"
+#include "OpenMP/OMPT/Callback.h"
+#include "PluginInterface.h"
+#include "omptarget.h"
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/BinaryFormat/ELF.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
+#include "llvm/Support/Error.h"
+
+#if !defined(__BYTE_ORDER__) || !defined(__ORDER_LITTLE_ENDIAN__) || \
+ !defined(__ORDER_BIG_ENDIAN__)
+#error "Missing preprocessor definitions for endianness detection."
+#endif
+
+#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+#define LITTLEENDIAN_CPU
+#elif defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
+#define BIGENDIAN_CPU
+#endif
+
+namespace llvm::omp::target::plugin {
+
+/// Forward declarations for all specialized data structures.
+struct MPIPluginTy;
+struct MPIDeviceTy;
+struct MPIDeviceImageTy;
+struct MPIKernelTy;
+class MPIGlobalHandlerTy;
+
+// TODO: Should this be defined inside the EventSystem?
+using MPIEventQueue = std::list<EventTy>;
+using MPIEventQueuePtr = MPIEventQueue *;
+
+/// Class implementing the MPI device images properties.
+struct MPIDeviceImageTy : public DeviceImageTy {
+ /// Create the MPI image with the id and the target image pointer.
+ MPIDeviceImageTy(int32_t ImageId, GenericDeviceTy &Device,
+ const __tgt_device_image *TgtImage)
+ : DeviceImageTy(ImageId, Device, TgtImage), DeviceImageAddrs(getSize()) {}
+
+ llvm::SmallVector<void *> DeviceImageAddrs;
+};
+
+class MPIGlobalHandlerTy final : public GenericGlobalHandlerTy {
+public:
+ Error getGlobalMetadataFromDevice(GenericDeviceTy &GenericDevice,
+ DeviceImageTy &Image,
+ GlobalTy &DeviceGlobal) override {
+ const char *GlobalName = DeviceGlobal.getName().data();
+ MPIDeviceImageTy &MPIImage = static_cast<MPIDeviceImageTy &>(Image);
+
+ if (GlobalName == nullptr) {
+ return Plugin::error("Failed to get name for global %p", &DeviceGlobal);
+ }
+
+ void *EntryAddress = nullptr;
+
+ __tgt_offload_entry *Begin = MPIImage.getTgtImage()->EntriesBegin;
+ __tgt_offload_entry *End = MPIImage.getTgtImage()->EntriesEnd;
+
+ int I = 0;
+ for (auto &Entry = Begin; Entry < End; ++Entry) {
+ if (!strcmp(Entry->name, GlobalName)) {
+ EntryAddress = MPIImage.DeviceImageAddrs[I];
+ break;
+ }
+ I++;
+ }
+
+ if (EntryAddress == nullptr) {
+ return Plugin::error("Failed to find global %s", GlobalName);
+ }
+
+ // Save the pointer to the symbol.
+ DeviceGlobal.setPtr(EntryAddress);
+
+ return Plugin::success();
+ }
+};
+
+struct MPIKernelTy : public GenericKernelTy {
+ /// Construct the kernel with a name and an execution mode.
+ MPIKernelTy(const char *Name) : GenericKernelTy(Name), Func(nullptr) {}
+
+ /// Initialize the kernel.
+ Error initImpl(GenericDeviceTy &Device, DeviceImageTy &Image) override {
+ // Functions have zero size.
+ GlobalTy Global(getName(), 0);
+
+ // Get the metadata (address) of the kernel function.
+ GenericGlobalHandlerTy &GHandler = Device.Plugin.getGlobalHandler();
+ if (auto Err = GHandler.getGlobalMetadataFromDevice(Device, Image, Global))
+ return Err;
+
+ // Check that the function pointer is valid.
+ if (!Global.getPtr())
+ return Plugin::error("Invalid function for kernel %s", getName());
+
+ // Save the function pointer.
+ Func = (void (*)())Global.getPtr();
+
+ // TODO: Check which settings are appropriate for the mpi plugin
+ // for now we are using the Elf64 plugin configuration
+ KernelEnvironment.Configuration.ExecMode = OMP_TGT_EXEC_MODE_GENERIC;
+ KernelEnvironment.Configuration.MayUseNestedParallelism = /* Unknown */ 2;
+ KernelEnvironment.Configuration.UseGenericStateMachine = /* Unknown */ 2;
+
+ // Set the maximum number of threads to a single.
+ MaxNumThreads = 1;
+ return Plugin::success();
+ }
+
+ /// Launch the kernel.
+ Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
+ uint64_t NumBlocks, KernelArgsTy &KernelArgs,
+ KernelLaunchParamsTy LaunchParams,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
+
+private:
+ /// The kernel function to execute.
+ void (*Func)(void);
+};
+
+/// MPI resource reference and queue. These are the objects handled by the
+/// MPIQueue Manager for the MPI plugin.
+template <typename ResourceTy>
+struct MPIResourceRef final : public GenericDeviceResourceRef {
+
+ /// The underlying handler type for the resource.
+ using HandleTy = ResourceTy *;
+
+ /// Create a empty reference to an invalid resource.
+ MPIResourceRef() : Resource(nullptr) {}
+
+ /// Create a reference to an existing resource.
+ MPIResourceRef(HandleTy Queue) : Resource(Queue) {}
+
+ /// Create a new resource and save the reference.
+ Error create(GenericDeviceTy &Device) override {
+ if (Resource)
+ return Plugin::error("Recreating an existing resource");
+
+ Resource = new ResourceTy;
+ if (!Resource)
+ return Plugin::error("Failed to allocated a new resource");
+
+ return Plugin::success();
+ }
+
+ /// Destroy the resource and invalidate the reference.
+ Error destroy(GenericDeviceTy &Device) override {
+ if (!Resource)
+ return Plugin::error("Destroying an invalid resource");
+
+ delete Resource;
+ Resource = nullptr;
+
+ return Plugin::success();
+ }
+
+ operator HandleTy() const { return Resource; }
+
+private:
+ HandleTy Resource;
+};
+
+/// Class implementing the device functionalities for remote x86_64 processes.
+struct MPIDeviceTy : public GenericDeviceTy {
+ /// Create a MPI Device with a device id and the default MPI grid values.
+ MPIDeviceTy(GenericPluginTy &Plugin, int32_t DeviceId, int32_t NumDevices)
+ : GenericDeviceTy(Plugin, DeviceId, NumDevices, MPIGridValues),
+ MPIEventQueueManager(*this), MPIEventManager(*this) {}
+
+ /// Initialize the device, its resources and get its properties.
+ Error initImpl(GenericPluginTy &Plugin) override {
+ if (auto Err = MPIEventQueueManager.init(OMPX_InitialNumStreams))
+ return Err;
+
+ if (auto Err = MPIEventManager.init(OMPX_InitialNumEvents))
+ return Err;
+
+ return Plugin::success();
+ }
+
+ /// Deinitizalize the device and release its resources.
+ Error deinitImpl() override {
+ if (auto Err = MPIEventQueueManager.deinit())
+ return Err;
+
+ if (auto Err = MPIEventManager.deinit())
+ return Err;
+
+ return Plugin::success();
+ }
+
+ Error setContext() override { return Plugin::success(); }
+
+ /// Load the binary image into the device and allocate an image object.
+ Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage,
+ int32_t ImageId) override {
+ // Allocate and initialize the image object.
+ MPIDeviceImageTy *Image = Plugin.allocate<MPIDeviceImageTy>();
+ new (Image) MPIDeviceImageTy(ImageId, *this, TgtImage);
+ return Image;
+ }
+
+ /// Allocate memory on the device or related to the device.
+ void *allocate(size_t Size, void *, TargetAllocTy Kind) override {
+ return nullptr;
+ }
+
+ /// Deallocate memory on the device or related to the device.
+ int free(void *TgtPtr, TargetAllocTy Kind) override {
+ return OFFLOAD_SUCCESS;
+ }
+
+ /// Submit data to the device (host to device transfer).
+ Error dataSubmitImpl(void *TgtPtr, const void *HstPtr, int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::success();
+ }
+
+ /// Retrieve data from the device (device to host transfer).
+ Error dataRetrieveImpl(void *HstPtr, const void *TgtPtr, int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::success();
+ }
+
+ /// Exchange data between two devices directly. In the MPI plugin, this
+ /// function will create an event for the host to tell the devices about the
+ /// exchange. Then, the devices will do the transfer themselves and let the
+ /// host know when it's done.
+ Error dataExchangeImpl(const void *SrcPtr, GenericDeviceTy &DstDev,
+ void *DstPtr, int64_t Size,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::success();
+ }
+
+ /// Allocate and construct a MPI kernel.
+ Expected<GenericKernelTy &> constructKernel(const char *Name) override {
+ // Allocate and construct the kernel.
+ MPIKernelTy *MPIKernel = Plugin.allocate<MPIKernelTy>();
+
+ if (!MPIKernel)
+ return Plugin::error("Failed to allocate memory for MPI kernel");
+
+ new (MPIKernel) MPIKernelTy(Name);
+
+ return *MPIKernel;
+ }
+
+ /// Create an event.
+ Error createEventImpl(void **EventStoragePtr) override {
+ return Plugin::success();
+ }
+
+ /// Destroy a previously created event.
+ Error destroyEventImpl(void *Event) override {
+ return MPIEventManager.returnResource(reinterpret_cast<EventTy *>(Event));
+ }
+
+ /// Record the event.
+ Error recordEventImpl(void *Event,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::success();
+ }
+
+ /// Make the queue wait on the event.
+ Error waitEventImpl(void *Event,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::success();
+ }
+
+ /// Synchronize the current thread with the event
+ Error syncEventImpl(void *Event) override { return Plugin::success(); }
+
+ /// Synchronize current thread with the pending operations on the async info.
+ Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
+ return Plugin::success();
+ }
+
+ /// Query for the completion of the pending operations on the async info.
+ Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override {
+ return Plugin::success();
+ }
+
+ Expected<void *> dataLockImpl(void *HstPtr, int64_t Size) override {
+ return HstPtr;
+ }
+
+ /// Indicate that the buffer is not pinned.
+ Expected<bool> isPinnedPtrImpl(void *HstPtr, void *&BaseHstPtr,
+ void *&BaseDevAccessiblePtr,
+ size_t &BaseSize) const override {
+ return false;
+ }
+
+ Error dataUnlockImpl(void *HstPtr) override { return Plugin::success(); }
+
+ /// This plugin should not setup the device environment or memory pool.
+ virtual bool shouldSetupDeviceEnvironment() const override { return false; };
+ virtual bool shouldSetupDeviceMemoryPool() const override { return false; };
+
+ /// Device memory limits are currently not applicable to the MPI plugin.
+ Error getDeviceStackSize(uint64_t &Value) override {
+ Value = 0;
+ return Plugin::success();
+ }
+
+ Error setDeviceStackSize(uint64_t Value) override {
+ return Plugin::success();
+ }
+
+ Error getDeviceHeapSize(uint64_t &Value) override {
+ Value = 0;
+ return Plugin::success();
+ }
+
+ Error setDeviceHeapSize(uint64_t Value) override { return Plugin::success(); }
+
+ /// Device interoperability. Not supported by MPI right now.
+ Error initAsyncInfoImpl(AsyncInfoWrapperTy &AsyncInfoWrapper) override {
+ return Plugin::error("initAsyncInfoImpl not supported");
+ }
+
+ /// This plugin does not support interoperability.
+ Error initDeviceInfoImpl(__tgt_device_info *DeviceInfo) override {
+ return Plugin::error("initDeviceInfoImpl not supported");
+ }
+
+ /// Print information about the device.
+ Error obtainInfoImpl(InfoQueueTy &Info) override {
+ // TODO: Add more information about the device.
+ Info.add("MPI plugin");
+ Info.add("MPI OpenMP Device Number", DeviceId);
+
+ return Plugin::success();
+ }
+
+ Error getQueue(AsyncInfoWrapperTy &AsyncInfoWrapper,
+ MPIEventQueuePtr &Queue) {
+ return Plugin::success();
+ }
+
+private:
+ using MPIEventQueueManagerTy =
+ GenericDeviceResourceManagerTy<MPIResourceRef<MPIEventQueue>>;
+ using MPIEventManagerTy =
+ GenericDeviceResourceManagerTy<MPIResourceRef<EventTy>>;
+
+ MPIEventQueueManagerTy MPIEventQueueManager;
+ MPIEventManagerTy MPIEventManager;
+
+ /// Grid values for the MPI plugin.
+ static constexpr GV MPIGridValues = {
+ 1, // GV_Slot_Size
+ 1, // GV_Warp_Size
+ 1, // GV_Max_Teams
+ 1, // GV_Default_Num_Teams
+ 1, // GV_SimpleBufferSize
+ 1, // GV_Max_WG_Size
+ 1, // GV_Default_WG_Size
+ };
+};
+
+Error MPIKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
+ uint32_t NumThreads, uint64_t NumBlocks,
+ KernelArgsTy &KernelArgs,
+ KernelLaunchParamsTy LaunchParams,
+ AsyncInfoWrapperTy &AsyncInfoWrapper) const {
+ return Plugin::success();
+}
+
+/// Class implementing the MPI plugin.
+struct MPIPluginTy : public GenericPluginTy {
+ MPIPluginTy() : GenericPluginTy(getTripleArch()) {}
+
+ /// This class should not be copied.
+ MPIPluginTy(const MPIPluginTy &) = delete;
+ MPIPluginTy(MPIPluginTy &&) = delete;
+
+ /// Initialize the plugin and return the number of devices.
+ Expected<int32_t> initImpl() override {
+ if (!EventSystem.is_initialized())
+ EventSystem.initialize();
+ int32_t NumRemoteDevices = getNumRemoteDevices();
+ assert(RemoteDevices.size() == 0 && "MPI Plugin already initialized");
+ RemoteDevices.resize(NumRemoteDevices, nullptr);
+ return NumRemoteDevices;
+ }
+
+ /// Deinitialize the plugin.
+ Error deinitImpl() override {
+ EventSystem.deinitialize();
+ return Plugin::success();
+ }
+
+ /// Creates a MPI device.
+ GenericDeviceTy *createDevice(GenericPluginTy &Plugin, int32_t DeviceId,
+ int32_t NumDevices) override {
+ return new MPIDeviceTy(Plugin, DeviceId, NumDevices);
+ }
+
+ /// Creates a MPI global handler.
+ GenericGlobalHandlerTy *createGlobalHandler() override {
+ return new MPIGlobalHandlerTy();
+ }
+
+ /// Get the ELF code to recognize the compatible binary images.
+ uint16_t getMagicElfBits() const override {
+ return utils::elf::getTargetMachine();
+ }
+
+ /// All images (ELF-compatible) should be compatible with this plugin.
+ Expected<bool> isELFCompatible(uint32_t DeviceID,
+ StringRef Image) const override {
+ return true;
+ }
+
+ Triple::ArchType getTripleArch() const override {
+#if defined(__x86_64__)
+ return llvm::Triple::x86_64;
+#elif defined(__s390x__)
+ return llvm::Triple::systemz;
+#elif defined(__aarch64__)
+#ifdef LITTLEENDIAN_CPU
+ return llvm::Triple::aarch64;
+#else
+ return llvm::Triple::aarch64_be;
+#endif
+#elif defined(__powerpc64__)
+#ifdef LITTLEENDIAN_CPU
+ return llvm::Triple::ppc64le;
+#else
+ return llvm::Triple::ppc64;
+#endif
+#else
+ return llvm::Triple::UnknownArch;
+#endif
+ }
+
+ Error getQueue(__tgt_async_info *AsyncInfoPtr, MPIEventQueuePtr &Queue) {
+ const std::lock_guard<std::mutex> Lock(MPIQueueMutex);
+ Queue = static_cast<MPIEventQueuePtr>(AsyncInfoPtr->Queue);
+ if (!Queue) {
+ Queue = new MPIEventQueue;
+ if (Queue == nullptr)
+ return Plugin::error("Failed to get Queue from AsyncInfoPtr %p\n",
+ AsyncInfoPtr);
+ // Modify the AsyncInfoWrapper to hold the new queue.
+ AsyncInfoPtr->Queue = Queue;
+ }
+ return Plugin::success();
+ }
+
+ Error returnQueue(MPIEventQueuePtr &Queue) {
+ const std::lock_guard<std::mutex> Lock(MPIQueueMutex);
+ if (Queue == nullptr)
+ return Plugin::error("Failed to return Queue: invalid Queue ptr");
+
+ delete Queue;
+
+ return Plugin::success();
+ }
+
+ const char *getName() const override { return GETNAME(TARGET_NAME); }
+
+ /// This plugin does not support exchanging data between two devices.
+ bool isDataExchangable(int32_t SrcDeviceId, int32_t DstDeviceId) override {
+ bool QueryResult = false;
+
+ int32_t SrcRank = -1, SrcDevId, DstRank = -1, DstDevId;
+
+ std::tie(SrcRank, SrcDevId) = EventSystem.mapDeviceId(SrcDeviceId);
+ std::tie(DstRank, DstDevId) = EventSystem.mapDeviceId(DstDeviceId);
+
+ // If the exchange is between different mpi processes, it is possible to
+ // perform the operation without consulting the devices
+ if ((SrcRank != -1) && (DstRank != -1) && (SrcRank != DstRank))
+ return true;
+
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::isDataExchangable, EventTypeTy::IS_DATA_EXCHANGABLE,
+ SrcDeviceId, DstDeviceId, &QueryResult);
+
+ if (Event.empty()) {
+ DP("Failed to create isDataExchangeble event in %d SrcDevice\n",
+ SrcDeviceId);
+ return false;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ DP("Failed to query isDataExchangeble from device %d SrcDevice: %s\n",
+ SrcDeviceId, toString(std::move(Error)).c_str());
+ return false;
+ }
+
+ return QueryResult;
+ }
+
+ /// Get the number of devices considering all devices per rank
+ int32_t getNumRemoteDevices() {
+ int32_t NumRemoteDevices = 0;
+ int32_t NumRanks = EventSystem.getNumWorkers();
+
+ for (int32_t RemoteRank = 0; RemoteRank < NumRanks; RemoteRank++) {
+ auto Event = EventSystem.createEvent(
+ OriginEvents::retrieveNumDevices, EventTypeTy::RETRIEVE_NUM_DEVICES,
+ RemoteRank, &EventSystem.DevicesPerRemote.emplace_back(0));
+
+ if (Event.empty()) {
+ DP("Error retrieving Num Devices from rank %d\n", RemoteRank);
+ return 0;
+ }
+
+ Event.wait();
+ if (auto Err = Event.getError())
+ DP("Error retrieving Num Devices from rank %d: %s\n", RemoteRank,
+ toString(std::move(Err)).c_str());
+
+ NumRemoteDevices += EventSystem.DevicesPerRemote[RemoteRank];
+ }
+
+ return NumRemoteDevices;
+ }
+
+ int32_t is_plugin_compatible(__tgt_device_image *Image) override {
+ if (!EventSystem.is_initialized())
+ EventSystem.initialize();
+
+ int NumRanks = EventSystem.getNumWorkers();
+ llvm::SmallVector<bool> QueryResults{};
+ bool QueryResult = true;
+ for (int RemoteRank = 0; RemoteRank < NumRanks; RemoteRank++) {
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::isPluginCompatible, EventTypeTy::IS_PLUGIN_COMPATIBLE,
+ RemoteRank, Image, &QueryResults.emplace_back(false));
+
+ if (Event.empty()) {
+ DP("Failed to create isPluginCompatible on Rank %d\n", RemoteRank);
+ QueryResults[RemoteRank] = false;
+ }
+
+ Event.wait();
+ if (auto Err = Event.getError()) {
+ DP("Error querying the binary compability on Rank %d\n", RemoteRank);
+ QueryResults[RemoteRank] = false;
+ }
+
+ QueryResult &= QueryResults[RemoteRank];
+ }
+
+ return QueryResult;
+ }
+
+ int32_t is_device_compatible(int32_t DeviceId,
+ __tgt_device_image *Image) override {
+ bool QueryResult = true;
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::isDeviceCompatible,
+ EventTypeTy::IS_DEVICE_COMPATIBLE,
+ DeviceId, Image, &QueryResult);
+
+ if (Event.empty()) {
+ DP("Failed to create isDeviceCompatible on Device %d\n", DeviceId);
+ }
+
+ Event.wait();
+ if (auto Err = Event.getError()) {
+ DP("Error querying the binary compability on Device %d\n", DeviceId);
+ }
+
+ return QueryResult;
+ }
+
+ int32_t is_device_initialized(int32_t DeviceId) const override {
+ return isValidDeviceId(DeviceId) && RemoteDevices[DeviceId] != nullptr;
+ }
+
+ int32_t init_device(int32_t DeviceId) override {
+ void *DevicePtr = nullptr;
+
+ EventTy Event =
+ EventSystem.createEvent(OriginEvents::initDevice,
+ EventTypeTy::INIT_DEVICE, DeviceId, &DevicePtr);
+
+ if (Event.empty()) {
+ REPORT("Error to create InitDevice Event for device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to initialize device %d: %s\n", DeviceId,
+ toString(std::move(Error)).data());
+ return 0;
+ }
+
+ RemoteDevices[DeviceId] = DevicePtr;
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t initialize_record_replay(int32_t DeviceId, int64_t MemorySize,
+ void *VAddr, bool isRecord, bool SaveOutput,
+ uint64_t &ReqPtrArgOffset) override {
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::initRecordReplay, EventTypeTy::INIT_RECORD_REPLAY,
+ DeviceId, MemorySize, VAddr, isRecord, SaveOutput, &ReqPtrArgOffset);
+
+ if (Event.empty()) {
+ REPORT("Error to create initRecordReplay Event for device %d\n",
+ DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("WARNING RR did not intialize RR-properly with %lu bytes"
+ "(Error: %s)\n",
+ MemorySize, toString(std::move(Error)).data());
+ if (!isRecord) {
+ return OFFLOAD_FAIL;
+ }
+ }
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t load_binary(int32_t DeviceId, __tgt_device_image *TgtImage,
+ __tgt_device_binary *Binary) override {
+ EventTy Event = EventSystem.createEvent(OriginEvents::loadBinary,
+ EventTypeTy::LOAD_BINARY, DeviceId,
+ TgtImage, Binary);
+
+ if (Event.empty()) {
+ REPORT("Failed to create loadBinary event for image %p", TgtImage);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError(); Error) {
+ REPORT("Event failed during loadBinary. %s\n",
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ DeviceImgPtrToDeviceId[Binary->handle] = DeviceId;
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ void *data_alloc(int32_t DeviceId, int64_t Size, void *HostPtr,
+ int32_t Kind) override {
+ if (Size == 0)
+ return nullptr;
+
+ void *TgtPtr = nullptr;
+ std::optional<Error> Err = std::nullopt;
+ EventTy Event;
+
+ switch (Kind) {
+ case TARGET_ALLOC_DEFAULT:
+ case TARGET_ALLOC_DEVICE:
+ case TARGET_ALLOC_DEVICE_NON_BLOCKING:
+ Event = EventSystem.createEvent(OriginEvents::allocateBuffer,
+ EventTypeTy::ALLOC, DeviceId, Size, Kind,
+ &TgtPtr);
+
+ if (Event.empty()) {
+ Err = Plugin::error("Failed to create alloc event with size %z", Size);
+ break;
+ }
+
+ Event.wait();
+ Err = Event.getError();
+ break;
+ case TARGET_ALLOC_HOST:
+ TgtPtr = memAllocHost(Size);
+ Err = Plugin::check(TgtPtr == nullptr, "Failed to allocate host memory");
+ break;
+ case TARGET_ALLOC_SHARED:
+ Err = Plugin::error("Incompatible memory type %d", Kind);
+ break;
+ }
+
+ if (*Err) {
+ REPORT("Failed to allocate data for HostPtr %p: %s\n", HostPtr,
+ toString(std::move(*Err)).c_str());
+ return nullptr;
+ }
+
+ return TgtPtr;
+ }
+
+ int32_t data_delete(int32_t DeviceId, void *TgtPtr, int32_t Kind) override {
+ if (TgtPtr == nullptr)
+ return OFFLOAD_SUCCESS;
+
+ std::optional<Error> Err = std::nullopt;
+ EventTy Event;
+
+ switch (Kind) {
+ case TARGET_ALLOC_DEFAULT:
+ case TARGET_ALLOC_DEVICE:
+ case TARGET_ALLOC_DEVICE_NON_BLOCKING:
+ Event =
+ EventSystem.createEvent(OriginEvents::deleteBuffer,
+ EventTypeTy::DELETE, DeviceId, TgtPtr, Kind);
+
+ if (Event.empty()) {
+ Err = Plugin::error("Failed to create data delete event for %p TgtPtr",
+ TgtPtr);
+ break;
+ }
+
+ Event.wait();
+ Err = Event.getError();
+ break;
+ case TARGET_ALLOC_HOST:
+ Err = Plugin::check(memFreeHost(TgtPtr), "Failed to free host memory");
+ break;
+ case TARGET_ALLOC_SHARED:
+ Err = createStringError(inconvertibleErrorCode(),
+ "Incompatible memory type %d", Kind);
+ break;
+ }
+
+ if (*Err) {
+ REPORT("Failed delete data at %p TgtPtr: %s\n", TgtPtr,
+ toString(std::move(*Err)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_lock(int32_t DeviceId, void *Ptr, int64_t Size,
+ void **LockedPtr) override {
+ EventTy Event =
+ EventSystem.createEvent(OriginEvents::dataLock, EventTypeTy::DATA_LOCK,
+ DeviceId, Ptr, Size, LockedPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create data lock event on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to lock memory %p: %s\n", Ptr,
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ if (!(*LockedPtr)) {
+ REPORT("Failure to lock memory %p: obtained a null locked pointer\n",
+ Ptr);
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_unlock(int32_t DeviceId, void *Ptr) override {
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::dataUnlock, EventTypeTy::DATA_UNLOCK, DeviceId, Ptr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create data unlock event on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to unlock memory %p: %s\n", Ptr,
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_notify_mapped(int32_t DeviceId, void *HstPtr,
+ int64_t Size) override {
+ EventTy Event = EventSystem.createEvent(OriginEvents::dataNotifyMapped,
+ EventTypeTy::DATA_NOTIFY_MAPPED,
+ DeviceId, HstPtr, Size);
+
+ if (Event.empty()) {
+ REPORT("Failed to create data notify mapped event on device %d\n",
+ DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to notify data mapped %p: %s\n", HstPtr,
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_notify_unmapped(int32_t DeviceId, void *HstPtr) override {
+ EventTy Event = EventSystem.createEvent(OriginEvents::dataNotifyUnmapped,
+ EventTypeTy::DATA_NOTIFY_UNMAPPED,
+ DeviceId, HstPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create data notify unmapped event on device %d\n",
+ DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to notify data unmapped %p: %s\n", HstPtr,
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_submit_async(int32_t DeviceId, void *TgtPtr, void *HstPtr,
+ int64_t Size,
+ __tgt_async_info *AsyncInfoPtr) override {
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfoPtr, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ // Copy HstData to a buffer with event-managed lifetime.
+ memAllocHost(Size);
+ void *SubmitBuffer = memAllocHost(Size);
+ std::memcpy(SubmitBuffer, HstPtr, Size);
+ EventDataHandleTy DataHandle(SubmitBuffer, &memFreeHost);
+
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::submit, EventTypeTy::SUBMIT, DeviceId, TgtPtr, DataHandle,
+ Size, AsyncInfoPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create dataSubmit event from %p HstPtr to %p TgtPtr\n",
+ HstPtr, TgtPtr);
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(Event);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_retrieve_async(int32_t DeviceId, void *HstPtr, void *TgtPtr,
+ int64_t Size,
+ __tgt_async_info *AsyncInfoPtr) override {
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfoPtr, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ EventTy Event =
+ EventSystem.createEvent(OriginEvents::retrieve, EventTypeTy::RETRIEVE,
+ DeviceId, Size, HstPtr, TgtPtr, AsyncInfoPtr);
+
+ if (Event.empty()) {
+ REPORT(
+ "Failed to create dataRetrieve event from %p TgtPtr to %p HstPtr\n",
+ TgtPtr, HstPtr);
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(Event);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t data_exchange_async(int32_t SrcDeviceId, void *SrcPtr,
+ int DstDeviceId, void *DstPtr, int64_t Size,
+ __tgt_async_info *AsyncInfo) override {
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfo, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ int32_t SrcRank, SrcDevId, DstRank, DstDevId;
+ EventTy Event;
+
+ std::tie(SrcRank, SrcDevId) = EventSystem.mapDeviceId(SrcDeviceId);
+ std::tie(DstRank, DstDevId) = EventSystem.mapDeviceId(DstDeviceId);
+
+ if (SrcRank == DstRank) {
+ Event = EventSystem.createEvent(
+ OriginEvents::localExchange, EventTypeTy::LOCAL_EXCHANGE, SrcDeviceId,
+ SrcPtr, DstDeviceId, DstPtr, Size, AsyncInfo);
+ }
+
+ else {
+ Event = EventSystem.createExchangeEvent(SrcDeviceId, SrcPtr, DstDeviceId,
+ DstPtr, Size, AsyncInfo);
+ }
+
+ if (Event.empty()) {
+ REPORT("Failed to create data exchange event from %d SrcDeviceId to %d "
+ "DstDeviceId\n",
+ SrcDeviceId, DstDeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(Event);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t launch_kernel(int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs,
+ ptrdiff_t *TgtOffsets, KernelArgsTy *KernelArgs,
+ __tgt_async_info *AsyncInfoPtr) override {
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfoPtr, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ uint32_t NumArgs = KernelArgs->NumArgs;
+
+ void *Args = memAllocHost(sizeof(void *) * NumArgs);
+ std::memcpy(Args, TgtArgs, sizeof(void *) * NumArgs);
+ EventDataHandleTy ArgsHandle(Args, &memFreeHost);
+
+ void *Offsets = memAllocHost(sizeof(ptrdiff_t) * NumArgs);
+ std::memcpy(Offsets, TgtOffsets, sizeof(ptrdiff_t) * NumArgs);
+ EventDataHandleTy OffsetsHandle(Offsets, &memFreeHost);
+
+ void *KernelArgsPtr = memAllocHost(sizeof(KernelArgsTy));
+ std::memcpy(KernelArgsPtr, KernelArgs, sizeof(KernelArgsTy));
+ EventDataHandleTy KernelArgsHandle(KernelArgsPtr, &memFreeHost);
+
+ EventTy Event = EventSystem.createEvent(
+ OriginEvents::launchKernel, EventTypeTy::LAUNCH_KERNEL, DeviceId,
+ TgtEntryPtr, ArgsHandle, OffsetsHandle, KernelArgsHandle, AsyncInfoPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create launchKernel event on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(Event);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t synchronize(int32_t DeviceId,
+ __tgt_async_info *AsyncInfoPtr) override {
+ MPIEventQueuePtr Queue =
+ reinterpret_cast<MPIEventQueuePtr>(AsyncInfoPtr->Queue);
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::synchronize,
+ EventTypeTy::SYNCHRONIZE, DeviceId,
+ AsyncInfoPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create synchronize event on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(Event);
+
+ for (auto &Event : *Queue) {
+ Event.wait();
+
+ if (auto Error = Event.getError(); Error) {
+ REPORT("Event failed during synchronization. %s\n",
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+ }
+
+ // Once the queue is synchronized, return it to the pool and reset the
+ // AsyncInfo. This is to make sure that the synchronization only works
+ // for its own tasks.
+ AsyncInfoPtr->Queue = nullptr;
+ if (auto Error = returnQueue(Queue)) {
+ REPORT("Failed to return async Queue: %s\n",
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t query_async(int32_t DeviceId,
+ __tgt_async_info *AsyncInfoPtr) override {
+ auto *Queue = reinterpret_cast<MPIEventQueue *>(AsyncInfoPtr->Queue);
+
+ // Returns success when there are pending operations in AsyncInfo, moving
+ // forward through the events on the queue until it is fully completed.
+ while (!Queue->empty()) {
+ auto &Event = Queue->front();
+
+ Event.resume();
+
+ if (!Event.done())
+ return OFFLOAD_SUCCESS;
+
+ if (auto Error = Event.getError(); Error) {
+ REPORT("Event failed during query. %s\n",
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+ Queue->pop_front();
+ }
+
+ // Once the queue is synchronized, return it to the pool and reset the
+ // AsyncInfo. This is to make sure that the synchronization only works
+ // for its own tasks.
+ AsyncInfoPtr->Queue = nullptr;
+ if (auto Error = returnQueue(Queue)) {
+ REPORT("Failed to return async Queue: %s\n",
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ void print_device_info(int32_t DeviceId) override {
+ EventTy Event =
+ EventSystem.createEvent(OriginEvents::printDeviceInfo,
+ EventTypeTy::PRINT_DEVICE_INFO, DeviceId);
+
+ if (Event.empty()) {
+ REPORT("Failed to create printDeviceInfo event on device %d\n", DeviceId);
+ return;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failure to print device %d info: %s\n", DeviceId,
+ toString(std::move(Error)).data());
+ }
+ }
+
+ int32_t create_event(int32_t DeviceId, void **EventPtr) override {
+ if (!EventPtr) {
+ REPORT("Failure to record event: Received invalid event pointer\n");
+ return OFFLOAD_FAIL;
+ }
+
+ EventTy *NewEvent = new EventTy;
+
+ if (NewEvent == nullptr) {
+ REPORT("Failed to createEvent\n");
+ return OFFLOAD_FAIL;
+ }
+
+ *EventPtr = reinterpret_cast<void *>(NewEvent);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t record_event(int32_t DeviceId, void *EventPtr,
+ __tgt_async_info *AsyncInfoPtr) override {
+ if (!EventPtr) {
+ REPORT("Failure to record event: Received invalid event pointer\n");
+ return OFFLOAD_FAIL;
+ }
+
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfoPtr, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ if (Queue->empty())
+ return OFFLOAD_SUCCESS;
+
+ auto &RecordedEvent = *reinterpret_cast<EventTy *>(EventPtr);
+ RecordedEvent = Queue->back();
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t wait_event(int32_t DeviceId, void *EventPtr,
+ __tgt_async_info *AsyncInfoPtr) override {
+ if (!EventPtr) {
+ REPORT("Failure to wait event: Received invalid event pointer\n");
+ return OFFLOAD_FAIL;
+ }
+
+ auto &RecordedEvent = *reinterpret_cast<EventTy *>(EventPtr);
+ auto SyncEvent = OriginEvents::sync(RecordedEvent);
+
+ MPIEventQueuePtr Queue = nullptr;
+ if (auto Error = getQueue(AsyncInfoPtr, Queue)) {
+ REPORT("Failed to get async Queue: %s\n",
+ toString(std::move(Error)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ Queue->push_back(SyncEvent);
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t sync_event(int32_t DeviceId, void *EventPtr) override {
+ if (!EventPtr) {
+ REPORT("Failure to wait event: Received invalid event pointer\n");
+ return OFFLOAD_FAIL;
+ }
+
+ auto &RecordedEvent = *reinterpret_cast<EventTy *>(EventPtr);
+ auto SyncEvent = OriginEvents::sync(RecordedEvent);
+
+ SyncEvent.wait();
+
+ if (auto Err = SyncEvent.getError()) {
+ REPORT("Failure to synchronize event %p: %s\n", EventPtr,
+ toString(std::move(Err)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t destroy_event(int32_t DeviceId, void *EventPtr) override {
+
+ if (!EventPtr) {
+ REPORT("Failure to destroy event: Received invalid event pointer\n");
+ return OFFLOAD_FAIL;
+ }
+
+ EventTy *MPIEventPtr = reinterpret_cast<EventTy *>(EventPtr);
+
+ delete MPIEventPtr;
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t init_async_info(int32_t DeviceId,
+ __tgt_async_info **AsyncInfoPtr) override {
+ assert(AsyncInfoPtr && "Invalid async info");
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::initAsyncInfo,
+ EventTypeTy::INIT_ASYNC_INFO,
+ DeviceId, AsyncInfoPtr);
+
+ if (Event.empty()) {
+ REPORT("Failed to create initAsyncInfo on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Err = Event.getError()) {
+ REPORT("Failure to initialize async info at " DPxMOD
+ " on device %d: %s\n",
+ DPxPTR(*AsyncInfoPtr), DeviceId, toString(std::move(Err)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t init_device_info(int32_t DeviceId, __tgt_device_info *DeviceInfo,
+ const char **ErrStr) override {
+ *ErrStr = "";
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::initDeviceInfo,
+ EventTypeTy::INIT_DEVICE_INFO,
+ DeviceId, DeviceInfo);
+
+ if (Event.empty()) {
+ REPORT("Failed to create initDeviceInfo on device %d\n", DeviceId);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Err = Event.getError()) {
+ REPORT("Failure to initialize device info at " DPxMOD
+ " on device %d: %s\n",
+ DPxPTR(DeviceInfo), DeviceId, toString(std::move(Err)).data());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t use_auto_zero_copy(int32_t DeviceId) override { return false; }
+
+ int32_t get_global(__tgt_device_binary Binary, uint64_t Size,
+ const char *Name, void **DevicePtr) override {
+ int32_t DeviceId = DeviceImgPtrToDeviceId[Binary.handle];
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::getGlobal,
+ EventTypeTy::GET_GLOBAL, DeviceId,
+ Binary, Size, Name, DevicePtr);
+ if (Event.empty()) {
+ REPORT("Failed to create getGlobal event on device %d\n", 0);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failed to get Global on device %d: %s\n", 0,
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+ int32_t get_function(__tgt_device_binary Binary, const char *Name,
+ void **KernelPtr) override {
+
+ int32_t DeviceId = DeviceImgPtrToDeviceId[Binary.handle];
+
+ EventTy Event = EventSystem.createEvent(OriginEvents::getFunction,
+ EventTypeTy::GET_FUNCTION, DeviceId,
+ Binary, Name, KernelPtr);
+ if (Event.empty()) {
+ REPORT("Failed to create getFunction event on device %d\n", 0);
+ return OFFLOAD_FAIL;
+ }
+
+ Event.wait();
+
+ if (auto Error = Event.getError()) {
+ REPORT("Failed to get function on device %d: %s\n", 0,
+ toString(std::move(Error)).c_str());
+ return OFFLOAD_FAIL;
+ }
+
+ return OFFLOAD_SUCCESS;
+ }
+
+private:
+ std::mutex MPIQueueMutex;
+ llvm::DenseMap<uintptr_t, int32_t> DeviceImgPtrToDeviceId;
+ llvm::SmallVector<void *> RemoteDevices;
+ EventSystemTy EventSystem;
+};
+
+template <typename... ArgsTy>
+static Error Plugin::check(int32_t ErrorCode, const char *ErrFmt,
+ ArgsTy... Args) {
+ if (ErrorCode == OFFLOAD_SUCCESS)
+ return Error::success();
+
+ return createStringError<ArgsTy..., const char *>(
+ inconvertibleErrorCode(), ErrFmt, Args...,
+ std::to_string(ErrorCode).data());
+}
+
+} // namespace llvm::omp::target::plugin
+
+extern "C" {
+llvm::omp::target::plugin::GenericPluginTy *createPlugin_mpi() {
+ return new llvm::omp::target::plugin::MPIPluginTy();
+}
+}
----------------
shiltian wrote:
empty line at the end of file
https://github.com/llvm/llvm-project/pull/114574
More information about the llvm-commits
mailing list