[llvm] [Offload] Add framework for math conformance tests (PR #149242)
Joseph Huber via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 24 09:29:01 PDT 2025
================
@@ -0,0 +1,286 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the implementation of helpers and non-template member
+/// functions for the DeviceContext class.
+///
+//===----------------------------------------------------------------------===//
+
+#include "mathtest/DeviceContext.hpp"
+
+#include "mathtest/ErrorHandling.hpp"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
+
+#include <OffloadAPI.h>
+#include <cstddef>
+#include <memory>
+#include <optional>
+#include <string>
+#include <system_error>
+#include <vector>
+
+using namespace mathtest;
+
+//===----------------------------------------------------------------------===//
+// Helpers
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// The static 'Wrapper' instance ensures olInit() is called once at program
+// startup and olShutDown() is called once at program termination
+struct OffloadInitWrapper {
+ OffloadInitWrapper() { OL_CHECK(olInit()); }
+ ~OffloadInitWrapper() { OL_CHECK(olShutDown()); }
+};
+static OffloadInitWrapper Wrapper{};
+
+[[nodiscard]] std::string getDeviceName(ol_device_handle_t DeviceHandle) {
+ std::size_t PropSize = 0;
+ OL_CHECK(olGetDeviceInfoSize(DeviceHandle, OL_DEVICE_INFO_NAME, &PropSize));
+
+ if (PropSize == 0)
+ return "";
+
+ std::string PropValue(PropSize, '\0');
+ OL_CHECK(olGetDeviceInfo(DeviceHandle, OL_DEVICE_INFO_NAME, PropSize,
+ PropValue.data()));
+ PropValue.pop_back(); // Remove the null terminator
+
+ return PropValue;
+}
+
+[[nodiscard]] ol_platform_handle_t
+getDevicePlatform(ol_device_handle_t DeviceHandle) noexcept {
+ ol_platform_handle_t PlatformHandle = nullptr;
+ OL_CHECK(olGetDeviceInfo(DeviceHandle, OL_DEVICE_INFO_PLATFORM,
+ sizeof(PlatformHandle), &PlatformHandle));
+ return PlatformHandle;
+}
+
+[[nodiscard]] std::string getPlatformName(ol_platform_handle_t PlatformHandle) {
+ std::size_t PropSize = 0;
+ OL_CHECK(
+ olGetPlatformInfoSize(PlatformHandle, OL_PLATFORM_INFO_NAME, &PropSize));
+
+ if (PropSize == 0)
+ return "";
+
+ std::string PropValue(PropSize, '\0');
+ OL_CHECK(olGetPlatformInfo(PlatformHandle, OL_PLATFORM_INFO_NAME, PropSize,
+ PropValue.data()));
+ PropValue.pop_back(); // Remove the null terminator
+
+ return llvm::StringRef(PropValue).lower();
+}
+
+[[nodiscard]] ol_platform_backend_t
+getPlatformBackend(ol_platform_handle_t PlatformHandle) noexcept {
+ ol_platform_backend_t Backend = OL_PLATFORM_BACKEND_UNKNOWN;
+ OL_CHECK(olGetPlatformInfo(PlatformHandle, OL_PLATFORM_INFO_BACKEND,
+ sizeof(Backend), &Backend));
+ return Backend;
+}
+
+struct Device {
+ ol_device_handle_t Handle;
+ std::string Name;
+ std::string Platform;
+ ol_platform_backend_t Backend;
+};
+
+const std::vector<Device> &getDevices() {
+ // Thread-safe initialization of a static local variable
+ static auto Devices = []() {
+ std::vector<Device> TmpDevices;
+
+ // Discovers all devices that are not the host
+ const auto *const ResultFromIterate = olIterateDevices(
+ [](ol_device_handle_t DeviceHandle, void *Data) {
+ ol_platform_handle_t PlatformHandle = getDevicePlatform(DeviceHandle);
+ ol_platform_backend_t Backend = getPlatformBackend(PlatformHandle);
+
+ if (Backend != OL_PLATFORM_BACKEND_HOST) {
+ auto Name = getDeviceName(DeviceHandle);
+ auto Platform = getPlatformName(PlatformHandle);
+
+ static_cast<std::vector<Device> *>(Data)->push_back(
+ {DeviceHandle, Name, Platform, Backend});
+ }
+
+ return true;
+ },
+ &TmpDevices);
+
+ OL_CHECK(ResultFromIterate);
+
+ return TmpDevices;
+ }();
+
+ return Devices;
+}
+} // namespace
+
+const llvm::SetVector<llvm::StringRef> &mathtest::getPlatforms() {
+ // Thread-safe initialization of a static local variable
+ static auto Platforms = []() {
+ llvm::SetVector<llvm::StringRef> TmpPlatforms;
+
+ for (const auto &Device : getDevices())
+ TmpPlatforms.insert(Device.Platform);
+
+ return TmpPlatforms;
+ }();
+
+ return Platforms;
+}
+
+void detail::allocManagedMemory(ol_device_handle_t DeviceHandle,
+ std::size_t Size,
+ void **AllocationOut) noexcept {
+ OL_CHECK(
+ olMemAlloc(DeviceHandle, OL_ALLOC_TYPE_MANAGED, Size, AllocationOut));
+}
+
+//===----------------------------------------------------------------------===//
+// DeviceContext
+//===----------------------------------------------------------------------===//
+
+DeviceContext::DeviceContext(std::size_t GlobalDeviceId)
+ : GlobalDeviceId(GlobalDeviceId), DeviceHandle(nullptr) {
+ const auto &Devices = getDevices();
+
+ if (GlobalDeviceId >= Devices.size())
+ FATAL_ERROR("Invalid GlobalDeviceId: " + llvm::Twine(GlobalDeviceId) +
+ ", but the number of available devices is " +
+ llvm::Twine(Devices.size()));
+
+ DeviceHandle = Devices[GlobalDeviceId].Handle;
+}
+
+DeviceContext::DeviceContext(llvm::StringRef Platform, std::size_t DeviceId)
+ : DeviceHandle(nullptr) {
+ std::string NormalizedPlatform = Platform.lower();
+ const auto &Platforms = getPlatforms();
+
+ if (!Platforms.contains(NormalizedPlatform))
+ FATAL_ERROR("There is no platform that matches with '" +
+ llvm::Twine(Platform) +
+ "'. Available platforms are: " + llvm::join(Platforms, ", "));
+
+ const auto &Devices = getDevices();
+
+ std::optional<std::size_t> FoundGlobalDeviceId;
+ std::size_t MatchCount = 0;
+
+ for (std::size_t Index = 0; Index < Devices.size(); ++Index) {
+ if (Devices[Index].Platform == NormalizedPlatform) {
+ if (MatchCount == DeviceId) {
+ FoundGlobalDeviceId = Index;
+ break;
+ }
+ MatchCount++;
+ }
+ }
+
+ if (!FoundGlobalDeviceId.has_value())
+ FATAL_ERROR("Invalid DeviceId: " + llvm::Twine(DeviceId) +
+ ", but the number of available devices on '" + Platform +
+ "' is " + llvm::Twine(MatchCount));
+
+ GlobalDeviceId = FoundGlobalDeviceId.value();
+ DeviceHandle = Devices[GlobalDeviceId].Handle;
+}
+
+[[nodiscard]] std::shared_ptr<DeviceImage>
+DeviceContext::loadBinary(llvm::StringRef Directory, llvm::StringRef BinaryName,
+ llvm::StringRef Extension) const {
+ llvm::SmallString<128> FullPath(Directory);
+ llvm::sys::path::append(FullPath, llvm::Twine(BinaryName) + Extension);
+
+ // For simplicity, this implementation intentionally reads the binary from
+ // disk on every call.
+ //
+ // Other use cases could benefit from a global, thread-safe cache to avoid
+ // redundant file I/O and GPU program creation.
+
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
+ llvm::MemoryBuffer::getFile(FullPath);
+ if (std::error_code ErrorCode = FileOrErr.getError())
+ FATAL_ERROR(llvm::Twine("Failed to read device binary file '") + FullPath +
+ "': " + ErrorCode.message());
+
+ std::unique_ptr<llvm::MemoryBuffer> &BinaryData = *FileOrErr;
+
+ ol_program_handle_t ProgramHandle = nullptr;
+ OL_CHECK(olCreateProgram(DeviceHandle, BinaryData->getBufferStart(),
----------------
jhuber6 wrote:
You can do this manually with the LLVM ELF interface, simply load it as an elf and check the header flags for EM_AMDGPU or EM_CUDA.
https://github.com/llvm/llvm-project/pull/149242
More information about the llvm-commits
mailing list