[llvm] [Offload] Add `MAX_WORK_GROUP_SIZE` device info query (PR #143718)

Ross Brunton via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 12 08:15:22 PDT 2025


https://github.com/RossBrunton updated https://github.com/llvm/llvm-project/pull/143718

>From e47c8edcfb92eba9d42238d3be5809161a3fd690 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Wed, 11 Jun 2025 15:55:09 +0100
Subject: [PATCH 1/3] [Offload] Add `MAX_WORK_GROUP_SIZE` device info query

This adds a new device info query for the maximum workgroup/block size
for each dimension. Since this returns three values, a new `ol_range_t`
type was added as an `{x, y, z}` triplet. Device info handling and
struct printing was also updated to handle it.
---
 offload/liboffload/API/Device.td              |  3 +-
 offload/liboffload/src/OffloadImpl.cpp        | 69 +++++++++++++++----
 offload/tools/offload-tblgen/PrintGen.cpp     |  5 ++
 .../OffloadAPI/device/olGetDeviceInfo.cpp     |  9 +++
 .../OffloadAPI/device/olGetDeviceInfoSize.cpp |  8 +++
 5 files changed, 81 insertions(+), 13 deletions(-)

diff --git a/offload/liboffload/API/Device.td b/offload/liboffload/API/Device.td
index 4abc24f3ba27f..94bd6cbf0e5be 100644
--- a/offload/liboffload/API/Device.td
+++ b/offload/liboffload/API/Device.td
@@ -31,7 +31,8 @@ def : Enum {
     TaggedEtor<"PLATFORM", "ol_platform_handle_t", "the platform associated with the device">,
     TaggedEtor<"NAME", "char[]", "Device name">,
     TaggedEtor<"VENDOR", "char[]", "Device vendor">,
-    TaggedEtor<"DRIVER_VERSION", "char[]", "Driver version">
+    TaggedEtor<"DRIVER_VERSION", "char[]", "Driver version">,
+    TaggedEtor<"MAX_WORK_GROUP_SIZE", "ol_dimensions_t", "Maximum work group size in each dimension">,
   ];
 }
 
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 0a784cddeaecb..1781fe79cdce5 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -228,16 +228,13 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
   ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
 
   // Find the info if it exists under any of the given names
-  auto GetInfo = [&](std::vector<std::string> Names) {
-    InfoQueueTy DevInfo;
-    if (Device == HostDevice())
-      return std::string("Host");
-
+  auto FindInfo = [&](InfoQueueTy &DevInfo, std::vector<std::string> &Names)
+      -> std::optional<decltype(DevInfo.getQueue().begin())> {
     if (!Device->Device)
-      return std::string("");
+      return std::nullopt;
 
     if (auto Err = Device->Device->obtainInfoImpl(DevInfo))
-      return std::string("");
+      return std::nullopt;
 
     for (auto Name : Names) {
       auto InfoKeyMatches = [&](const InfoQueueTy::InfoQueueEntryTy &Info) {
@@ -247,11 +244,50 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
                                DevInfo.getQueue().end(), InfoKeyMatches);
 
       if (Item != std::end(DevInfo.getQueue())) {
-        return Item->Value;
+        return Item;
       }
     }
 
-    return std::string("");
+    return std::nullopt;
+  };
+  auto GetInfoString = [&](std::vector<std::string> Names) {
+    InfoQueueTy DevInfo;
+
+    if (auto Item = FindInfo(DevInfo, Names)) {
+      return (*Item)->Value.c_str();
+    } else {
+      return "";
+    }
+  };
+  auto GetInfoXyz = [&](std::vector<std::string> Names) {
+    InfoQueueTy DevInfo;
+
+    if (auto Item = FindInfo(DevInfo, Names)) {
+      auto Iter = *Item;
+      ol_dimensions_t Out{0, 0, 0};
+      auto Level = Iter->Level + 1;
+
+      while ((++Iter)->Level == Level) {
+        switch (Iter->Key[0]) {
+        case 'x':
+          Out.x = std::stoi(Iter->Value);
+          break;
+        case 'y':
+          Out.y = std::stoi(Iter->Value);
+          break;
+        case 'z':
+          Out.z = std::stoi(Iter->Value);
+          break;
+        default:
+          // Ignore any extra values
+          (void)0;
+        }
+      }
+
+      return Out;
+    } else {
+      return ol_dimensions_t{0, 0, 0};
+    }
   };
 
   switch (PropName) {
@@ -261,12 +297,21 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
     return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST)
                                   : ReturnValue(OL_DEVICE_TYPE_GPU);
   case OL_DEVICE_INFO_NAME:
-    return ReturnValue(GetInfo({"Device Name"}).c_str());
+    if (Device == HostDevice())
+      return ReturnValue("Host");
+    return ReturnValue(GetInfoString({"Device Name"}));
   case OL_DEVICE_INFO_VENDOR:
-    return ReturnValue(GetInfo({"Vendor Name"}).c_str());
+    if (Device == HostDevice())
+      return ReturnValue("Host");
+    return ReturnValue(GetInfoString({"Vendor Name"}));
   case OL_DEVICE_INFO_DRIVER_VERSION:
+    if (Device == HostDevice())
+      return ReturnValue("Host");
     return ReturnValue(
-        GetInfo({"CUDA Driver Version", "HSA Runtime Version"}).c_str());
+        GetInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
+  case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
+    return ReturnValue(GetInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
+                                   "Maximum Block Dimensions" /*CUDA*/}));
   default:
     return createOffloadError(ErrorCode::INVALID_ENUMERATION,
                               "getDeviceInfo enum '%i' is invalid", PropName);
diff --git a/offload/tools/offload-tblgen/PrintGen.cpp b/offload/tools/offload-tblgen/PrintGen.cpp
index a964ff09d0f6e..d1189688a90a3 100644
--- a/offload/tools/offload-tblgen/PrintGen.cpp
+++ b/offload/tools/offload-tblgen/PrintGen.cpp
@@ -213,6 +213,11 @@ template <typename T> inline void printTagged(llvm::raw_ostream &os, const void
                   "enum {0} value);\n",
                   EnumRec{R}.getName());
   }
+  for (auto *R : Records.getAllDerivedDefinitions("Struct")) {
+    OS << formatv("inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, "
+                  "const struct {0} param);\n",
+                  StructRec{R}.getName());
+  }
   OS << "\n";
 
   // Create definitions
diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
index 0247744911eaa..a6f0145ab39b4 100644
--- a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
+++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
@@ -77,6 +77,15 @@ TEST_P(olGetDeviceInfoTest, SuccessDriverVersion) {
   ASSERT_EQ(std::strlen(DriverVersion.data()), Size - 1);
 }
 
+TEST_P(olGetDeviceInfoTest, SuccessMaxWorkGroupSize) {
+  ol_dimensions_t Value{0, 0, 0};
+  ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
+                                 sizeof(Value), &Value));
+  ASSERT_GT(Value.x, 0);
+  ASSERT_GT(Value.y, 0);
+  ASSERT_GT(Value.z, 0);
+}
+
 TEST_P(olGetDeviceInfoTest, InvalidNullHandleDevice) {
   ol_device_type_t DeviceType;
   ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
index edd2704a722dd..a908078a25211 100644
--- a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
+++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
@@ -44,6 +44,14 @@ TEST_P(olGetDeviceInfoSizeTest, SuccessDriverVersion) {
   ASSERT_NE(Size, 0ul);
 }
 
+TEST_P(olGetDeviceInfoSizeTest, SuccessMaxWorkGroupSize) {
+  size_t Size = 0;
+  ASSERT_SUCCESS(
+      olGetDeviceInfoSize(Device, OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE, &Size));
+  ASSERT_EQ(Size, sizeof(ol_dimensions_t));
+  ASSERT_EQ(Size, sizeof(uint32_t) * 3);
+}
+
 TEST_P(olGetDeviceInfoSizeTest, InvalidNullHandle) {
   size_t Size = 0;
   ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,

>From 8c9d99e3e9eee53de8b7fd459e21769726bb3878 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Wed, 11 Jun 2025 16:43:12 +0100
Subject: [PATCH 2/3] Small fixes

---
 offload/liboffload/src/OffloadImpl.cpp                  | 6 +++---
 offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp | 6 +++---
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 1781fe79cdce5..31e0e1e5e892c 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -228,7 +228,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
   ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
 
   // Find the info if it exists under any of the given names
-  auto FindInfo = [&](InfoQueueTy &DevInfo, std::vector<std::string> &Names)
+  auto FindInfo = [&](InfoQueueTy &DevInfo, llvm::SmallVector<StringRef> &Names)
       -> std::optional<decltype(DevInfo.getQueue().begin())> {
     if (!Device->Device)
       return std::nullopt;
@@ -250,7 +250,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
 
     return std::nullopt;
   };
-  auto GetInfoString = [&](std::vector<std::string> Names) {
+  auto GetInfoString = [&](llvm::SmallVector<StringRef> Names) {
     InfoQueueTy DevInfo;
 
     if (auto Item = FindInfo(DevInfo, Names)) {
@@ -259,7 +259,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
       return "";
     }
   };
-  auto GetInfoXyz = [&](std::vector<std::string> Names) {
+  auto GetInfoXyz = [&](llvm::SmallVector<StringRef> Names) {
     InfoQueueTy DevInfo;
 
     if (auto Item = FindInfo(DevInfo, Names)) {
diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
index a6f0145ab39b4..c534c45205993 100644
--- a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
+++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
@@ -81,9 +81,9 @@ TEST_P(olGetDeviceInfoTest, SuccessMaxWorkGroupSize) {
   ol_dimensions_t Value{0, 0, 0};
   ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
                                  sizeof(Value), &Value));
-  ASSERT_GT(Value.x, 0);
-  ASSERT_GT(Value.y, 0);
-  ASSERT_GT(Value.z, 0);
+  ASSERT_GT(Value.x, 0u);
+  ASSERT_GT(Value.y, 0u);
+  ASSERT_GT(Value.z, 0u);
 }
 
 TEST_P(olGetDeviceInfoTest, InvalidNullHandleDevice) {

>From b99d5cb989d72c1ab8ebe988e2f589a7e6256e6d Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Thu, 12 Jun 2025 16:14:32 +0100
Subject: [PATCH 3/3] Make comment less confusing

---
 offload/liboffload/src/OffloadImpl.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 31e0e1e5e892c..f48cc2c1b2b43 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -267,6 +267,8 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
       ol_dimensions_t Out{0, 0, 0};
       auto Level = Iter->Level + 1;
 
+      // Check the "children" of the current info for x/y/z components.
+      // We ignore any components that don't match.
       while ((++Iter)->Level == Level) {
         switch (Iter->Key[0]) {
         case 'x':
@@ -278,9 +280,6 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
         case 'z':
           Out.z = std::stoi(Iter->Value);
           break;
-        default:
-          // Ignore any extra values
-          (void)0;
         }
       }
 



More information about the llvm-commits mailing list