[llvm-branch-commits] [llvm] [Offload] Refactor device information queries to use new tagging (PR #147318)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Jul 7 08:18:47 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-offload
Author: Ross Brunton (RossBrunton)
<details>
<summary>Changes</summary>
Instead using strings to look up device information (which is brittle
and slow), use the new tags that the plugins specify when building the
nodes.
---
Full diff: https://github.com/llvm/llvm-project/pull/147318.diff
2 Files Affected:
- (modified) offload/liboffload/src/Helpers.hpp (+6-13)
- (modified) offload/liboffload/src/OffloadImpl.cpp (+48-63)
``````````diff
diff --git a/offload/liboffload/src/Helpers.hpp b/offload/liboffload/src/Helpers.hpp
index 8b85945508b98..62e55e500fac7 100644
--- a/offload/liboffload/src/Helpers.hpp
+++ b/offload/liboffload/src/Helpers.hpp
@@ -75,23 +75,16 @@ class InfoWriter {
InfoWriter(InfoWriter &) = delete;
~InfoWriter() = default;
- template <typename T> llvm::Error write(llvm::Expected<T> &&Val) {
- if (Val)
- return getInfo(Size, Target, SizeRet, *Val);
- return Val.takeError();
+ template <typename T> llvm::Error write(T Val) {
+ return getInfo(Size, Target, SizeRet, Val);
}
- template <typename T>
- llvm::Error writeArray(llvm::Expected<T> &&Val, size_t Elems) {
- if (Val)
- return getInfoArray(Elems, Size, Target, SizeRet, *Val);
- return Val.takeError();
+ template <typename T> llvm::Error writeArray(T Val, size_t Elems) {
+ return getInfoArray(Elems, Size, Target, SizeRet, Val);
}
- llvm::Error writeString(llvm::Expected<llvm::StringRef> &&Val) {
- if (Val)
- return getInfoString(Size, Target, SizeRet, *Val);
- return Val.takeError();
+ llvm::Error writeString(llvm::StringRef Val) {
+ return getInfoString(Size, Target, SizeRet, Val);
}
private:
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index f9da638436705..c84bf01460252 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -286,78 +286,63 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
};
- // Find the info if it exists under any of the given names
- auto getInfoString =
- [&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
- for (auto &Name : Names) {
- if (auto Entry = Device->Info.get(Name)) {
- if (!std::holds_alternative<std::string>((*Entry)->Value))
- return makeError(ErrorCode::BACKEND_FAILURE,
- "plugin returned incorrect type");
- return std::get<std::string>((*Entry)->Value).c_str();
- }
- }
-
- return makeError(ErrorCode::UNIMPLEMENTED,
- "plugin did not provide a response for this information");
- };
-
- auto getInfoXyz =
- [&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t> {
- for (auto &Name : Names) {
- if (auto Entry = Device->Info.get(Name)) {
- auto Node = *Entry;
- ol_dimensions_t Out{0, 0, 0};
-
- auto getField = [&](StringRef Name, uint32_t &Dest) {
- if (auto F = Node->get(Name)) {
- if (!std::holds_alternative<size_t>((*F)->Value))
- return makeError(
- ErrorCode::BACKEND_FAILURE,
- "plugin returned incorrect type for dimensions element");
- Dest = std::get<size_t>((*F)->Value);
- } else
- return makeError(ErrorCode::BACKEND_FAILURE,
- "plugin didn't provide all values for dimensions");
- return Plugin::success();
- };
-
- if (auto Res = getField("x", Out.x))
- return Res;
- if (auto Res = getField("y", Out.y))
- return Res;
- if (auto Res = getField("z", Out.z))
- return Res;
-
- return Out;
- }
- }
+ // These are not implemented by the plugin interface
+ if (PropName == OL_DEVICE_INFO_PLATFORM)
+ return Info.write<void *>(Device->Platform);
+ if (PropName == OL_DEVICE_INFO_TYPE)
+ return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
+ // TODO: Update when https://github.com/llvm/llvm-project/pull/147314 is merged
+ if (PropName > OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE)
+ return createOffloadError(ErrorCode::INVALID_ENUMERATION,
+ "getDeviceInfo enum '%i' is invalid", PropName);
+ auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName));
+ if (!EntryOpt)
return makeError(ErrorCode::UNIMPLEMENTED,
"plugin did not provide a response for this information");
- };
+ auto Entry = *EntryOpt;
switch (PropName) {
- case OL_DEVICE_INFO_PLATFORM:
- return Info.write<void *>(Device->Platform);
- case OL_DEVICE_INFO_TYPE:
- return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
case OL_DEVICE_INFO_NAME:
- return Info.writeString(getInfoString({"Device Name"}));
case OL_DEVICE_INFO_VENDOR:
- return Info.writeString(getInfoString({"Vendor Name"}));
- case OL_DEVICE_INFO_DRIVER_VERSION:
- return Info.writeString(
- getInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
- case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
- return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
- "Maximum Block Dimensions" /*CUDA*/}));
- default:
- return createOffloadError(ErrorCode::INVALID_ENUMERATION,
- "getDeviceInfo enum '%i' is invalid", PropName);
+ case OL_DEVICE_INFO_DRIVER_VERSION: {
+ // String values
+ if (!std::holds_alternative<std::string>(Entry->Value))
+ return makeError(ErrorCode::BACKEND_FAILURE,
+ "plugin returned incorrect type");
+ return Info.writeString(std::get<std::string>(Entry->Value).c_str());
}
- return Error::success();
+ case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: {
+ // {x, y, z} triples
+ ol_dimensions_t Out{0, 0, 0};
+
+ auto getField = [&](StringRef Name, uint32_t &Dest) {
+ if (auto F = Entry->get(Name)) {
+ if (!std::holds_alternative<size_t>((*F)->Value))
+ return makeError(
+ ErrorCode::BACKEND_FAILURE,
+ "plugin returned incorrect type for dimensions element");
+ Dest = std::get<size_t>((*F)->Value);
+ } else
+ return makeError(ErrorCode::BACKEND_FAILURE,
+ "plugin didn't provide all values for dimensions");
+ return Plugin::success();
+ };
+
+ if (auto Res = getField("x", Out.x))
+ return Res;
+ if (auto Res = getField("y", Out.y))
+ return Res;
+ if (auto Res = getField("z", Out.z))
+ return Res;
+
+ return Info.write(Out);
+ }
+
+ default:
+ llvm_unreachable("Unimplemented device info");
+ }
}
Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
``````````
</details>
https://github.com/llvm/llvm-project/pull/147318
More information about the llvm-branch-commits
mailing list