[llvm] [Offload] Introduce `offload-tblgen` (PR #88923)
Callum Fare via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 16 09:19:21 PDT 2024
https://github.com/callumfare created https://github.com/llvm/llvm-project/pull/88923
Introduce `offload-tblgen`, which will be used for automatically generating some source files for the Offload project. The tablegen files are intended to be the 'single source of truth' for the new API, with the header file automatically generated, along with other useful source and .inc files.
This is an initial implementation intended to be a useful starting off point for discussion, and is based on tooling from Unified Runtime as previously discussed. I'm expecting a lot of iteration and discussion before anything is merged, so all feedback is welcome.
See `offload/API/README.md` for more documentation.
Notable things still to be done (not necessarily before this PR is merged):
* Decide on exactly what features are desirable. A lot of this is lifted from the equivalent templates and scripts in Unified Runtime, and obviously not everything will be relevant.
* Decide on a prefix to use for API naming (like `cu` for cuda or `cl` for OpenCL). Currently set to `ol` for the sake of brevity. Could also be `offl`, `offload` or something else. This is plumbed into quite a few places so makes sense to decide early.
* Check that the doxygen comments in the generated header can actually be used to generate documentation
* Implement tracing/printing backend
* Finish the validation backend
* Implement some kind of test generation
>From 5a4a251286e4852a159d9b23a3055a6e06489de3 Mon Sep 17 00:00:00 2001
From: Callum Fare <callum at codeplay.com>
Date: Tue, 16 Apr 2024 16:49:40 +0100
Subject: [PATCH] [Offload] Introduce `offload-tblgen`
offload/API/APIDefs.td | 151 ++++++
offload/API/Example.td | 481 ++++++++++++++++++
offload/API/OffloadAPI.td | 5 +
offload/API/README.md | 103 ++++
offload/CMakeLists.txt | 15 +
offload/tools/offload-tblgen/APIGen.cpp | 156 ++++++
offload/tools/offload-tblgen/CMakeLists.txt | 12 +
offload/tools/offload-tblgen/GenCommon.hpp | 20 +
offload/tools/offload-tblgen/Generators.hpp | 4 +
offload/tools/offload-tblgen/RecordTypes.hpp | 186 +++++++
.../tools/offload-tblgen/ValidationGen.cpp | 121 +++++
.../tools/offload-tblgen/offload-tblgen.cpp | 61 +++
12 files changed, 1315 insertions(+)
create mode 100644 offload/API/APIDefs.td
create mode 100644 offload/API/Example.td
create mode 100644 offload/API/OffloadAPI.td
create mode 100644 offload/API/README.md
create mode 100644 offload/CMakeLists.txt
create mode 100644 offload/tools/offload-tblgen/APIGen.cpp
create mode 100644 offload/tools/offload-tblgen/CMakeLists.txt
create mode 100644 offload/tools/offload-tblgen/GenCommon.hpp
create mode 100644 offload/tools/offload-tblgen/Generators.hpp
create mode 100644 offload/tools/offload-tblgen/RecordTypes.hpp
create mode 100644 offload/tools/offload-tblgen/ValidationGen.cpp
create mode 100644 offload/tools/offload-tblgen/offload-tblgen.cpp
diff --git a/offload/API/APIDefs.td b/offload/API/APIDefs.td
new file mode 100644
index 00000000000000..3b0266dc2b3f24
--- /dev/null
+++ b/offload/API/APIDefs.td
@@ -0,0 +1,151 @@
+// See offload/API/README.md for documentation.
+defvar PARAM_IN = 0x1;
+defvar PARAM_OUT = 0x2;
+defvar PARAM_OPTIONAL = 0x4;
+defvar PREFIX = "OL";
+defvar prefix = !tolower(PREFIX);
+class IsHandleType<string Type> {
+ // size("_handle_t") == 9
+ bit ret = !if(!lt(!size(Type), 9), 0, !ne(!find(Type, "_handle_t", !sub(!size(Type), 9)), -1));
+class IsPointerType<string Type> {
+ bit ret = !ne(!find(Type, "*", !sub(!size(Type), 1)), -1);
+class Param<string Type, string Name, string Desc, bits<3> _Flags = 0> {
+ string type = Type;
+ string name = Name;
+ string desc = Desc;
+ bits<3> flags = _Flags;
+ bit IsHandle = IsHandleType<type>.ret;
+ bit IsPointer = IsPointerType<type>.ret;
+class Return<string Value, list<string> Conditions = []> {
+ string value = Value;
+ list<string> conditions = Conditions;
+class ShouldCheckHandle<Param P> {
+ bit ret = !and(P.IsHandle, !eq(!and(PARAM_OPTIONAL, P.flags), 0));
+class ShouldCheckPointer<Param P> {
+ bit ret = !and(P.IsPointer, !eq(!and(PARAM_OPTIONAL, P.flags), 0));
+class AppendConditionsToReturn<list<Return> Returns, string ReturnValue, list<string> Conditions> {
+ list<Return> ret = !foreach(Ret, Returns, !if(!eq(Ret.value, ReturnValue), Return<Ret.value, Ret.conditions # Conditions>, Ret));
+class AddHandleChecksToReturns<list<Param> Params, list<Return> Returns> {
+ list<string> handle_params = !foreach(P, Params, !if(ShouldCheckHandle<P>.ret, P.name, ""));
+ list<string> handle_params_filt = !filter(param, handle_params, !ne(param, ""));
+ list<string> handle_param_conds = !foreach(handle, handle_params_filt, "`NULL == "#handle#"`");
+ // Does the list of returns already contain ERROR_INVALID_NULL_HANDLE?
+ bit returns_has_inv_handle = !foldl(0, Returns, HasErr, Ret, !or(HasErr, !eq(Ret.value, PREFIX # "_RESULT_ERROR_INVALID_NULL_HANDLE")));
+ list<Return> returns_out = !if(returns_has_inv_handle,
+ AppendConditionsToReturn<Returns, PREFIX # "_RESULT_ERROR_INVALID_NULL_HANDLE", handle_param_conds>.ret,
+ !listconcat(Returns, [Return<PREFIX # "_RESULT_ERROR_INVALID_NULL_HANDLE", handle_param_conds>])
+ );
+class AddPointerChecksToReturns<list<Param> Params, list<Return> Returns> {
+ list<string> ptr_params = !foreach(P, Params, !if(ShouldCheckPointer<P>.ret, P.name, ""));
+ list<string> ptr_params_filt = !filter(param, ptr_params, !ne(param, ""));
+ list<string> ptr_param_conds = !foreach(ptr, ptr_params_filt, "`NULL == "#ptr#"`");
+ // Does the list of returns already contain ERROR_INVALID_NULL_POINTER?
+ bit returns_has_inv_ptr = !foldl(0, Returns, HasErr, Ret, !or(HasErr, !eq(Ret.value, PREFIX # "_RESULT_ERROR_INVALID_NULL_POINTER")));
+ list<Return> returns_out = !if(returns_has_inv_ptr,
+ AppendConditionsToReturn<Returns, PREFIX # "_RESULT_ERROR_INVALID_NULL_POINTER", ptr_param_conds>.ret,
+ !listconcat(Returns, [Return<PREFIX # "_RESULT_ERROR_INVALID_NULL_POINTER", ptr_param_conds>])
+ );
+defvar DefaultReturns = [
+class APIObject {
+ string name;
+ string desc;
+class Function<string Class> : APIObject {
+ string api_class = Class;
+ list<Param> params;
+ list<Return> returns;
+ list<string> details = [];
+ list<string> analogues = [];
+ list<Return> returns_with_def = !listconcat(DefaultReturns, returns);
+ list<Return> all_returns = AddPointerChecksToReturns<params,
+ AddHandleChecksToReturns<params, returns_with_def>.returns_out>.returns_out;
+class Etor<string Name, string Desc> {
+ string name = Name;
+ string desc = Desc;
+class Enum : APIObject {
+ // This refers to whether the enumerator descriptions specify a return
+ // type for functions where this enum may be used as an input type.
+ // The format is "[$x_some_return_t] Description text"
+ // (TODO: This is lifted from UR, is it relevant?)
+ bit is_typed = 0;
+ list<Etor> etors = [];
+class StructMember<string Type, string Name, string Desc> {
+ string type = Type;
+ string name = Name;
+ string desc = Desc;
+defvar DefaultPropStructMembers = [
+ StructMember<prefix#"_structure_type_t", "stype", "type of this structure">,
+ StructMember<"void*", "pNext", "pointer to extension-specific structure">
+class StructHasInheritedMembers<string BaseClass> {
+ bit ret = !or(!eq(BaseClass, prefix#"_base_properties_t"), !eq(BaseClass, prefix#"_base_desc_t"));
+class Struct : APIObject {
+ string base_class = "";
+ list<StructMember> members;
+ list<StructMember> all_members = !if(StructHasInheritedMembers<base_class>.ret, DefaultPropStructMembers, []) # members;
+class Typedef : APIObject {
+ string value;
+class FptrTypedef : APIObject {
+ list<Param> params;
+ list<Return> returns;
+class Macro : APIObject {
+ string value;
+ string condition;
+ string alt_value;
+class Handle : APIObject;
diff --git a/offload/API/Example.td b/offload/API/Example.td
new file mode 100644
index 00000000000000..04987292aa853c
--- /dev/null
+++ b/offload/API/Example.td
@@ -0,0 +1,481 @@
+// This file serves as an example for the Offload tablegen framework.
+// It is NOT an actual representation of the API. It is based off a random
+// selection of features from Unified Runtime.
+def : Macro {
+ let name = "OL_MAKE_VERSION( _major, _minor )";
+ let desc = "Generates generic API versions";
+ let value = "(( _major << 16 )|( _minor & 0x0000ffff))";
+def : Macro {
+ let name = "OL_MAJOR_VERSION( _ver )";
+ let desc = "Extracts API major version";
+ let value = "( _ver >> 16 )";
+def : Macro {
+ let name = "OL_MINOR_VERSION( _ver )";
+ let desc = "Extracts API minor version";
+ let value = "( _ver & 0x0000ffff )";
+def : Macro {
+ let name = "OL_APICALL";
+ let desc = "Calling convention for all API functions";
+ let condition = "defined(_WIN32)";
+ let value = "__cdecl";
+ let alt_value = "";
+def : Macro {
+ let name = "OL_APIEXPORT";
+ let desc = "Microsoft-specific dllexport storage-class attribute";
+ let condition = "defined(_WIN32)";
+ let value = "__declspec(dllexport)";
+ let alt_value = "";
+def : Macro {
+ let name = "OL_DLLEXPORT";
+ let desc = "Microsoft-specific dllexport storage-class attribute";
+ let condition = "defined(_WIN32)";
+ let value = "__declspec(dllexport)";
+def : Macro {
+ let name = "OL_DLLEXPORT";
+ let desc = "GCC-specific dllexport storage-class attribute";
+ let condition = "__GNUC__ >= 4";
+ let value = "__attribute__ ((visibility (\"default\")))";
+ let alt_value = "";
+def : Typedef {
+ let name = "ol_bool_t";
+ let value = "uint8_t";
+ let desc = "compiler-independent type";
+def : Handle {
+ let name = "ol_loader_config_handle_t";
+ let desc = "Handle of a loader config object";
+def : Handle {
+ let name = "ol_adapter_handle_t";
+ let desc = "Handle of an adapter instance";
+def : Handle {
+ let name = "ol_platform_handle_t";
+ let desc = "Handle of a platform instance";
+def : Handle {
+ let name = "ol_device_handle_t";
+ let desc = "Handle of platform's device object";
+def : Handle {
+ let name = "ol_context_handle_t";
+ let desc = "Handle of context object";
+def : Handle {
+ let name = "ol_event_handle_t";
+ let desc = "Handle of event object";
+def : Handle {
+ let name = "ol_program_handle_t";
+ let desc = "Handle of Program object";
+def : Handle {
+ let name = "ol_kernel_handle_t";
+ let desc = "Handle of program's Kernel object";
+def : Handle {
+ let name = "ol_queue_handle_t";
+ let desc = "Handle of a queue object";
+def : Handle {
+ let name = "ol_native_handle_t";
+ let desc = "Handle of a native object";
+def : Handle {
+ let name = "ol_sampler_handle_t";
+ let desc = "Handle of a Sampler object";
+def : Handle {
+ let name = "ol_mem_handle_t";
+ let desc = "Handle of memory object which can either be buffer or image";
+def : Handle {
+ let name = "ol_physical_mem_handle_t";
+ let desc = "Handle of physical memory object";
+def : Macro {
+ let name = "OL_BIT( _i )";
+ let desc = "Generic macro for enumerator bit masks";
+ let value = "( 1 << _i )";
+def : Enum {
+ let name = "ol_result_t";
+ let desc = "Defines Return/Error codes";
+ let etors =[
+ Etor<"SUCCESS", "Success">,
+ Etor<"ERROR_INVALID_OPERATION", "Invalid operation">,
+ Etor<"ERROR_INVALID_QUEUE_PROPERTIES", "Invalid queue properties">,
+ Etor<"ERROR_INVALID_QUEUE", "Invalid queue">,
+ Etor<"ERROR_INVALID_VALUE", "Invalid Value">,
+ Etor<"ERROR_INVALID_CONTEXT", "Invalid context">,
+ Etor<"ERROR_INVALID_PLATFORM", "Invalid platform">,
+ Etor<"ERROR_INVALID_BINARY", "Invalid binary">,
+ Etor<"ERROR_INVALID_PROGRAM", "Invalid program">,
+ Etor<"ERROR_INVALID_SAMPLER", "Invalid sampler">,
+ Etor<"ERROR_INVALID_BUFFER_SIZE", "Invalid buffer size">,
+ Etor<"ERROR_INVALID_MEM_OBJECT", "Invalid memory object">,
+ Etor<"ERROR_INVALID_EVENT", "Invalid event">,
+ Etor<"ERROR_INVALID_EVENT_WAIT_LIST", "Returned when the event wait list or the events in the wait list are invalid.">,
+ Etor<"ERROR_MISALIGNED_SUB_BUFFER_OFFSET", "Misaligned sub buffer offset">,
+ Etor<"ERROR_INVALID_WORK_GROUP_SIZE", "Invalid work group size">,
+ Etor<"ERROR_COMPILER_NOT_AVAILABLE", "Compiler not available">,
+ Etor<"ERROR_PROFILING_INFO_NOT_AVAILABLE", "Profiling info not available">,
+ Etor<"ERROR_DEVICE_NOT_FOUND", "Device not found">,
+ Etor<"ERROR_INVALID_DEVICE", "Invalid device">,
+ Etor<"ERROR_DEVICE_LOST", "Device hung, reset, was removed, or adapter update occurred">,
+ Etor<"ERROR_DEVICE_REQUIRES_RESET", "Device requires a reset">,
+ Etor<"ERROR_DEVICE_IN_LOW_POWER_STATE", "Device currently in low power state">,
+ Etor<"ERROR_DEVICE_PARTITION_FAILED", "Device partitioning failed">,
+ Etor<"ERROR_INVALID_WORK_ITEM_SIZE", "Invalid work item size">,
+ Etor<"ERROR_INVALID_WORK_DIMENSION", "Invalid work dimension">,
+ Etor<"ERROR_INVALID_KERNEL_ARGS", "Invalid kernel args">,
+ Etor<"ERROR_INVALID_KERNEL", "Invalid kernel">,
+ Etor<"ERROR_INVALID_KERNEL_NAME", "[Validation] kernel name is not found in the program">,
+ Etor<"ERROR_INVALID_KERNEL_ARGUMENT_INDEX", "[Validation] kernel argument index is not valid for kernel">,
+ Etor<"ERROR_INVALID_KERNEL_ARGUMENT_SIZE", "[Validation] kernel argument size does not match kernel">,
+ Etor<"ERROR_INVALID_KERNEL_ATTRIBUTE_VALUE", "[Validation] value of kernel attribute is not valid for the kernel or device">,
+ Etor<"ERROR_INVALID_IMAGE_SIZE", "Invalid image size">,
+ Etor<"ERROR_INVALID_IMAGE_FORMAT_DESCRIPTOR", "Invalid image format descriptor">,
+ Etor<"ERROR_IMAGE_FORMAT_NOT_SUPPORTED", "Image format not supported">,
+ Etor<"ERROR_MEM_OBJECT_ALLOCATION_FAILURE", "Memory object allocation failure">,
+ Etor<"ERROR_INVALID_PROGRAM_EXECUTABLE", "Program object parameter is invalid.">,
+ Etor<"ERROR_UNINITIALIZED", "[Validation] adapter is not initialized or specific entry-point is not implemented">,
+ Etor<"ERROR_OUT_OF_HOST_MEMORY", "Insufficient host memory to satisfy call">,
+ Etor<"ERROR_OUT_OF_DEVICE_MEMORY", "Insufficient device memory to satisfy call">,
+ Etor<"ERROR_OUT_OF_RESOURCES", "Out of resources">,
+ Etor<"ERROR_PROGRAM_BUILD_FAILURE", "Error occurred when building program, see build log for details">,
+ Etor<"ERROR_PROGRAM_LINK_FAILURE", "Error occurred when linking programs, see build log for details">,
+ Etor<"ERROR_UNSUPPORTED_VERSION", "[Validation] generic error code for unsupported versions">,
+ Etor<"ERROR_UNSUPPORTED_FEATURE", "[Validation] generic error code for unsupported features">,
+ Etor<"ERROR_INVALID_ARGUMENT", "[Validation] generic error code for invalid arguments">,
+ Etor<"ERROR_INVALID_NULL_HANDLE", "[Validation] handle argument is not valid">,
+ Etor<"ERROR_HANDLE_OBJECT_IN_USE", "[Validation] object pointed to by handle still in-use by device">,
+ Etor<"ERROR_INVALID_NULL_POINTER", "[Validation] pointer argument may not be nullptr">,
+ Etor<"ERROR_INVALID_SIZE", "[Validation] invalid size or dimensions (e.g., must not be zero, or is out of bounds)">,
+ Etor<"ERROR_UNSUPPORTED_SIZE", "[Validation] size argument is not supported by the device (e.g., too large)">,
+ Etor<"ERROR_UNSUPPORTED_ALIGNMENT", "[Validation] alignment argument is not supported by the device (e.g., too small)">,
+ Etor<"ERROR_INVALID_SYNCHRONIZATION_OBJECT", "[Validation] synchronization object in invalid state">,
+ Etor<"ERROR_INVALID_ENUMERATION", "[Validation] enumerator argument is not valid">,
+ Etor<"ERROR_UNSUPPORTED_ENUMERATION", "[Validation] enumerator argument is not supported by the device">,
+ Etor<"ERROR_UNSUPPORTED_IMAGE_FORMAT", "[Validation] image format is not supported by the device">,
+ Etor<"ERROR_INVALID_NATIVE_BINARY", "[Validation] native binary is not supported by the device">,
+ Etor<"ERROR_INVALID_GLOBAL_NAME", "[Validation] global variable is not found in the program">,
+ Etor<"ERROR_INVALID_FUNCTION_NAME", "[Validation] function name is not found in the program">,
+ Etor<"ERROR_INVALID_GROUP_SIZE_DIMENSION", "[Validation] group size dimension is not valid for the kernel or device">,
+ Etor<"ERROR_INVALID_GLOBAL_WIDTH_DIMENSION", "[Validation] global width dimension is not valid for the kernel or device">,
+ Etor<"ERROR_PROGRAM_UNLINKED", "[Validation] compiled program or program with imports needs to be linked before kernels can be created from it.">,
+ Etor<"ERROR_OVERLAPPING_REGIONS", "[Validation] copy operations do not support overlapping regions of memory">,
+ Etor<"ERROR_INVALID_HOST_PTR", "Invalid host pointer">,
+ Etor<"ERROR_INVALID_USM_SIZE", "Invalid USM size">,
+ Etor<"ERROR_OBJECT_ALLOCATION_FAILURE", "Objection allocation failure">,
+ Etor<"ERROR_ADAPTER_SPECIFIC", "An adapter specific warning/error has been reported and can be retrieved via the urPlatformGetLastError entry point.">,
+ Etor<"ERROR_LAYER_NOT_PRESENT", "A requested layer was not found by the loader.">,
+ Etor<"ERROR_IN_EVENT_LIST_EXEC_STATUS", "An event in the provided wait list has OL_EVENT_STATUS_ERROR.">,
+ Etor<"ERROR_UNKNOWN", "Unknown or internal error">
+ ];
+def : Struct {
+ let name = "ol_base_properties_t";
+ let desc = "Base for all properties types";
+ let members = [
+ StructMember<"ol_structure_type_t", "stype", "[in] type of this structure">,
+ StructMember<"void*", "pNext", "[in,out][optional] pointer to extension-specific structure">
+ ];
+def : Struct {
+ let name = "ol_base_desc_t";
+ let desc = "Base for all descriptor types";
+ let members = [
+ StructMember<"ol_structure_type_t", "stype", "[in] type of this structure">,
+ StructMember<"const void*", "pNext", "[in][optional] pointer to extension-specific structure">
+ ];
+def : Struct {
+ let name = "ol_rect_offset_t";
+ let desc = "3D offset argument passed to buffer rect operations";
+ let members = [
+ StructMember<"uint64_t", "x", "[in] x offset (bytes)">,
+ StructMember<"uint64_t", "y", "[in] y offset (scalar)">,
+ StructMember<"uint64_t", "z", "[in] z offset (scalar)">
+ ];
+def : Struct {
+ let name = "ol_rect_region_t";
+ let desc = "3D region argument passed to buffer rect operations";
+ let members = [
+ StructMember<"uint64_t", "width", "[in] width (bytes)">,
+ StructMember<"uint64_t", "height", "[in] height (scalar)">,
+ StructMember<"uint64_t", "depth", "[in] scalar (scalar)">
+ ];
+def : Enum {
+ let name = "ol_queue_info_t";
+ let desc = "Query queue info";
+ let is_typed = 1;
+ let etors =[
+ Etor<"CONTEXT", "[ol_context_handle_t] context associated with this queue.">,
+ Etor<"DEVICE", "[ol_device_handle_t] device associated with this queue.">,
+ Etor<"DEVICE_DEFAULT", "[ol_queue_handle_t] the current default queue of the underlying device.">,
+ Etor<"FLAGS", "[ol_queue_flags_t] the properties associated with ol_queue_properties_t::flags.">,
+ Etor<"REFERENCE_COUNT", [{[uint32_t] Reference count of the queue object.
+The reference count returned should be considered immediately stale.
+It is unsuitable for general use in applications. This feature is provided for identifying memory leaks.}]>,
+ Etor<"SIZE", "[uint32_t] The size of the queue">,
+ Etor<"EMPTY", "[ol_bool_t] return true if the queue was empty at the time of the query">
+ ];
+def : Enum {
+ let name = "ol_queue_flags_t";
+ let desc = "Queue property flags";
+ let etors =[
+ Etor<"OUT_OF_ORDER_EXEC_MODE_ENABLE", "Enable/disable out of order execution">,
+ Etor<"PROFILING_ENABLE", "Enable/disable profiling">,
+ Etor<"ON_DEVICE", "Is a device queue">,
+ Etor<"ON_DEVICE_DEFAULT", "Is the default queue for a device">,
+ Etor<"DISCARD_EVENTS", "Events will be discarded">,
+ Etor<"PRIORITY_LOW", "Low priority queue">,
+ Etor<"PRIORITY_HIGH", "High priority queue">,
+ Etor<"SUBMISSION_BATCHED", "Hint: enqueue and submit in a batch later. No change in queue semantics. Implementation chooses submission mode.">,
+ Etor<"SUBMISSION_IMMEDIATE", "Hint: enqueue and submit immediately. No change in queue semantics. Implementation chooses submission mode.">,
+ Etor<"USE_DEFAULT_STREAM", "Use the default stream. Only meaningful for CUDA. Other platforms may ignore this flag.">,
+ Etor<"SYNC_WITH_DEFAULT_STREAM", "Synchronize with the default stream. Only meaningful for CUDA. Other platforms may ignore this flag.">
+ ];
+def : Function<"olQueue"> {
+ let name = "GetInfo";
+ let desc = "Query information about a command queue";
+ let params = [
+ Param<"ol_queue_handle_t", "hQueue", "handle of the queue object", PARAM_IN>,
+ Param<"ol_queue_info_t", "propName", "name of the queue property to query", PARAM_IN>,
+ Param<"size_t", "propSize", "size in bytes of the queue property value provided", PARAM_IN>,
+ Param<"void*", "pPropValue", "[typename(propName, propSize)] value of the queue property", !or(PARAM_OUT, PARAM_OPTIONAL)>,
+ Param<"size_t*", "pPropSizeRet", "size in bytes returned in queue property value", !or(PARAM_OUT, PARAM_OPTIONAL)>
+ ];
+ let returns = [
+ "If `propName` is not supported by the adapter."
+ ]>,
+ "`propSize == 0 && pPropValue != NULL`",
+ "If `propSize` is less than the real number of bytes needed to return the info."
+ ]>,
+ "`propSize != 0 && pPropValue == NULL`",
+ "`pPropValue == NULL && pPropSizeRet == NULL`"
+ ]>,
+ ];
+def : Struct {
+ let name = "ol_queue_properties_t";
+ let desc = "Queue creation properties";
+ let base_class = "ol_base_properties_t";
+ let members = [
+ StructMember<"ol_queue_flags_t", "flags", "[in] Bitfield of queue creation flags">
+ ];
+def : Struct {
+ let name = "ol_queue_index_properties_t";
+ let desc = "Queue index creation properties";
+ let base_class = "ol_base_properties_t";
+ let members = [
+ StructMember<"uint32_t", "computeIndex", "[in] Specifies the compute index as described in the sycl_ext_intel_queue_index extension.">
+ ];
+def : Function<"olQueue"> {
+ let name = "Create";
+ let desc = "Create a command queue for a device in a context";
+ let details = [
+ "See also ol_queue_index_properties_t."
+ ];
+ let params = [
+ Param<"ol_context_handle_t", "hContext", "handle of the context object", PARAM_IN>,
+ Param<"ol_device_handle_t", "hDevice", "handle of the device object", PARAM_IN>,
+ Param<"const ol_queue_properties_t*", "pProperties", "pointer to queue creation properties.", !or(PARAM_IN, PARAM_OPTIONAL)>,
+ Param<"ol_queue_handle_t*", "phQueue", "pointer to handle of queue object created", PARAM_OUT>
+ ];
+ let returns = [
+ "`pProperties != NULL && pProperties->flags & OL_QUEUE_FLAG_PRIORITY_HIGH && pProperties->flags & OL_QUEUE_FLAG_PRIORITY_LOW`",
+ "`pProperties != NULL && pProperties->flags & OL_QUEUE_FLAG_SUBMISSION_BATCHED && pProperties->flags & OL_QUEUE_FLAG_SUBMISSION_IMMEDIATE`"
+ ]>,
+ ];
+def : Function<"olQueue"> {
+ let name = "Retain";
+ let desc = "Get a reference to the command queue handle. Increment the command queue's reference count";
+ let details = [
+ "Useful in library function to retain access to the command queue after the caller released the queue."
+ ];
+ let params = [
+ Param<"ol_queue_handle_t", "hQueue", "handle of the queue object to get access", PARAM_IN>
+ ];
+ let returns = [
+ ];
+def : Function<"olQueue"> {
+ let name = "Release";
+ let desc = "Decrement the command queue's reference count and delete the command queue if the reference count becomes zero.";
+ let details = [
+ "After the command queue reference count becomes zero and all queued commands in the queue have finished, the queue is deleted.",
+ "It also performs an implicit flush to issue all previously queued commands in the queue."
+ ];
+ let params = [
+ Param<"ol_queue_handle_t", "hQueue", "handle of the queue object to release", PARAM_IN>
+ ];
+ let returns = [
+ ];
+def : Struct {
+ let name = "ol_queue_native_desc_t";
+ let desc = "Descriptor for olQueueGetNativeHandle and olQueueCreateWithNativeHandle.";
+ let base_class = "ol_base_desc_t";
+ let members = [
+ StructMember<"void*", "pNativeData", "[in][optional] Adapter-specific metadata needed to create the handle.">
+ ];
+def : Function<"olQueue"> {
+ let name = "GetNativeHandle";
+ let desc = "Return queue native queue handle.";
+ let details = [
+ "Retrieved native handle can be used for direct interaction with the native platform driver.",
+ "Use interoperability queue extensions to convert native handle to native type.",
+ "The application may call this function from simultaneous threads for the same context.",
+ "The implementation of this function should be thread-safe."
+ ];
+ let params = [
+ Param<"ol_queue_handle_t", "hQueue", "handle of the queue.", PARAM_IN>,
+ Param<"ol_queue_native_desc_t*", "pDesc", "pointer to native descriptor", !or(PARAM_IN, PARAM_OPTIONAL)>,
+ Param<"ol_native_handle_t*", "phNativeQueue", "a pointer to the native handle of the queue.", PARAM_OUT>
+ ];
+ let returns = [
+ "If the adapter has no underlying equivalent handle."
+ ]>
+ ];
+def : Struct {
+ let name = "ol_queue_native_properties_t";
+ let desc = "Properties for for olQueueCreateWithNativeHandle.";
+ let base_class = "ol_base_properties_t";
+ let members = [
+ StructMember<"bool", "isNativeHandleOwned", [{[in] Indicates UR owns the native handle or if it came from an interoperability
+operation in the application that asked to not transfer the ownership to
+the unified-runtime.}]>
+ ];
+def : Function<"olQueue"> {
+ let name = "CreateWithNativeHandle";
+ let desc = "Create runtime queue object from native queue handle.";
+ let details = [
+ "Creates runtime queue handle from native driver queue handle.",
+ "The application may call this function from simultaneous threads for the same context.",
+ "The implementation of this function should be thread-safe."
+ ];
+ let params = [
+ Param<"ol_native_handle_t", "hNativeQueue", "[nocheck] the native handle of the queue.", PARAM_IN>,
+ Param<"ol_context_handle_t", "hContext", "handle of the context object", PARAM_IN>,
+ Param<"ol_device_handle_t", "hDevice", "handle of the device object", PARAM_IN>,
+ Param<"const ol_queue_native_properties_t*", "pProperties", "pointer to native queue properties struct", !or(PARAM_IN, PARAM_OPTIONAL)>,
+ Param<"ol_queue_handle_t*", "phQueue", "pointer to the handle of the queue object created.", PARAM_OUT>
+ ];
+ let returns = [
+ "If the adapter has no underlying equivalent handle."
+ ]>
+ ];
+def : Function<"olQueue"> {
+ let name = "Finish";
+ let desc = "Blocks until all previously issued commands to the command queue are finished.";
+ let details = [
+ "Blocks until all previously issued commands to the command queue are issued and completed.",
+ "olQueueFinish does not return until all enqueued commands have been processed and finished.",
+ "olQueueFinish acts as a synchronization point."
+ ];
+ let params = [
+ Param<"ol_queue_handle_t", "hQueue", "handle of the queue to be finished.", PARAM_IN>
+ ];
+ let returns = [
+ ];
+def : Function<"olQueue"> {
+ let name = "Flush";
+ let desc = "Issues all previously enqueued commands in a command queue to the device.";
+ let details = [
+ "Guarantees that all enqueued commands will be issued to the appropriate device.",
+ "There is no guarantee that they will be completed after olQueueFlush returns."
+ ];
+ let params = [
+ Param<"ol_queue_handle_t", "hQueue", "handle of the queue to be flushed.", PARAM_IN>
+ ];
+ let returns = [
+ ];
diff --git a/offload/API/OffloadAPI.td b/offload/API/OffloadAPI.td
new file mode 100644
index 00000000000000..4bfaec29568951
--- /dev/null
+++ b/offload/API/OffloadAPI.td
@@ -0,0 +1,5 @@
+// Always include this file first
+include "APIDefs.td"
+// Add API definition files here
+include "Example.td"
diff --git a/offload/API/README.md b/offload/API/README.md
new file mode 100644
index 00000000000000..4bc69b4740e784
--- /dev/null
+++ b/offload/API/README.md
@@ -0,0 +1,103 @@
+# Offload API definitions
+**Note**: This is a work-in-progress. The intention is for this to serve as a
+starting off point for design discussion. It is loosely based on equivalent
+tooling in Unified Runtime.
+The Tablegen files in this directory are used to define the Offload API. They
+are used with the `offload-tblgen` tool to generate API headers and (stub)
+validation code. There are plans to add support for tracing, printing (e.g.
+adding `operator<<(std::ostream)` defs to API structs, enums, etc), and test
+The root file is `OffloadAPI.td` - additional `.td` files can be included in
+this file to add them to the API.
+## API Objects
+The API consists of a number of objects, which always have a *name* field and
+*description* field, and are one of the following types:
+### Function
+Represents an API entry point function. Has a list of returns and parameters.
+Also has fields for details (representing a bullet-point list of
+information about the function that would otherwise be too detailed for the
+description), and analogues (equivalent functions in other APIs).
+#### Parameter
+Represents a parameter to a function, has *type*, *name*, and *desc* fields.
+Also has a *flags* field containing flags representing whether the parameter is
+in, out, or optional.
+The *type* field is used to infer if the parameter is a pointer or handle type.
+A *handle* type is a pointer to an opaque struct, used to abstract over
+plugin-specific implementation details.
+#### Return
+A return represents a possible return code from the function, and optionally a
+list of conditions in which this value may be returned. The conditions list is
+not expected to be exhaustive. A condition is considered free-form text, but
+if it is wrapped in \`backticks\` then it is treated as literal code
+representing an error condition (e.g. `someParam < 1`). These conditions are
+used to automatically create validation checks by the `offload-tblgen`
+validation generator.
+Returns are automatically generated for functions with pointer or handle
+parameters, so API authors do not need to exhaustively add null checks for
+these types of parameters. All functions also get a number of default return
+values automatically.
+### Struct
+Represents a struct. Contains a list of members, which each have a *type*,
+*name*, and *desc*.
+Also optionally takes a *base_class* field. If this is either of the special
+`ol_base_properties_t` or `ol_base_desc_t` structs, then the struct will inherit
+members from those structs. The generated struct does **not** use actual C++
+inheritance, but instead explicitly has those members copied in, which preserves
+compatibility with C.
+### Enum
+Represents a C-style enum. Contains a list of `etor` values.
+All enums automatically get a `<enum_name>_FORCE_UINT32 = 0x7fffffff` value,
+which forces the underlying type to be uint32.
+### Handle
+Represents a pointer to an opaque struct, as described in the Parameter section.
+It does not take any extra fields.
+### Typedef
+Represents a typedef, contains only a *value* field.
+### Macro
+Represents a C preprocessor `#define`. Contains a *value* field. Optionally
+takes a *condition* field, which allows the macro to be conditionally defined,
+and an *alt_value* field, which represents the value if the condition is false.
+Macro arguments are presented in the *name* field (e.g. name = `mymacro(arg)`).
+While there may seem little point generating a macro from tablegen, doing this
+allows the entire source of the header file to be generated from the tablegen
+files, rather than requiring a mix of C source and tablegen.
+## Generation
+### API header
+./offload-tblgen -I <path-to-llvm>/offload/API <path-to-llvm>/offload/API/OffloadAPI.td --gen-api
+The comments in the generated header are in Doxygen format, although
+generating documentation from them hasn't been tested yet.
+### Validation functions
+./offload-tblgen -I <path-to-llvm>/offload/API <path-to-llvm>/offload/API/OffloadAPI.td --gen-validation
+The functions are partially stubbed and are designed to be used in conjunction
+with code that can track live handle references, etc. See the equivalent code
+in Unified Runtime for an idea of how this might work.
+### Future Tablegen backends
+`RecordTypes.hpp` contains wrappers for all of the API object types, which will
+allow more backends to be easily added in future.
diff --git a/offload/CMakeLists.txt b/offload/CMakeLists.txt
new file mode 100644
index 00000000000000..39a299d350cc85
--- /dev/null
+++ b/offload/CMakeLists.txt
@@ -0,0 +1,15 @@
+# TODO: #75125 isn't merged yet, use some hacky CMake to get things building in
+# the mean time.
diff --git a/offload/tools/offload-tblgen/APIGen.cpp b/offload/tools/offload-tblgen/APIGen.cpp
new file mode 100644
index 00000000000000..7ee335887154c4
--- /dev/null
+++ b/offload/tools/offload-tblgen/APIGen.cpp
@@ -0,0 +1,156 @@
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+#include "GenCommon.hpp"
+#include "RecordTypes.hpp"
+using namespace llvm;
+using namespace offload::tblgen;
+// Produce a possibly multi-line comment from the input string
+static std::string MakeComment(StringRef in) {
+ std::string out = "";
+ size_t LineStart = 0;
+ size_t LineBreak = 0;
+ while (LineBreak < in.size()) {
+ LineBreak = in.find_first_of("\n", LineStart);
+ if (LineBreak - LineStart <= 1) {
+ break;
+ }
+ out += std::string("\t///< ") +
+ in.substr(LineStart, LineBreak - LineStart).str() + "\n";
+ LineStart = LineBreak + 1;
+ }
+ return out;
+static void ProcessHandle(const HandleRec &H, raw_ostream &OS) {
+ OS << CommentsHeader;
+ OS << formatv("/// @brief {0}\n", H.getDesc());
+ OS << formatv("typedef struct {0}_ *{0};\n", H.getName());
+static void ProcessTypedef(const TypedefRec &T, raw_ostream &OS) {
+ OS << CommentsHeader;
+ OS << formatv("/// @brief {0}\n", T.getDesc());
+ OS << formatv("typedef {0} {1};\n", T.getValue(), T.getName());
+static void ProcessMacro(const MacroRec &M, raw_ostream &OS) {
+ OS << CommentsHeader;
+ OS << formatv("#ifndef {0}\n", M.getName());
+ if (auto Condition = M.getCondition()) {
+ OS << formatv("#if {0}\n", *Condition);
+ }
+ OS << "/// @brief " << M.getDesc() << "\n";
+ OS << formatv("#define {0} {1}\n", M.getName(), M.getValue());
+ if (auto AltValue = M.getAltValue()) {
+ OS << "#else\n";
+ OS << formatv("#define {0} {1}\n", M.getName(), *AltValue);
+ }
+ if (auto Condition = M.getCondition()) {
+ OS << formatv("#endif // {0}\n", *Condition);
+ }
+ OS << formatv("#endif // {0}\n", M.getName());
+static void ProcessFunction(const FunctionRec &F, raw_ostream &OS) {
+ OS << CommentsHeader;
+ OS << formatv("/// @brief {0}\n", F.getDesc());
+ OS << CommentsBreak;
+ OS << "/// @details\n";
+ for (auto &Detail : F.getDetails()) {
+ OS << formatv("/// - {0}\n", Detail);
+ }
+ OS << CommentsBreak;
+ // Emit analogue remarks
+ auto Analogues = F.getAnalogues();
+ if (!Analogues.empty()) {
+ OS << "/// @remarks\n/// _Analogues_\n";
+ for (auto &Analogue : Analogues) {
+ OS << formatv("/// - **{0}**\n", Analogue);
+ }
+ OS << CommentsBreak;
+ }
+ OS << "/// @returns\n";
+ auto Returns = F.getReturns();
+ for (auto &Ret : Returns) {
+ OS << formatv("/// - ::{0}\n", Ret.getValue());
+ auto RetConditions = Ret.getConditions();
+ for (auto &RetCondition : RetConditions) {
+ OS << formatv("/// + {0}\n", RetCondition);
+ }
+ }
+ OS << formatv("{0}_APIEXPORT {1}_result_t {0}_APICALL ", PrefixUpper,
+ PrefixLower);
+ OS << F.getFullName();
+ OS << "(\n";
+ auto Params = F.getParams();
+ for (auto &Param : Params) {
+ OS << " " << Param.getType() << " " << Param.getName();
+ if (Param != Params.back()) {
+ OS << ", ";
+ } else {
+ OS << " ";
+ }
+ OS << MakeParamComment(Param) << "\n";
+ }
+ OS << ");\n\n";
+static void ProcessEnum(const EnumRec &Enum, raw_ostream &OS) {
+ OS << CommentsHeader;
+ OS << formatv("/// @brief {0}\n", Enum.getDesc());
+ OS << formatv("typedef enum {0} {{\n", Enum.getName());
+ uint32_t EtorVal = 0;
+ for (const auto &EnumVal : Enum.getValues()) {
+ auto Desc = MakeComment(EnumVal.getDesc());
+ OS << formatv(" {0}_{1} = {2}, {3}", Enum.getEnumValNamePrefix(),
+ EnumVal.getName(), EtorVal++, Desc);
+ }
+ // Add force uint32 val
+ OS << formatv(
+ " /// @cond\n {0}_FORCE_UINT32 = 0x7fffffff\n /// @endcond\n\n",
+ Enum.getEnumValNamePrefix());
+ OS << formatv("} {0};\n", Enum.getName());
+static void ProcessStruct(const StructRec &Struct, raw_ostream &OS) {
+ OS << CommentsHeader;
+ OS << formatv("/// @brief {0}\n", Struct.getDesc());
+ OS << formatv("typedef struct {0} {{\n", Struct.getName());
+ for (const auto &Member : Struct.getMembers()) {
+ OS << formatv(" {0} {1}; {2}", Member.getType(), Member.getName(),
+ MakeComment(Member.getDesc()));
+ }
+ OS << formatv("} {0};\n\n", Struct.getName());
+void EmitOffloadAPI(RecordKeeper &Records, raw_ostream &OS) {
+ for (auto *R : Records.getAllDerivedDefinitions("APIObject")) {
+ if (R->isSubClassOf("Macro")) {
+ ProcessMacro(MacroRec{R}, OS);
+ } else if (R->isSubClassOf("Typedef")) {
+ ProcessTypedef(TypedefRec{R}, OS);
+ } else if (R->isSubClassOf("Handle")) {
+ ProcessHandle(HandleRec{R}, OS);
+ } else if (R->isSubClassOf("Function")) {
+ ProcessFunction(FunctionRec{R}, OS);
+ } else if (R->isSubClassOf("Enum")) {
+ ProcessEnum(EnumRec{R}, OS);
+ } else if (R->isSubClassOf("Struct")) {
+ ProcessStruct(StructRec{R}, OS);
+ }
+ }
diff --git a/offload/tools/offload-tblgen/CMakeLists.txt b/offload/tools/offload-tblgen/CMakeLists.txt
new file mode 100644
index 00000000000000..61c733b8bc46ec
--- /dev/null
+++ b/offload/tools/offload-tblgen/CMakeLists.txt
@@ -0,0 +1,12 @@
+ Demangle
+ Support
+ TableGen
+add_tablegen(offload-tblgen Offload
+ EXPORT Offload
+ APIGen.cpp
+ offload-tblgen.cpp
+ ValidationGen.cpp
+ )
diff --git a/offload/tools/offload-tblgen/GenCommon.hpp b/offload/tools/offload-tblgen/GenCommon.hpp
new file mode 100644
index 00000000000000..c5634e1e99fc23
--- /dev/null
+++ b/offload/tools/offload-tblgen/GenCommon.hpp
@@ -0,0 +1,20 @@
+#pragma once
+#include "RecordTypes.hpp"
+#include "llvm/Support/FormatVariadic.h"
+constexpr auto CommentsHeader = R"(
+constexpr auto CommentsBreak = "///\n";
+constexpr auto PrefixLower = "ol";
+constexpr auto PrefixUpper = "OL";
+static std::string
+MakeParamComment(const llvm::offload::tblgen::ParamRec &Param) {
+ return llvm::formatv("///< {0}{1}{2} {3}", (Param.isIn() ? "[in]" : ""),
+ (Param.isOut() ? "[out]" : ""),
+ (Param.isOpt() ? "[optional]" : ""), Param.getDesc());
diff --git a/offload/tools/offload-tblgen/Generators.hpp b/offload/tools/offload-tblgen/Generators.hpp
new file mode 100644
index 00000000000000..b3e568f43025fe
--- /dev/null
+++ b/offload/tools/offload-tblgen/Generators.hpp
@@ -0,0 +1,4 @@
+#include "llvm/TableGen/Record.h"
+void EmitOffloadAPI(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
+void EmitOffloadValidation(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
diff --git a/offload/tools/offload-tblgen/RecordTypes.hpp b/offload/tools/offload-tblgen/RecordTypes.hpp
new file mode 100644
index 00000000000000..286737b1ec00a2
--- /dev/null
+++ b/offload/tools/offload-tblgen/RecordTypes.hpp
@@ -0,0 +1,186 @@
+#pragma once
+#include <string>
+#include "llvm/TableGen/Record.h"
+namespace llvm {
+namespace offload {
+namespace tblgen {
+class HandleRec {
+ explicit HandleRec(Record *rec) : rec(rec) {}
+ StringRef getName() const { return rec->getValueAsString("name"); }
+ StringRef getDesc() const { return rec->getValueAsString("desc"); }
+ Record *rec;
+class MacroRec {
+ explicit MacroRec(Record *rec) : rec(rec) {}
+ StringRef getName() const { return rec->getValueAsString("name"); }
+ StringRef getDesc() const { return rec->getValueAsString("desc"); }
+ std::optional<StringRef> getCondition() const {
+ return rec->getValueAsOptionalString("condition");
+ }
+ StringRef getValue() const { return rec->getValueAsString("value"); }
+ std::optional<StringRef> getAltValue() const {
+ return rec->getValueAsOptionalString("alt_value");
+ }
+ Record *rec;
+class TypedefRec {
+ explicit TypedefRec(Record *rec) : rec(rec) {}
+ StringRef getName() const { return rec->getValueAsString("name"); }
+ StringRef getDesc() const { return rec->getValueAsString("desc"); }
+ StringRef getValue() const { return rec->getValueAsString("value"); }
+ Record *rec;
+class EnumValueRec {
+ explicit EnumValueRec(Record *rec) : rec(rec) {}
+ std::string getName() const { return rec->getValueAsString("name").upper(); }
+ StringRef getDesc() const { return rec->getValueAsString("desc"); }
+ Record *rec;
+class EnumRec {
+ explicit EnumRec(Record *rec) : rec(rec) {
+ for (auto *Val : rec->getValueAsListOfDefs("etors")) {
+ vals.emplace_back(EnumValueRec{Val});
+ }
+ }
+ StringRef getName() const { return rec->getValueAsString("name"); }
+ StringRef getDesc() const { return rec->getValueAsString("desc"); }
+ const std::vector<EnumValueRec> &getValues() const { return vals; }
+ std::string getEnumValNamePrefix() const {
+ return StringRef(getName().str().substr(0, getName().str().length() - 2))
+ .upper();
+ }
+ Record *rec;
+ std::vector<EnumValueRec> vals;
+class StructMemberRec {
+ explicit StructMemberRec(Record *rec) : rec(rec) {}
+ StringRef getType() const { return rec->getValueAsString("type"); }
+ StringRef getName() const { return rec->getValueAsString("name"); }
+ StringRef getDesc() const { return rec->getValueAsString("desc"); }
+ Record *rec;
+class StructRec {
+ explicit StructRec(Record *rec) : rec(rec) {
+ for (auto *Member : rec->getValueAsListOfDefs("all_members")) {
+ members.emplace_back(StructMemberRec(Member));
+ }
+ }
+ StringRef getName() const { return rec->getValueAsString("name"); }
+ StringRef getDesc() const { return rec->getValueAsString("desc"); }
+ std::optional<StringRef> getBaseClass() const {
+ return rec->getValueAsOptionalString("base_class");
+ }
+ const std::vector<StructMemberRec> &getMembers() const { return members; }
+ Record *rec;
+ std::vector<StructMemberRec> members;
+class ParamRec {
+ explicit ParamRec(Record *rec) : rec(rec) {
+ flags = rec->getValueAsBitsInit("flags");
+ }
+ StringRef getName() const { return rec->getValueAsString("name"); }
+ StringRef getType() const { return rec->getValueAsString("type"); }
+ StringRef getDesc() const { return rec->getValueAsString("desc"); }
+ bool isIn() const { return dyn_cast<BitInit>(flags->getBit(0))->getValue(); }
+ bool isOut() const { return dyn_cast<BitInit>(flags->getBit(1))->getValue(); }
+ bool isOpt() const { return dyn_cast<BitInit>(flags->getBit(2))->getValue(); }
+ Record *getRec() const { return rec; }
+ // Needed to check whether we're at the back of a vector of params
+ bool operator!=(const ParamRec &p) const { return rec != p.getRec(); }
+ Record *rec;
+ BitsInit *flags;
+class ReturnRec {
+ ReturnRec(Record *rec) : rec(rec) {}
+ StringRef getValue() const { return rec->getValueAsString("value"); }
+ std::vector<StringRef> getConditions() const {
+ return rec->getValueAsListOfStrings("conditions");
+ }
+ Record *rec;
+class FunctionRec {
+ FunctionRec(Record *rec) : rec(rec) {
+ for (auto &Ret : rec->getValueAsListOfDefs("all_returns"))
+ rets.emplace_back(Ret);
+ for (auto &Param : rec->getValueAsListOfDefs("params"))
+ params.emplace_back(Param);
+ }
+ std::string getFullName() const {
+ return rec->getValueAsString("api_class").str() +
+ rec->getValueAsString("name").str();
+ }
+ StringRef getName() const { return rec->getValueAsString("name"); }
+ StringRef getClass() const { return rec->getValueAsString("api_class"); }
+ const std::vector<ReturnRec> &getReturns() const { return rets; }
+ const std::vector<ParamRec> &getParams() const { return params; }
+ StringRef getDesc() const { return rec->getValueAsString("desc"); }
+ std::vector<StringRef> getDetails() const {
+ return rec->getValueAsListOfStrings("details");
+ }
+ std::vector<StringRef> getAnalogues() const {
+ return rec->getValueAsListOfStrings("analogues");
+ }
+ bool modifiesRefCount() const {
+ auto Name = rec->getValueAsString("name");
+ auto Class = rec->getValueAsString("api_class");
+ return (Name == "Create") || (Name == "Retain") || (Name == "Release") ||
+ (Name == "Get" && Class == "Adapter");
+ }
+ std::vector<ReturnRec> rets;
+ std::vector<ParamRec> params;
+ Record *rec;
+} // namespace tblgen
+} // namespace offload
+} // namespace llvm
diff --git a/offload/tools/offload-tblgen/ValidationGen.cpp b/offload/tools/offload-tblgen/ValidationGen.cpp
new file mode 100644
index 00000000000000..23ab94e7ee1f08
--- /dev/null
+++ b/offload/tools/offload-tblgen/ValidationGen.cpp
@@ -0,0 +1,121 @@
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Record.h"
+#include "GenCommon.hpp"
+#include "RecordTypes.hpp"
+using namespace llvm;
+using namespace offload::tblgen;
+static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
+ OS << CommentsHeader;
+ OS << formatv("/// @brief Intercept function for {0}\n", F.getFullName());
+ // Emit preamble
+ OS << formatv("{0}_result_t {1}_APICALL val_{2}(\n", PrefixLower, PrefixUpper,
+ F.getFullName());
+ // Emit arguments
+ std::string ParamNameList = "";
+ for (auto &Param : F.getParams()) {
+ OS << " " << Param.getType() << " " << Param.getName();
+ if (Param != F.getParams().back()) {
+ OS << ", ";
+ } else {
+ OS << " ";
+ }
+ OS << MakeParamComment(Param) << "\n";
+ ParamNameList += Param.getName().str() + ", ";
+ }
+ OS << ") {\n";
+ OS << " if (true /*enableParameterValidation*/) {\n";
+ // Emit validation checks
+ for (const auto &Return : F.getReturns()) {
+ for (auto &Condition : Return.getConditions()) {
+ if (Condition.starts_with("`") && Condition.ends_with("`")) {
+ auto ConditionString = Condition.substr(1, Condition.size() - 2);
+ OS << formatv(" if ({0}) {{\n", ConditionString);
+ OS << formatv(" return {0};\n", Return.getValue());
+ OS << " }\n\n";
+ }
+ }
+ }
+ OS << " }\n\n";
+ auto LifetimeTodoComment =
+ R"( // TODO: Implement. `refCountContext` is some global object that tracks known
+ // live handle objects, and logs related errors.
+ // In UR this is implemented as an unordered_map of handles to structs
+ // containing the reference count, amongst other details. In this case, a
+ // handle is invalid if it does not exist in the map.
+ bool EmittedTodo = false;
+ // Emit handle lifetime checks
+ for (auto &Param : F.getParams()) {
+ if (Param.getType().ends_with("handle_t")) {
+ // Only add this comment once per function to keep the code size down
+ if (!EmittedTodo) {
+ OS << LifetimeTodoComment;
+ EmittedTodo = true;
+ }
+ OS << formatv(" if (true /* enableLifeTimeValidation && "
+ "!refCountContext.isReferenceValid({0}) */) {{\n",
+ Param.getName());
+ OS << formatv(" // refCountContext.logInvalidReference({0});\n",
+ Param.getName());
+ OS << " }\n\n";
+ }
+ }
+ // Perform actual function call
+ ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
+ OS << formatv(" {0}_result_t result = {1}({2});\n\n", PrefixLower,
+ F.getFullName(), ParamNameList);
+ // Handle reference counting for cases where the function modifies the ref
+ // count of a handle
+ // * `Create` - initialize a reference count
+ // * `Retain` - increment a reference count
+ // * `Release` - decerement a reference count
+ if (F.modifiesRefCount()) {
+ OS << formatv(" if ( /*context.enableLeakChecking &&*/ result == "
+ "{0}_RESULT_SUCCESS) {\n",
+ PrefixUpper);
+ // The refcount context optionally takes a bool specifying whether the
+ // handle being tracked is an adapter handle, as they are counted
+ // differently.
+ // TODO: This behavior is lifted from UR. Offload will likely be different.
+ auto AdapterHandleArg = (F.getClass() == "Adapter") ? "true" : "false";
+ if (F.getName() == "Create") {
+ // We only expect one handle output for these types of functions, but loop
+ // over all params just in case
+ for (auto &Param : F.getParams()) {
+ if (Param.isOut()) {
+ OS << formatv(" // refCountContext.createRefCount(*{0});\n",
+ Param.getName());
+ }
+ }
+ // Retain and release functions only have 1 parameter
+ } else if (F.getName() == "Retain") {
+ OS << formatv(" // refCountContext.incrementRefCount({0}, {1});\n",
+ F.getParams().at(0).getName(), AdapterHandleArg);
+ } else {
+ OS << formatv(" // refCountContext.decrementRefCount({0}, {1});\n",
+ F.getParams().at(0).getName(), AdapterHandleArg);
+ }
+ OS << " }\n";
+ }
+ OS << " return result;\n";
+ OS << "}\n";
+void EmitOffloadValidation(RecordKeeper &Records, raw_ostream &OS) {
+ for (auto *R : Records.getAllDerivedDefinitions("Function")) {
+ EmitValidationFunc(FunctionRec{R}, OS);
+ }
diff --git a/offload/tools/offload-tblgen/offload-tblgen.cpp b/offload/tools/offload-tblgen/offload-tblgen.cpp
new file mode 100644
index 00000000000000..02fd6fffa04e54
--- /dev/null
+++ b/offload/tools/offload-tblgen/offload-tblgen.cpp
@@ -0,0 +1,61 @@
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/TableGen/Main.h"
+#include "llvm/TableGen/Record.h"
+#include "Generators.hpp"
+namespace llvm {
+namespace offload {
+namespace tblgen {
+enum ActionType { PrintRecords, DumpJSON, GenAPI, GenValidation };
+namespace {
+cl::opt<ActionType> Action(
+ cl::desc("Action to perform:"),
+ cl::values(
+ clEnumValN(PrintRecords, "print-records",
+ "Print all records to stdout (default)"),
+ clEnumValN(DumpJSON, "dump-json",
+ "Dump all records as machine-readable JSON"),
+ clEnumValN(GenAPI, "gen-api", "Generate Offload API header contents"),
+ clEnumValN(GenValidation, "gen-validation",
+ "Generate Offload entry point validation functions")));
+static bool OffloadTableGenMain(raw_ostream &OS, RecordKeeper &Records) {
+ switch (Action) {
+ case PrintRecords:
+ OS << Records;
+ break;
+ case DumpJSON:
+ EmitJSON(Records, OS);
+ break;
+ case GenAPI:
+ EmitOffloadAPI(Records, OS);
+ break;
+ case GenValidation:
+ EmitOffloadValidation(Records, OS);
+ break;
+ default:
+ break;
+ }
+ return false;
+int OffloadTblgenMain(int argc, char **argv) {
+ InitLLVM y(argc, argv);
+ cl::ParseCommandLineOptions(argc, argv);
+ return TableGenMain(argv[0], &OffloadTableGenMain);
+ ;
+} // namespace tblgen
+} // namespace offload
+} // namespace llvm
+using namespace llvm;
+using namespace offload::tblgen;
+int main(int argc, char **argv) { return OffloadTblgenMain(argc, argv); }
More information about the llvm-commits
mailing list