[llvm-branch-commits] [llvm] [Offload] Refactor device information queries to use new tagging (PR #147318)

Ross Brunton via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Jul 7 08:18:13 PDT 2025


https://github.com/RossBrunton created https://github.com/llvm/llvm-project/pull/147318

Instead using strings to look up device information (which is brittle
and slow), use the new tags that the plugins specify when building the
nodes.


>From 4cce1eec173637a0e50655e10ad520a9821b9960 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Mon, 7 Jul 2025 16:13:32 +0100
Subject: [PATCH] [Offload] Refactor device information queries to use new
 tagging

Instead using strings to look up device information (which is brittle
and slow), use the new tags that the plugins specify when building the
nodes.
---
 offload/liboffload/src/Helpers.hpp     |  19 ++---
 offload/liboffload/src/OffloadImpl.cpp | 111 +++++++++++--------------
 2 files changed, 54 insertions(+), 76 deletions(-)

diff --git a/offload/liboffload/src/Helpers.hpp b/offload/liboffload/src/Helpers.hpp
index 8b85945508b98..62e55e500fac7 100644
--- a/offload/liboffload/src/Helpers.hpp
+++ b/offload/liboffload/src/Helpers.hpp
@@ -75,23 +75,16 @@ class InfoWriter {
   InfoWriter(InfoWriter &) = delete;
   ~InfoWriter() = default;
 
-  template <typename T> llvm::Error write(llvm::Expected<T> &&Val) {
-    if (Val)
-      return getInfo(Size, Target, SizeRet, *Val);
-    return Val.takeError();
+  template <typename T> llvm::Error write(T Val) {
+    return getInfo(Size, Target, SizeRet, Val);
   }
 
-  template <typename T>
-  llvm::Error writeArray(llvm::Expected<T> &&Val, size_t Elems) {
-    if (Val)
-      return getInfoArray(Elems, Size, Target, SizeRet, *Val);
-    return Val.takeError();
+  template <typename T> llvm::Error writeArray(T Val, size_t Elems) {
+    return getInfoArray(Elems, Size, Target, SizeRet, Val);
   }
 
-  llvm::Error writeString(llvm::Expected<llvm::StringRef> &&Val) {
-    if (Val)
-      return getInfoString(Size, Target, SizeRet, *Val);
-    return Val.takeError();
+  llvm::Error writeString(llvm::StringRef Val) {
+    return getInfoString(Size, Target, SizeRet, Val);
   }
 
 private:
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index f9da638436705..c84bf01460252 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -286,78 +286,63 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
     return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
   };
 
-  // Find the info if it exists under any of the given names
-  auto getInfoString =
-      [&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
-    for (auto &Name : Names) {
-      if (auto Entry = Device->Info.get(Name)) {
-        if (!std::holds_alternative<std::string>((*Entry)->Value))
-          return makeError(ErrorCode::BACKEND_FAILURE,
-                           "plugin returned incorrect type");
-        return std::get<std::string>((*Entry)->Value).c_str();
-      }
-    }
-
-    return makeError(ErrorCode::UNIMPLEMENTED,
-                     "plugin did not provide a response for this information");
-  };
-
-  auto getInfoXyz =
-      [&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t> {
-    for (auto &Name : Names) {
-      if (auto Entry = Device->Info.get(Name)) {
-        auto Node = *Entry;
-        ol_dimensions_t Out{0, 0, 0};
-
-        auto getField = [&](StringRef Name, uint32_t &Dest) {
-          if (auto F = Node->get(Name)) {
-            if (!std::holds_alternative<size_t>((*F)->Value))
-              return makeError(
-                  ErrorCode::BACKEND_FAILURE,
-                  "plugin returned incorrect type for dimensions element");
-            Dest = std::get<size_t>((*F)->Value);
-          } else
-            return makeError(ErrorCode::BACKEND_FAILURE,
-                             "plugin didn't provide all values for dimensions");
-          return Plugin::success();
-        };
-
-        if (auto Res = getField("x", Out.x))
-          return Res;
-        if (auto Res = getField("y", Out.y))
-          return Res;
-        if (auto Res = getField("z", Out.z))
-          return Res;
-
-        return Out;
-      }
-    }
+  // These are not implemented by the plugin interface
+  if (PropName == OL_DEVICE_INFO_PLATFORM)
+    return Info.write<void *>(Device->Platform);
+  if (PropName == OL_DEVICE_INFO_TYPE)
+    return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
+  // TODO: Update when https://github.com/llvm/llvm-project/pull/147314 is merged
+  if (PropName > OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE)
+    return createOffloadError(ErrorCode::INVALID_ENUMERATION,
+                              "getDeviceInfo enum '%i' is invalid", PropName);
 
+  auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName));
+  if (!EntryOpt)
     return makeError(ErrorCode::UNIMPLEMENTED,
                      "plugin did not provide a response for this information");
-  };
+  auto Entry = *EntryOpt;
 
   switch (PropName) {
-  case OL_DEVICE_INFO_PLATFORM:
-    return Info.write<void *>(Device->Platform);
-  case OL_DEVICE_INFO_TYPE:
-    return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
   case OL_DEVICE_INFO_NAME:
-    return Info.writeString(getInfoString({"Device Name"}));
   case OL_DEVICE_INFO_VENDOR:
-    return Info.writeString(getInfoString({"Vendor Name"}));
-  case OL_DEVICE_INFO_DRIVER_VERSION:
-    return Info.writeString(
-        getInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
-  case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
-    return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
-                                  "Maximum Block Dimensions" /*CUDA*/}));
-  default:
-    return createOffloadError(ErrorCode::INVALID_ENUMERATION,
-                              "getDeviceInfo enum '%i' is invalid", PropName);
+  case OL_DEVICE_INFO_DRIVER_VERSION: {
+    // String values
+    if (!std::holds_alternative<std::string>(Entry->Value))
+      return makeError(ErrorCode::BACKEND_FAILURE,
+                       "plugin returned incorrect type");
+    return Info.writeString(std::get<std::string>(Entry->Value).c_str());
   }
 
-  return Error::success();
+  case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: {
+    // {x, y, z} triples
+    ol_dimensions_t Out{0, 0, 0};
+
+    auto getField = [&](StringRef Name, uint32_t &Dest) {
+      if (auto F = Entry->get(Name)) {
+        if (!std::holds_alternative<size_t>((*F)->Value))
+          return makeError(
+              ErrorCode::BACKEND_FAILURE,
+              "plugin returned incorrect type for dimensions element");
+        Dest = std::get<size_t>((*F)->Value);
+      } else
+        return makeError(ErrorCode::BACKEND_FAILURE,
+                         "plugin didn't provide all values for dimensions");
+      return Plugin::success();
+    };
+
+    if (auto Res = getField("x", Out.x))
+      return Res;
+    if (auto Res = getField("y", Out.y))
+      return Res;
+    if (auto Res = getField("z", Out.z))
+      return Res;
+
+    return Info.write(Out);
+  }
+
+  default:
+    llvm_unreachable("Unimplemented device info");
+  }
 }
 
 Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,



More information about the llvm-branch-commits mailing list