[llvm] [Offload] Add `MAX_WORK_GROUP_SIZE` device info query (PR #143718)
Ross Brunton via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 11 07:58:39 PDT 2025
https://github.com/RossBrunton created https://github.com/llvm/llvm-project/pull/143718
This adds a new device info query for the maximum workgroup/block size
for each dimension. Since this returns three values, a new `ol_range_t`
type was added as an `{x, y, z}` triplet. Device info handling and
struct printing was also updated to handle it.
>From e25b535eeebfb90ac526ea46da5b036e57d5b0ef Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Wed, 11 Jun 2025 15:55:09 +0100
Subject: [PATCH] [Offload] Add `MAX_WORK_GROUP_SIZE` device info query
This adds a new device info query for the maximum workgroup/block size
for each dimension. Since this returns three values, a new `ol_range_t`
type was added as an `{x, y, z}` triplet. Device info handling and
struct printing was also updated to handle it.
---
offload/liboffload/API/Common.td | 10 +++
offload/liboffload/API/Device.td | 3 +-
offload/liboffload/src/OffloadImpl.cpp | 69 +++++++++++++++----
offload/tools/offload-tblgen/PrintGen.cpp | 5 ++
.../OffloadAPI/device/olGetDeviceInfo.cpp | 9 +++
.../OffloadAPI/device/olGetDeviceInfoSize.cpp | 8 +++
6 files changed, 91 insertions(+), 13 deletions(-)
diff --git a/offload/liboffload/API/Common.td b/offload/liboffload/API/Common.td
index 7674da0438c29..e2e6d6e452671 100644
--- a/offload/liboffload/API/Common.td
+++ b/offload/liboffload/API/Common.td
@@ -148,6 +148,16 @@ def : Struct {
];
}
+def : Struct {
+ let name = "ol_range_t";
+ let desc = "A three element vector";
+ let members = [
+ StructMember<"size_t", "x", "X">,
+ StructMember<"size_t", "y", "Y">,
+ StructMember<"size_t", "z", "Z">,
+ ];
+}
+
def : Function {
let name = "olInit";
let desc = "Perform initialization of the Offload library and plugins";
diff --git a/offload/liboffload/API/Device.td b/offload/liboffload/API/Device.td
index 4abc24f3ba27f..1c7d1aaee8d59 100644
--- a/offload/liboffload/API/Device.td
+++ b/offload/liboffload/API/Device.td
@@ -31,7 +31,8 @@ def : Enum {
TaggedEtor<"PLATFORM", "ol_platform_handle_t", "the platform associated with the device">,
TaggedEtor<"NAME", "char[]", "Device name">,
TaggedEtor<"VENDOR", "char[]", "Device vendor">,
- TaggedEtor<"DRIVER_VERSION", "char[]", "Driver version">
+ TaggedEtor<"DRIVER_VERSION", "char[]", "Driver version">,
+ TaggedEtor<"MAX_WORK_GROUP_SIZE", "ol_range_t", "Maximum work group size in each dimension">,
];
}
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index d2b331905ab77..b89845c387e65 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -228,16 +228,13 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
// Find the info if it exists under any of the given names
- auto GetInfo = [&](std::vector<std::string> Names) {
- InfoQueueTy DevInfo;
- if (Device == HostDevice())
- return std::string("Host");
-
+ auto FindInfo = [&](InfoQueueTy &DevInfo, std::vector<std::string> &Names)
+ -> std::optional<decltype(DevInfo.getQueue().begin())> {
if (!Device->Device)
- return std::string("");
+ return std::nullopt;
if (auto Err = Device->Device->obtainInfoImpl(DevInfo))
- return std::string("");
+ return std::nullopt;
for (auto Name : Names) {
auto InfoKeyMatches = [&](const InfoQueueTy::InfoQueueEntryTy &Info) {
@@ -247,11 +244,50 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
DevInfo.getQueue().end(), InfoKeyMatches);
if (Item != std::end(DevInfo.getQueue())) {
- return Item->Value;
+ return Item;
}
}
- return std::string("");
+ return std::nullopt;
+ };
+ auto GetInfoString = [&](std::vector<std::string> Names) {
+ InfoQueueTy DevInfo;
+
+ if (auto Item = FindInfo(DevInfo, Names)) {
+ return (*Item)->Value.c_str();
+ } else {
+ return "";
+ }
+ };
+ auto GetInfoXyz = [&](std::vector<std::string> Names) {
+ InfoQueueTy DevInfo;
+
+ if (auto Item = FindInfo(DevInfo, Names)) {
+ auto Iter = *Item;
+ ol_range_t Out{0, 0, 0};
+ auto Level = Iter->Level + 1;
+
+ while ((++Iter)->Level == Level) {
+ switch (Iter->Key[0]) {
+ case 'x':
+ Out.x = std::stoi(Iter->Value);
+ break;
+ case 'y':
+ Out.y = std::stoi(Iter->Value);
+ break;
+ case 'z':
+ Out.z = std::stoi(Iter->Value);
+ break;
+ default:
+ // Ignore any extra values
+ (void)0;
+ }
+ }
+
+ return Out;
+ } else {
+ return ol_range_t{0, 0, 0};
+ }
};
switch (PropName) {
@@ -261,12 +297,21 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST)
: ReturnValue(OL_DEVICE_TYPE_GPU);
case OL_DEVICE_INFO_NAME:
- return ReturnValue(GetInfo({"Device Name"}).c_str());
+ if (Device == HostDevice())
+ return ReturnValue("Host");
+ return ReturnValue(GetInfoString({"Device Name"}));
case OL_DEVICE_INFO_VENDOR:
- return ReturnValue(GetInfo({"Vendor Name"}).c_str());
+ if (Device == HostDevice())
+ return ReturnValue("Host");
+ return ReturnValue(GetInfoString({"Vendor Name"}));
case OL_DEVICE_INFO_DRIVER_VERSION:
+ if (Device == HostDevice())
+ return ReturnValue("Host");
return ReturnValue(
- GetInfo({"CUDA Driver Version", "HSA Runtime Version"}).c_str());
+ GetInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
+ case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
+ return ReturnValue(GetInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
+ "Maximum Block Dimensions" /*CUDA*/}));
default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"getDeviceInfo enum '%i' is invalid", PropName);
diff --git a/offload/tools/offload-tblgen/PrintGen.cpp b/offload/tools/offload-tblgen/PrintGen.cpp
index a964ff09d0f6e..d1189688a90a3 100644
--- a/offload/tools/offload-tblgen/PrintGen.cpp
+++ b/offload/tools/offload-tblgen/PrintGen.cpp
@@ -213,6 +213,11 @@ template <typename T> inline void printTagged(llvm::raw_ostream &os, const void
"enum {0} value);\n",
EnumRec{R}.getName());
}
+ for (auto *R : Records.getAllDerivedDefinitions("Struct")) {
+ OS << formatv("inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, "
+ "const struct {0} param);\n",
+ StructRec{R}.getName());
+ }
OS << "\n";
// Create definitions
diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
index 0247744911eaa..ef7baf9e91275 100644
--- a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
+++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
@@ -77,6 +77,15 @@ TEST_P(olGetDeviceInfoTest, SuccessDriverVersion) {
ASSERT_EQ(std::strlen(DriverVersion.data()), Size - 1);
}
+TEST_P(olGetDeviceInfoTest, SuccessMaxWorkGroupSize) {
+ ol_range_t Value{0, 0, 0};
+ ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
+ sizeof(Value), &Value));
+ ASSERT_GT(Value.x, 0);
+ ASSERT_GT(Value.y, 0);
+ ASSERT_GT(Value.z, 0);
+}
+
TEST_P(olGetDeviceInfoTest, InvalidNullHandleDevice) {
ol_device_type_t DeviceType;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
index edd2704a722dd..a2caad8650c79 100644
--- a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
+++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
@@ -44,6 +44,14 @@ TEST_P(olGetDeviceInfoSizeTest, SuccessDriverVersion) {
ASSERT_NE(Size, 0ul);
}
+TEST_P(olGetDeviceInfoSizeTest, SuccessMaxWorkGroupSize) {
+ size_t Size = 0;
+ ASSERT_SUCCESS(
+ olGetDeviceInfoSize(Device, OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE, &Size));
+ ASSERT_EQ(Size, sizeof(ol_range_t));
+ ASSERT_EQ(Size, sizeof(size_t) * 3);
+}
+
TEST_P(olGetDeviceInfoSizeTest, InvalidNullHandle) {
size_t Size = 0;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
More information about the llvm-commits
mailing list