[Mlir-commits] [mlir] [mlir][spirv] Add mgpu* wrappers for Vulkan runtime, migrate some tests (PR #123114)

Andrea Faulds llvmlistbot at llvm.org
Thu Jan 16 05:57:53 PST 2025


https://github.com/andfau-amd updated https://github.com/llvm/llvm-project/pull/123114

>From ba17c0b2242aab6bae16c7db07432ef68b691acd Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Thu, 16 Jan 2025 14:57:32 +0100
Subject: [PATCH] [mlir][spirv] Add mgpu* wrappers for Vulkan runtime, migrate
 some tests

This commit adds new wrappers around the MLIR Vulkan runtime which
implement the mgpu* APIs (as generated by GPUToLLVMConversionPass),
adds an optional LLVM lowering to the Vulkan runner mlir-opt pipeline
based on GPUToLLVMConversionPass, and migrates several of the
mlir-vulkan-runner tests to use mlir-cpu-runner instead, together with
the new pipeline and wrappers.

This is a further incremental step towards eliminating
mlir-vulkan-runner and its associated pipeline, passes and wrappers
(#73457). This commit does not migrate all of the tests to the new
system, because changes to the mgpuLaunchKernel ABI will be necessary
to support the tests that use multi-dimensional memref arguments.
---
 .../mlir/Conversion/GPUCommon/GPUCommonPass.h |   3 +-
 mlir/include/mlir/Conversion/Passes.td        |   5 +
 .../GPUCommon/GPUToLLVMConversion.cpp         |  37 +++++-
 .../lib/Pass/TestVulkanRunnerPipeline.cpp     |  22 +++-
 mlir/test/mlir-vulkan-runner/addf.mlir        |   4 +-
 mlir/test/mlir-vulkan-runner/addf_if.mlir     |   4 +-
 .../mlir-vulkan-runner/addui_extended.mlir    |   8 +-
 .../mlir-vulkan-runner/smul_extended.mlir     |   8 +-
 mlir/test/mlir-vulkan-runner/time.mlir        |   4 +-
 .../mlir-vulkan-runner/umul_extended.mlir     |   8 +-
 .../vector-deinterleave.mlir                  |   4 +-
 .../mlir-vulkan-runner/vector-interleave.mlir |   4 +-
 .../mlir-vulkan-runner/vector-shuffle.mlir    |   4 +-
 .../vulkan-runtime-wrappers.cpp               | 115 ++++++++++++++++++
 14 files changed, 198 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
index 094360e75ab617..cf0c96f0eba000 100644
--- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
@@ -64,7 +64,8 @@ struct FunctionCallBuilder {
 /// populate converter for gpu types.
 void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                          RewritePatternSet &patterns,
-                                         bool kernelBarePtrCallConv = false);
+                                         bool kernelBarePtrCallConv = false,
+                                         bool typeCheckKernelArgs = false);
 
 /// A function that maps a MemorySpace enum to a target-specific integer value.
 using MemorySpaceMapping = std::function<unsigned(gpu::AddressSpace)>;
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index afeed370ce3475..6e37ad3f7a017e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -517,6 +517,11 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
            /*default=*/"false",
              "Use bare pointers to pass memref arguments to kernels. "
              "The kernel must use the same setting for this option."
+          >,
+    Option<"typeCheckKernelArgs", "type-check-kernel-args", "bool",
+           /*default=*/"false",
+             "Require all kernel arguments to be memrefs of rank 1 and with a "
+             "32-bit element size."
           >
   ];
 
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 83208e0c42da2f..81c3063173d032 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -427,9 +427,11 @@ class LegalizeLaunchFuncOpPattern
     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
 public:
   LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
-                              bool kernelBarePtrCallConv)
+                              bool kernelBarePtrCallConv,
+                              bool typeCheckKernelArgs)
       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
-        kernelBarePtrCallConv(kernelBarePtrCallConv) {}
+        kernelBarePtrCallConv(kernelBarePtrCallConv),
+        typeCheckKernelArgs(typeCheckKernelArgs) {}
 
 private:
   LogicalResult
@@ -437,6 +439,7 @@ class LegalizeLaunchFuncOpPattern
                   ConversionPatternRewriter &rewriter) const override;
 
   bool kernelBarePtrCallConv;
+  bool typeCheckKernelArgs;
 };
 
 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
@@ -563,8 +566,8 @@ void GpuToLLVMConversionPass::runOnOperation() {
   populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
                                                     target);
-  populateGpuToLLVMConversionPatterns(converter, patterns,
-                                      kernelBarePtrCallConv);
+  populateGpuToLLVMConversionPatterns(
+      converter, patterns, kernelBarePtrCallConv, typeCheckKernelArgs);
 
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
@@ -966,6 +969,26 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
   // stream must be created to pass to subsequent operations.
   else if (launchOp.getAsyncToken())
     stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
+
+  if (typeCheckKernelArgs) {
+    // The current non-bare-pointer ABI is a bad fit for `mgpuLaunchKernel`,
+    // which takes an untyped list of arguments. The type check here prevents
+    // accidentally violating the assumption made in vulkan-runtime-wrappers.cpp
+    // and creating a unchecked runtime ABI mismatch.
+    // TODO: Change the ABI here to remove the need for this type check.
+    for (Value arg : launchOp.getKernelOperands()) {
+      if (auto t = dyn_cast<MemRefType>(arg.getType())) {
+        if (t.getRank() != 1 || t.getElementTypeBitWidth() != 32) {
+          return rewriter.notifyMatchFailure(
+              launchOp, "Operand to launch op is not a rank-1 memref with "
+                        "32-bit element type.");
+        }
+      } else {
+        return rewriter.notifyMatchFailure(
+            launchOp, "Operand to launch op is not a memref.");
+      }
+    }
+  }
   // Lower the kernel operands to match kernel parameters.
   // Note: If `useBarePtrCallConv` is set in the type converter's options,
   // the value of `kernelBarePtrCallConv` will be ignored.
@@ -1737,7 +1760,8 @@ LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
 
 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns,
-                                               bool kernelBarePtrCallConv) {
+                                               bool kernelBarePtrCallConv,
+                                               bool typeCheckKernelArgs) {
   addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
   addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
   addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
@@ -1774,7 +1798,8 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
                ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
                ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
-  patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv);
+  patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
+                                            typeCheckKernelArgs);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
index 789c4d76cee0d5..a3624eb31e26e5 100644
--- a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
+++ b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
@@ -11,9 +11,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
+#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -29,6 +33,9 @@ struct VulkanRunnerPipelineOptions
   Option<bool> spirvWebGPUPrepare{
       *this, "spirv-webgpu-prepare",
       llvm::cl::desc("Run MLIR transforms used when targetting WebGPU")};
+  Option<bool> toLlvm{*this, "to-llvm",
+                      llvm::cl::desc("Run MLIR transforms to lower host code "
+                                     "to LLVM, intended for mlir-cpu-runner")};
 };
 
 void buildTestVulkanRunnerPipeline(OpPassManager &passManager,
@@ -56,6 +63,19 @@ void buildTestVulkanRunnerPipeline(OpPassManager &passManager,
     spirvModulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
 
   passManager.addPass(createGpuModuleToBinaryPass());
+
+  if (options.toLlvm) {
+    passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
+    passManager.nest<func::FuncOp>().addPass(
+        LLVM::createRequestCWrappersPass());
+    // vulkan-runtime-wrappers.cpp uses the non-bare-pointer calling convention,
+    // and the type check is needed to prevent accidental ABI mismatches.
+    GpuToLLVMConversionPassOptions opt;
+    opt.hostBarePtrCallConv = false;
+    opt.kernelBarePtrCallConv = false;
+    opt.typeCheckKernelArgs = true;
+    passManager.addPass(createGpuToLLVMConversionPass(opt));
+  }
 }
 
 } // namespace
@@ -65,7 +85,7 @@ void registerTestVulkanRunnerPipeline() {
   PassPipelineRegistration<VulkanRunnerPipelineOptions>(
       "test-vulkan-runner-pipeline",
       "Runs a series of passes for lowering GPU-dialect MLIR to "
-      "SPIR-V-dialect MLIR intended for mlir-vulkan-runner.",
+      "SPIR-V-dialect MLIR intended for mlir-vulkan-runner or mlir-cpu-runner.",
       buildTestVulkanRunnerPipeline);
 }
 } // namespace mlir::test
diff --git a/mlir/test/mlir-vulkan-runner/addf.mlir b/mlir/test/mlir-vulkan-runner/addf.mlir
index d435f75a288052..71f87a8b0d5c82 100644
--- a/mlir/test/mlir-vulkan-runner/addf.mlir
+++ b/mlir/test/mlir-vulkan-runner/addf.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
 
 // CHECK: [3.3,  3.3,  3.3,  3.3,  3.3,  3.3,  3.3,  3.3]
 module attributes {
diff --git a/mlir/test/mlir-vulkan-runner/addf_if.mlir b/mlir/test/mlir-vulkan-runner/addf_if.mlir
index 8ae995c65e7e84..6fe51a83482dca 100644
--- a/mlir/test/mlir-vulkan-runner/addf_if.mlir
+++ b/mlir/test/mlir-vulkan-runner/addf_if.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
 
 // CHECK: [3.3,  3.3,  3.3,  3.3,  0,  0,  0,  0]
 module attributes {
diff --git a/mlir/test/mlir-vulkan-runner/addui_extended.mlir b/mlir/test/mlir-vulkan-runner/addui_extended.mlir
index b8db4514214591..0894bc301f2e3c 100644
--- a/mlir/test/mlir-vulkan-runner/addui_extended.mlir
+++ b/mlir/test/mlir-vulkan-runner/addui_extended.mlir
@@ -1,13 +1,13 @@
 // Make sure that addition with carry produces expected results
 // with and without expansion to primitive add/cmp ops for WebGPU.
 
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - \
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - \
 // RUN:     --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
 // RUN:     --entry-point-result=void | FileCheck %s
 
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline=spirv-webgpu-prepare \
-// RUN:   | mlir-vulkan-runner - \
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline="spirv-webgpu-prepare to-llvm" \
+// RUN:   | mlir-cpu-runner - \
 // RUN:     --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
 // RUN:     --entry-point-result=void | FileCheck %s
 
diff --git a/mlir/test/mlir-vulkan-runner/smul_extended.mlir b/mlir/test/mlir-vulkan-runner/smul_extended.mlir
index 334aec843e1977..0ef86f46562e86 100644
--- a/mlir/test/mlir-vulkan-runner/smul_extended.mlir
+++ b/mlir/test/mlir-vulkan-runner/smul_extended.mlir
@@ -1,13 +1,13 @@
 // Make sure that signed extended multiplication produces expected results
 // with and without expansion to primitive mul/add ops for WebGPU.
 
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - \
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - \
 // RUN:     --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
 // RUN:     --entry-point-result=void | FileCheck %s
 
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline=spirv-webgpu-prepare \
-// RUN:   | mlir-vulkan-runner - \
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline="spirv-webgpu-prepare to-llvm" \
+// RUN:   | mlir-cpu-runner - \
 // RUN:     --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
 // RUN:     --entry-point-result=void | FileCheck %s
 
diff --git a/mlir/test/mlir-vulkan-runner/time.mlir b/mlir/test/mlir-vulkan-runner/time.mlir
index 6a0bfef36793b6..f6284478742389 100644
--- a/mlir/test/mlir-vulkan-runner/time.mlir
+++ b/mlir/test/mlir-vulkan-runner/time.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
 
 // CHECK: Compute shader execution time
 // CHECK: Command buffer submit time
diff --git a/mlir/test/mlir-vulkan-runner/umul_extended.mlir b/mlir/test/mlir-vulkan-runner/umul_extended.mlir
index 803b8c3d336d33..5936c808435c19 100644
--- a/mlir/test/mlir-vulkan-runner/umul_extended.mlir
+++ b/mlir/test/mlir-vulkan-runner/umul_extended.mlir
@@ -1,13 +1,13 @@
 // Make sure that unsigned extended multiplication produces expected results
 // with and without expansion to primitive mul/add ops for WebGPU.
 
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - \
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - \
 // RUN:     --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
 // RUN:     --entry-point-result=void | FileCheck %s
 
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline=spirv-webgpu-prepare \
-// RUN:   | mlir-vulkan-runner - \
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline="spirv-webgpu-prepare to-llvm" \
+// RUN:   | mlir-cpu-runner - \
 // RUN:     --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
 // RUN:     --entry-point-result=void | FileCheck %s
 
diff --git a/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir b/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir
index 097f3905949d82..ebeb19cd6bcc5f 100644
--- a/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir
+++ b/mlir/test/mlir-vulkan-runner/vector-deinterleave.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - \
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - \
 // RUN:     --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
 // RUN:     --entry-point-result=void | FileCheck %s
 
diff --git a/mlir/test/mlir-vulkan-runner/vector-interleave.mlir b/mlir/test/mlir-vulkan-runner/vector-interleave.mlir
index 5dd4abbd1fb19e..9314baf9b39c7e 100644
--- a/mlir/test/mlir-vulkan-runner/vector-interleave.mlir
+++ b/mlir/test/mlir-vulkan-runner/vector-interleave.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - \
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - \
 // RUN:     --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
 // RUN:     --entry-point-result=void | FileCheck %s
 
diff --git a/mlir/test/mlir-vulkan-runner/vector-shuffle.mlir b/mlir/test/mlir-vulkan-runner/vector-shuffle.mlir
index be97b48b1812ea..cf3e2c569426b7 100644
--- a/mlir/test/mlir-vulkan-runner/vector-shuffle.mlir
+++ b/mlir/test/mlir-vulkan-runner/vector-shuffle.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - \
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - \
 // RUN:     --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils \
 // RUN:     --entry-point-result=void | FileCheck %s
 
diff --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
index f1ed571734459e..aa604eac4b6406 100644
--- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
@@ -13,6 +13,8 @@
 #include <iostream>
 #include <mutex>
 #include <numeric>
+#include <string>
+#include <vector>
 
 #include "VulkanRuntime.h"
 
@@ -26,6 +28,37 @@
 
 namespace {
 
+class VulkanModule;
+
+// Class to be a thing that can be returned from `mgpuModuleGetFunction`.
+struct VulkanFunction {
+  VulkanModule *module;
+  std::string name;
+
+  VulkanFunction(VulkanModule *module, const char *name)
+      : module(module), name(name) {}
+};
+
+// Class to own a copy of the SPIR-V provided to `mgpuModuleLoad` and to manage
+// allocation of pointers returned from `mgpuModuleGetFunction`.
+class VulkanModule {
+public:
+  VulkanModule(const char *ptr, size_t size) : blob(ptr, ptr + size) {}
+  ~VulkanModule() = default;
+
+  VulkanFunction *getFunction(const char *name) {
+    return functions.emplace_back(std::make_unique<VulkanFunction>(this, name))
+        .get();
+  }
+
+  uint8_t *blobData() { return reinterpret_cast<uint8_t *>(blob.data()); }
+  uint32_t blobSize() { return static_cast<uint32_t>(blob.size()); }
+
+private:
+  std::vector<char> blob;
+  std::vector<std::unique_ptr<VulkanFunction>> functions;
+};
+
 class VulkanRuntimeManager {
 public:
   VulkanRuntimeManager() = default;
@@ -91,6 +124,88 @@ void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex,
 }
 
 extern "C" {
+
+//===----------------------------------------------------------------------===//
+//
+// New wrappers, intended for mlir-cpu-runner. Calls to these are generated by
+// GPUToLLVMConversionPass.
+//
+//===----------------------------------------------------------------------===//
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuStreamCreate() {
+  return new VulkanRuntimeManager();
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager) {
+  delete static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamSynchronize(void *) {
+  // Currently a no-op as the other operations are synchronous.
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleLoad(const void *data,
+                                                  size_t gpuBlobSize) {
+  return new VulkanModule(static_cast<const char *>(data), gpuBlobSize);
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuModuleUnload(void *vkModule) {
+  delete static_cast<VulkanModule *>(vkModule);
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleGetFunction(void *vkModule,
+                                                         const char *name) {
+  return static_cast<VulkanModule *>(vkModule)->getFunction(name);
+}
+
+VULKAN_WRAPPER_SYMBOL_EXPORT void
+mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
+                 size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/,
+                 size_t /*smem*/, void *vkRuntimeManager, void **params,
+                 void ** /*extra*/, size_t paramsCount) {
+  auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
+
+  // The non-bare-pointer memref ABI interacts badly with mgpuLaunchKernel's
+  // signature:
+  // - The memref descriptor struct gets split into several elements, each
+  //   passed as their own "param".
+  // - No metadata is provided as to the rank or element type/size of a memref.
+  //   Here we assume that all MemRefs have rank 1 and an element size of
+  //   4 bytes. This means each descriptor struct will have five members.
+  // TODO: Refactor the ABI/API of mgpuLaunchKernel to use a different ABI for
+  //       memrefs, so that other memref types can also be used. This will allow
+  //       migrating the remaining tests and removal of mlir-vulkan-runner.
+  const size_t paramsPerMemRef = 5;
+  if (paramsCount % paramsPerMemRef) {
+    abort();
+  }
+  const DescriptorSetIndex setIndex = 0;
+  BindingIndex bindIndex = 0;
+  for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) {
+    auto memref = static_cast<MemRefDescriptor<uint32_t, 1> *>(params[i]);
+    bindMemRef<uint32_t, 1>(manager, setIndex, bindIndex, memref);
+    bindIndex++;
+  }
+
+  manager->setNumWorkGroups(NumWorkGroups{static_cast<uint32_t>(gridX),
+                                          static_cast<uint32_t>(gridY),
+                                          static_cast<uint32_t>(gridZ)});
+
+  auto function = static_cast<VulkanFunction *>(vkKernel);
+  manager->setShaderModule(function->module->blobData(),
+                           function->module->blobSize());
+  manager->setEntryPoint(function->name.c_str());
+
+  manager->runOnVulkan();
+}
+
+//===----------------------------------------------------------------------===//
+//
+// Old wrappers, intended for mlir-vulkan-runner. Calls to these are generated
+// by LaunchFuncToVulkanCallsPass.
+//
+//===----------------------------------------------------------------------===//
+
 /// Initializes `VulkanRuntimeManager` and returns a pointer to it.
 VULKAN_WRAPPER_SYMBOL_EXPORT void *initVulkan() {
   return new VulkanRuntimeManager();



More information about the Mlir-commits mailing list