[Mlir-commits] [mlir] [mlir][spirv] Add mgpu* wrappers for Vulkan runtime, migrate some tests (PR #123114)
Jakub Kuderski
llvmlistbot at llvm.org
Thu Jan 16 08:38:07 PST 2025
================
@@ -91,6 +124,91 @@ void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex,
}
extern "C" {
+
+//===----------------------------------------------------------------------===//
+//
+// New wrappers, intended for mlir-cpu-runner. Calls to these are generated by
+// GPUToLLVMConversionPass.
+//
+//===----------------------------------------------------------------------===//
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuStreamCreate() {
+ return new VulkanRuntimeManager();
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager) {
+ delete static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamSynchronize(void *) {
+ // Currently a no-op as the other operations are synchronous.
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleLoad(const void *data,
+ size_t gpuBlobSize) {
+ return new VulkanModule(static_cast<const uint8_t *>(data), gpuBlobSize);
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuModuleUnload(void *vkModule) {
+ delete static_cast<VulkanModule *>(vkModule);
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleGetFunction(void *vkModule,
+ const char *name) {
+ if (!vkModule)
+ abort();
+ return static_cast<VulkanModule *>(vkModule)->getFunction(name);
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void
+mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
+ size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/,
+ size_t /*smem*/, void *vkRuntimeManager, void **params,
+ void ** /*extra*/, size_t paramsCount) {
+ auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
+
+ // The non-bare-pointer memref ABI interacts badly with mgpuLaunchKernel's
+ // signature:
+ // - The memref descriptor struct gets split into several elements, each
+ // passed as their own "param".
+ // - No metadata is provided as to the rank or element type/size of a memref.
+ // Here we assume that all MemRefs have rank 1 and an element size of
+ // 4 bytes. This means each descriptor struct will have five members.
+ // TODO(https://github.com/llvm/llvm-project/issues/73457): Refactor the
+ // ABI/API of mgpuLaunchKernel to use a different ABI for memrefs, so
+ // that other memref types can also be used. This will allow migrating
+ // the remaining tests and removal of mlir-vulkan-runner.
+ const size_t paramsPerMemRef = 5;
+ if (paramsCount % paramsPerMemRef) {
+ abort();
----------------
kuhar wrote:
```suggestion
if (paramsCount % paramsPerMemRef != 0) {
abort();
```
Should we also print some error message?
https://github.com/llvm/llvm-project/pull/123114
More information about the Mlir-commits
mailing list