[Mlir-commits] [mlir] 1090a83 - [mlir][vulkan-runner] Update mlir-vulkan-runner execution driver.

Lei Zhang llvmlistbot at llvm.org
Tue Mar 10 12:58:51 PDT 2020


Author: Denis Khalikov
Date: 2020-03-10T15:58:31-04:00
New Revision: 1090a830692a863ccb091e3fad8cc1a287417493

URL: https://github.com/llvm/llvm-project/commit/1090a830692a863ccb091e3fad8cc1a287417493
DIFF: https://github.com/llvm/llvm-project/commit/1090a830692a863ccb091e3fad8cc1a287417493.diff

LOG: [mlir][vulkan-runner] Update mlir-vulkan-runner execution driver.

* Adds GpuLaunchFuncToVulkanLaunchFunc conversion pass.
* Moves a serialization of the `spirv::Module` from LaunchFuncToVulkanCalls pass to newly created pass.
* Updates LaunchFuncToVulkanCalls instrumentation pass, adds `initVulkan` and `deinitVulkan` runtime calls.
* Adds `bindResource` call to bind specifc resource by the given descriptor set and descriptor binding.
* Eliminates static construction and desctruction of `VulkanRuntimeManager`.

Differential Revision: https://reviews.llvm.org/D75192

Added: 
    mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
    mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir

Modified: 
    mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Conversion/GPUToVulkan/CMakeLists.txt
    mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
    mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
    mlir/test/mlir-vulkan-runner/addf.mlir
    mlir/tools/mlir-vulkan-runner/VulkanRuntime.h
    mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
    mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h b/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
index af2c0629c49a..9a02860bfc1a 100644
--- a/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
+++ b/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
@@ -24,7 +24,10 @@ class ModuleOp;
 template <typename T> class OpPassBase;
 
 std::unique_ptr<OpPassBase<ModuleOp>>
-createConvertGpuLaunchFuncToVulkanCallsPass();
+createConvertVulkanLaunchFuncToVulkanCallsPass();
+
+std::unique_ptr<OpPassBase<mlir::ModuleOp>>
+createConvertGpuLaunchFuncToVulkanLaunchFuncPass();
 
 } // namespace mlir
 #endif // MLIR_CONVERSION_GPUTOVULKAN_CONVERTGPUTOVULKANPASS_H

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 2b21ccb613ac..4df27e1bc1cd 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -128,7 +128,8 @@ inline void registerAllPasses() {
   createLinalgToSPIRVPass();
 
   // Vulkan
-  createConvertGpuLaunchFuncToVulkanCallsPass();
+  createConvertGpuLaunchFuncToVulkanLaunchFuncPass();
+  createConvertVulkanLaunchFuncToVulkanCallsPass();
 }
 
 } // namespace mlir

diff  --git a/mlir/lib/Conversion/GPUToVulkan/CMakeLists.txt b/mlir/lib/Conversion/GPUToVulkan/CMakeLists.txt
index eeafe5f37b97..847c9c5031e9 100644
--- a/mlir/lib/Conversion/GPUToVulkan/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToVulkan/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_conversion_library(MLIRGPUtoVulkanTransforms
   ConvertLaunchFuncToVulkanCalls.cpp
+  ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
   )
 
 target_link_libraries(MLIRGPUtoVulkanTransforms

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
new file mode 100644
index 000000000000..fcfae4563778
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -0,0 +1,173 @@
+//===- ConvertGPULaunchFuncToVulkanLaunchFunc.cpp - MLIR conversion pass --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert gpu launch function into a vulkan
+// launch function. Creates a SPIR-V binary shader from the `spirv::ModuleOp`
+// using `spirv::serialize` function, attaches binary data and entry point name
+// as an attributes to vulkan launch call op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Serialization.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
+static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kVulkanLaunch = "vulkanLaunch";
+
+namespace {
+
+// A pass to convert gpu launch op to vulkan launch call op, by creating a
+// SPIR-V binary shader from `spirv::ModuleOp` using `spirv::serialize`
+// function and attaching binary data and entry point name as an attributes to
+// created vulkan launch call op.
+class ConvertGpuLaunchFuncToVulkanLaunchFunc
+    : public ModulePass<ConvertGpuLaunchFuncToVulkanLaunchFunc> {
+public:
+  void runOnModule() override;
+
+private:
+  /// Creates a SPIR-V binary shader from the given `module` using
+  /// `spirv::serialize` function.
+  LogicalResult createBinaryShader(ModuleOp module,
+                                   std::vector<char> &binaryShader);
+
+  /// Converts the given `luanchOp` to vulkan launch call.
+  void convertGpuLaunchFunc(gpu::LaunchFuncOp launchOp);
+
+  /// Checks where the given type is supported by Vulkan runtime.
+  bool isSupportedType(Type type) {
+    // TODO(denis0x0D): Handle other types.
+    if (auto memRefType = type.dyn_cast_or_null<MemRefType>())
+      return memRefType.hasRank() && memRefType.getRank() == 1;
+    return false;
+  }
+
+  /// Declares the vulkan launch function. Returns an error if the any type of
+  /// operand is unsupported by Vulkan runtime.
+  LogicalResult declareVulkanLaunchFunc(Location loc,
+                                        gpu::LaunchFuncOp launchOp);
+
+};
+
+} // anonymous namespace
+
+void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
+  bool done = false;
+  getModule().walk([this, &done](gpu::LaunchFuncOp op) {
+    if (done) {
+      op.emitError("should only contain one 'gpu::LaunchFuncOp' op");
+      return signalPassFailure();
+    }
+    done = true;
+    convertGpuLaunchFunc(op);
+  });
+
+  // Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
+  for (auto gpuModule :
+       llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
+    gpuModule.erase();
+
+  for (auto spirvModule :
+       llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
+    spirvModule.erase();
+}
+
+LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
+    Location loc, gpu::LaunchFuncOp launchOp) {
+  OpBuilder builder(getModule().getBody()->getTerminator());
+  // TODO: Workgroup size is written into the kernel. So to properly modelling
+  // vulkan launch, we cannot have the local workgroup size configuration here.
+  SmallVector<Type, 8> vulkanLaunchTypes{launchOp.getOperandTypes()};
+
+  // Check that all operands have supported types except those for the launch
+  // configuration.
+  for (auto type : llvm::drop_begin(vulkanLaunchTypes, 6)) {
+    if (!isSupportedType(type))
+      return launchOp.emitError() << type << " is unsupported to run on Vulkan";
+  }
+
+  // Declare vulkan launch function.
+  builder.create<FuncOp>(
+      loc, kVulkanLaunch,
+      FunctionType::get(vulkanLaunchTypes, ArrayRef<Type>{}, loc->getContext()),
+      ArrayRef<NamedAttribute>{});
+
+  return success();
+}
+
+LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::createBinaryShader(
+    ModuleOp module, std::vector<char> &binaryShader) {
+  bool done = false;
+  SmallVector<uint32_t, 0> binary;
+  for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
+    if (done)
+      return spirvModule.emitError("should only contain one 'spv.module' op");
+    done = true;
+
+    if (failed(spirv::serialize(spirvModule, binary)))
+      return failure();
+  }
+  binaryShader.resize(binary.size() * sizeof(uint32_t));
+  std::memcpy(binaryShader.data(), reinterpret_cast<char *>(binary.data()),
+              binaryShader.size());
+  return success();
+}
+
+void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
+    gpu::LaunchFuncOp launchOp) {
+  ModuleOp module = getModule();
+  OpBuilder builder(launchOp);
+  Location loc = launchOp.getLoc();
+
+  // Serialize `spirv::Module` into binary form.
+  std::vector<char> binary;
+  if (failed(createBinaryShader(module, binary)))
+    return signalPassFailure();
+
+  // Declare vulkan launch function.
+  if (failed(declareVulkanLaunchFunc(loc, launchOp)))
+    return signalPassFailure();
+
+  // Create vulkan launch call op.
+  auto vulkanLaunchCallOp = builder.create<CallOp>(
+      loc, ArrayRef<Type>{}, builder.getSymbolRefAttr(kVulkanLaunch),
+      launchOp.getOperands());
+
+  // Set SPIR-V binary shader data as an attribute.
+  vulkanLaunchCallOp.setAttr(
+      kSPIRVBlobAttrName,
+      StringAttr::get({binary.data(), binary.size()}, loc->getContext()));
+
+  // Set entry point name as an attribute.
+  vulkanLaunchCallOp.setAttr(
+      kSPIRVEntryPointAttrName,
+      StringAttr::get(launchOp.kernel(), loc->getContext()));
+
+  launchOp.erase();
+}
+
+std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
+mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass() {
+  return std::make_unique<ConvertGpuLaunchFuncToVulkanLaunchFunc>();
+}
+
+static PassRegistration<ConvertGpuLaunchFuncToVulkanLaunchFunc>
+    pass("convert-gpu-launch-to-vulkan-launch",
+         "Convert gpu.launch_func to vulkanLaunch external call");

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 03cf1c7229af..b1bdc3036076 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements a pass to convert gpu.launch_func op into a sequence of
+// This file implements a pass to convert vulkan launch call into a sequence of
 // Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
 // don't expose separate external functions in IR for each of them, instead we
 // expose a few external functions to wrapper libraries which manages Vulkan
@@ -15,40 +15,44 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
-#include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/Dialect/SPIRV/Serialization.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
-#include "mlir/IR/StandardTypes.h"
 #include "mlir/Pass/Pass.h"
 
 #include "llvm/ADT/SmallString.h"
 
 using namespace mlir;
 
+static constexpr const char *kBindResource = "bindResource";
+static constexpr const char *kDeinitVulkan = "deinitVulkan";
+static constexpr const char *kRunOnVulkan = "runOnVulkan";
+static constexpr const char *kInitVulkan = "initVulkan";
 static constexpr const char *kSetBinaryShader = "setBinaryShader";
 static constexpr const char *kSetEntryPoint = "setEntryPoint";
 static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
-static constexpr const char *kRunOnVulkan = "runOnVulkan";
 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
+static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
+static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kVulkanLaunch = "vulkanLaunch";
 
 namespace {
 
-/// A pass to convert gpu.launch_func operation into a sequence of Vulkan
-/// runtime calls.
+/// A pass to convert vulkan launch func into a sequence of Vulkan
+/// runtime calls in the following order:
 ///
+/// * initVulkan           -- initializes vulkan runtime
+/// * bindResource         -- binds resource
 /// * setBinaryShader      -- sets the binary shader data
 /// * setEntryPoint        -- sets the entry point name
 /// * setNumWorkGroups     -- sets the number of a local workgroups
 /// * runOnVulkan          -- runs vulkan runtime
+/// * deinitVulkan         -- deinitializes vulkan runtime
 ///
-class GpuLaunchFuncToVulkanCalssPass
-    : public ModulePass<GpuLaunchFuncToVulkanCalssPass> {
+class VulkanLaunchFuncToVulkanCallsPass
+    : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
 private:
   LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
 
@@ -58,72 +62,145 @@ class GpuLaunchFuncToVulkanCalssPass
 
   void initializeCachedTypes() {
     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
+    llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
     llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
+    llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
   }
 
+  LLVM::LLVMType getFloatType() { return llvmFloatType; }
   LLVM::LLVMType getVoidType() { return llvmVoidType; }
   LLVM::LLVMType getPointerType() { return llvmPointerType; }
   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
-
-  /// Creates a SPIR-V binary shader from the given `module` using
-  /// `spirv::serialize` function.
-  LogicalResult createBinaryShader(ModuleOp module,
-                                   std::vector<char> &binaryShader);
+  LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
 
   /// Creates a LLVM global for the given `name`.
   Value createEntryPointNameConstant(StringRef name, Location loc,
                                      OpBuilder &builder);
 
-  /// Creates a LLVM constant for each dimension of local workgroup and
-  /// populates the given `numWorkGroups`.
-  LogicalResult createNumWorkGroups(Location loc, OpBuilder &builder,
-                                    mlir::gpu::LaunchFuncOp launchOp,
-                                    SmallVectorImpl<Value> &numWorkGroups);
-
   /// Declares all needed runtime functions.
   void declareVulkanFunctions(Location loc);
 
-  /// Translates the given `launcOp` op to the sequence of Vulkan runtime calls
-  void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
+  /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
+  bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
+    return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
+            callOp.getNumOperands() >= 6);
+  }
+
+  /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
+  /// runtime calls.
+  void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
+
+  /// Creates call to `bindResource` for each resource operand.
+  void createBindResourceCalls(LLVM::CallOp vulkanLaunchCallOp,
+                               Value vulkanRuntiem);
 
 public:
   void runOnModule() override;
 
 private:
   LLVM::LLVMDialect *llvmDialect;
+  LLVM::LLVMType llvmFloatType;
   LLVM::LLVMType llvmVoidType;
   LLVM::LLVMType llvmPointerType;
   LLVM::LLVMType llvmInt32Type;
+  LLVM::LLVMType llvmInt64Type;
 };
 
-} // anonymous namespace
+/// Represents operand adaptor for vulkan launch call operation, to simplify an
+/// access to the lowered memref.
+// TODO: We should use 'emit-c-wrappers' option to lower memref type:
+// https://mlir.llvm.org/docs/ConversionToLLVMDialect/#c-compatible-wrapper-emission.
+struct VulkanLaunchOpOperandAdaptor {
+  VulkanLaunchOpOperandAdaptor(ArrayRef<Value> values) { operands = values; }
+  VulkanLaunchOpOperandAdaptor(const VulkanLaunchOpOperandAdaptor &) = delete;
+  VulkanLaunchOpOperandAdaptor
+  operator=(const VulkanLaunchOpOperandAdaptor &) = delete;
+
+  /// Returns a tuple with a pointer to the memory and the size for the index-th
+  /// resource.
+  std::tuple<Value, Value> getResourceDescriptor1D(uint32_t index) {
+    assert(index < getResourceCount1D());
+    // 1D memref calling convention according to "ConversionToLLVMDialect.md":
+    // 0. Allocated pointer.
+    // 1. Aligned pointer.
+    // 2. Offset.
+    // 3. Size in dim 0.
+    // 4. Stride in dim 0.
+    return {operands[numConfigOps + index * loweredMemRefNumOps1D],
+            operands[numConfigOps + index * loweredMemRefNumOps1D + 3]};
+  }
 
-void GpuLaunchFuncToVulkanCalssPass::runOnModule() {
-  initializeCachedTypes();
+  /// Returns the number of resources assuming all operands lowered from
+  /// 1D memref.
+  uint32_t getResourceCount1D() {
+    return (operands.size() - numConfigOps) / loweredMemRefNumOps1D;
+  }
 
-  getModule().walk(
-      [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
+private:
+  /// The number of operands of lowered 1D memref.
+  static constexpr const uint32_t loweredMemRefNumOps1D = 5;
+  /// The number of the first config operands.
+  static constexpr const uint32_t numConfigOps = 6;
+  ArrayRef<Value> operands;
+};
 
-  // Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
-  for (auto gpuModule :
-       llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
-    gpuModule.erase();
+} // anonymous namespace
 
-  for (auto spirvModule :
-       llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
-    spirvModule.erase();
+void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
+  initializeCachedTypes();
+  getModule().walk([this](LLVM::CallOp op) {
+    if (isVulkanLaunchCallOp(op))
+      translateVulkanLaunchCall(op);
+  });
 }
 
-void GpuLaunchFuncToVulkanCalssPass::declareVulkanFunctions(Location loc) {
+void VulkanLaunchFuncToVulkanCallsPass::createBindResourceCalls(
+    LLVM::CallOp vulkanLaunchCallOp, Value vulkanRuntime) {
+  if (vulkanLaunchCallOp.getNumOperands() == 6)
+    return;
+  OpBuilder builder(vulkanLaunchCallOp);
+  Location loc = vulkanLaunchCallOp.getLoc();
+
+  // Create LLVM constant for the descriptor set index.
+  // Bind all resources to the `0` descriptor set, the same way as `GPUToSPIRV`
+  // pass does.
+  Value descriptorSet = builder.create<LLVM::ConstantOp>(
+      loc, getInt32Type(), builder.getI32IntegerAttr(0));
+
+  auto operands = SmallVector<Value, 32>{vulkanLaunchCallOp.getOperands()};
+  VulkanLaunchOpOperandAdaptor vkLaunchOperandAdaptor(operands);
+
+  for (auto resourceIdx :
+       llvm::seq<uint32_t>(0, vkLaunchOperandAdaptor.getResourceCount1D())) {
+    // Create LLVM constant for the descriptor binding index.
+    Value descriptorBinding = builder.create<LLVM::ConstantOp>(
+        loc, getInt32Type(), builder.getI32IntegerAttr(resourceIdx));
+    // Get a pointer to the memory and size of that memory.
+    auto resourceDescriptor =
+        vkLaunchOperandAdaptor.getResourceDescriptor1D(resourceIdx);
+    // Create call to `bindResource`.
+    builder.create<LLVM::CallOp>(
+        loc, ArrayRef<Type>{getVoidType()},
+        builder.getSymbolRefAttr(kBindResource),
+        ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
+                        // Pointer to the memory.
+                        std::get<0>(resourceDescriptor),
+                        // Size of the memory.
+                        std::get<1>(resourceDescriptor)});
+  }
+}
+
+void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
   ModuleOp module = getModule();
   OpBuilder builder(module.getBody()->getTerminator());
 
   if (!module.lookupSymbol(kSetEntryPoint)) {
     builder.create<LLVM::LLVMFuncOp>(
         loc, kSetEntryPoint,
-        LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
+        LLVM::LLVMType::getFunctionTy(getVoidType(),
+                                      {getPointerType(), getPointerType()},
                                       /*isVarArg=*/false));
   }
 
@@ -131,27 +208,52 @@ void GpuLaunchFuncToVulkanCalssPass::declareVulkanFunctions(Location loc) {
     builder.create<LLVM::LLVMFuncOp>(
         loc, kSetNumWorkGroups,
         LLVM::LLVMType::getFunctionTy(
-            getVoidType(), {getInt32Type(), getInt32Type(), getInt32Type()},
+            getVoidType(),
+            {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
             /*isVarArg=*/false));
   }
 
   if (!module.lookupSymbol(kSetBinaryShader)) {
     builder.create<LLVM::LLVMFuncOp>(
         loc, kSetBinaryShader,
-        LLVM::LLVMType::getFunctionTy(getVoidType(),
-                                      {getPointerType(), getInt32Type()},
-                                      /*isVarArg=*/false));
+        LLVM::LLVMType::getFunctionTy(
+            getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
+            /*isVarArg=*/false));
   }
 
   if (!module.lookupSymbol(kRunOnVulkan)) {
     builder.create<LLVM::LLVMFuncOp>(
         loc, kRunOnVulkan,
-        LLVM::LLVMType::getFunctionTy(getVoidType(), {},
+        LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
+                                      /*isVarArg=*/false));
+  }
+
+  if (!module.lookupSymbol(kBindResource)) {
+    builder.create<LLVM::LLVMFuncOp>(
+        loc, kBindResource,
+        LLVM::LLVMType::getFunctionTy(
+            getVoidType(),
+            {getPointerType(), getInt32Type(), getInt32Type(),
+             getFloatType().getPointerTo(), getInt64Type()},
+            /*isVarArg=*/false));
+  }
+
+  if (!module.lookupSymbol(kInitVulkan)) {
+    builder.create<LLVM::LLVMFuncOp>(
+        loc, kInitVulkan,
+        LLVM::LLVMType::getFunctionTy(getPointerType(), {},
+                                      /*isVarArg=*/false));
+  }
+
+  if (!module.lookupSymbol(kDeinitVulkan)) {
+    builder.create<LLVM::LLVMFuncOp>(
+        loc, kDeinitVulkan,
+        LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
                                       /*isVarArg=*/false));
   }
 }
 
-Value GpuLaunchFuncToVulkanCalssPass::createEntryPointNameConstant(
+Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
     StringRef name, Location loc, OpBuilder &builder) {
   SmallString<16> shaderName(name.begin(), name.end());
   // Append `\0` to follow C style string given that LLVM::createGlobalString()
@@ -164,107 +266,95 @@ Value GpuLaunchFuncToVulkanCalssPass::createEntryPointNameConstant(
                                   getLLVMDialect());
 }
 
-LogicalResult GpuLaunchFuncToVulkanCalssPass::createBinaryShader(
-    ModuleOp module, std::vector<char> &binaryShader) {
-  bool done = false;
-  SmallVector<uint32_t, 0> binary;
-  for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
-    if (done)
-      return spirvModule.emitError("should only contain one 'spv.module' op");
-    done = true;
-
-    if (failed(spirv::serialize(spirvModule, binary)))
-      return failure();
+void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
+    LLVM::CallOp vulkanLaunchCallOp) {
+  OpBuilder builder(vulkanLaunchCallOp);
+  Location loc = vulkanLaunchCallOp.getLoc();
+
+  // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
+  // for the given vulkan launch call.
+  auto spirvBlobAttr =
+      vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
+  if (!spirvBlobAttr) {
+    vulkanLaunchCallOp.emitError()
+        << "missing " << kSPIRVBlobAttrName << " attribute";
+    return signalPassFailure();
   }
 
-  binaryShader.resize(binary.size() * sizeof(uint32_t));
-  std::memcpy(binaryShader.data(), reinterpret_cast<char *>(binary.data()),
-              binaryShader.size());
-  return success();
-}
-
-LogicalResult GpuLaunchFuncToVulkanCalssPass::createNumWorkGroups(
-    Location loc, OpBuilder &builder, mlir::gpu::LaunchFuncOp launchOp,
-    SmallVectorImpl<Value> &numWorkGroups) {
-  for (auto index : llvm::seq(0, 3)) {
-    auto numWorkGroupDimConstant = dyn_cast_or_null<ConstantOp>(
-        launchOp.getOperand(index).getDefiningOp());
-
-    if (!numWorkGroupDimConstant)
-      return failure();
-
-    auto numWorkGroupDimValue =
-        numWorkGroupDimConstant.getValue().cast<IntegerAttr>().getInt();
-    numWorkGroups.push_back(builder.create<LLVM::ConstantOp>(
-        loc, getInt32Type(), builder.getI32IntegerAttr(numWorkGroupDimValue)));
+  auto entryPointNameAttr =
+      vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
+  if (!entryPointNameAttr) {
+    vulkanLaunchCallOp.emitError()
+        << "missing " << kSPIRVEntryPointAttrName << " attribute";
+    return signalPassFailure();
   }
 
-  return success();
-}
-
-void GpuLaunchFuncToVulkanCalssPass::translateGpuLaunchCalls(
-    mlir::gpu::LaunchFuncOp launchOp) {
-  ModuleOp module = getModule();
-  OpBuilder builder(launchOp);
-  Location loc = launchOp.getLoc();
-
-  // Serialize `spirv::Module` into binary form.
-  std::vector<char> binary;
-  if (failed(
-          GpuLaunchFuncToVulkanCalssPass::createBinaryShader(module, binary)))
-    return signalPassFailure();
+  // Create call to `initVulkan`.
+  auto initVulkanCall = builder.create<LLVM::CallOp>(
+      loc, ArrayRef<Type>{getPointerType()},
+      builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
+  // The result of `initVulkan` function is a pointer to Vulkan runtime, we
+  // need to pass that pointer to each Vulkan runtime call.
+  auto vulkanRuntime = initVulkanCall.getResult(0);
 
   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
   // that data to runtime call.
   Value ptrToSPIRVBinary = LLVM::createGlobalString(
-      loc, builder, kSPIRVBinary, StringRef(binary.data(), binary.size()),
+      loc, builder, kSPIRVBinary, spirvBlobAttr.getValue(),
       LLVM::Linkage::Internal, getLLVMDialect());
+
   // Create LLVM constant for the size of SPIR-V binary shader.
   Value binarySize = builder.create<LLVM::ConstantOp>(
-      loc, getInt32Type(), builder.getI32IntegerAttr(binary.size()));
+      loc, getInt32Type(),
+      builder.getI32IntegerAttr(spirvBlobAttr.getValue().size()));
+
+  // Create call to `bindResource` for each resource operand.
+  createBindResourceCalls(vulkanLaunchCallOp, vulkanRuntime);
+
   // Create call to `setBinaryShader` runtime function with the given pointer to
   // SPIR-V binary and binary size.
-  builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
-                               builder.getSymbolRefAttr(kSetBinaryShader),
-                               ArrayRef<Value>{ptrToSPIRVBinary, binarySize});
-
+  builder.create<LLVM::CallOp>(
+      loc, ArrayRef<Type>{getVoidType()},
+      builder.getSymbolRefAttr(kSetBinaryShader),
+      ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
   // Create LLVM global with entry point name.
   Value entryPointName =
-      createEntryPointNameConstant(launchOp.kernel(), loc, builder);
+      createEntryPointNameConstant(entryPointNameAttr.getValue(), loc, builder);
   // Create call to `setEntryPoint` runtime function with the given pointer to
   // entry point name.
   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
                                builder.getSymbolRefAttr(kSetEntryPoint),
-                               ArrayRef<Value>{entryPointName});
+                               ArrayRef<Value>{vulkanRuntime, entryPointName});
 
   // Create number of local workgroup for each dimension.
-  SmallVector<Value, 3> numWorkGroups;
-  if (failed(createNumWorkGroups(loc, builder, launchOp, numWorkGroups)))
-    return signalPassFailure();
-
-  // Create call `setNumWorkGroups` runtime function with the given numbers of
-  // local workgroup.
   builder.create<LLVM::CallOp>(
       loc, ArrayRef<Type>{getVoidType()},
       builder.getSymbolRefAttr(kSetNumWorkGroups),
-      ArrayRef<Value>{numWorkGroups[0], numWorkGroups[1], numWorkGroups[2]});
+      ArrayRef<Value>{vulkanRuntime, vulkanLaunchCallOp.getOperand(0),
+                      vulkanLaunchCallOp.getOperand(1),
+                      vulkanLaunchCallOp.getOperand(2)});
 
   // Create call to `runOnVulkan` runtime function.
   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
                                builder.getSymbolRefAttr(kRunOnVulkan),
-                               ArrayRef<Value>{});
+                               ArrayRef<Value>{vulkanRuntime});
+
+  // Create call to 'deinitVulkan' runtime function.
+  builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
+                               builder.getSymbolRefAttr(kDeinitVulkan),
+                               ArrayRef<Value>{vulkanRuntime});
 
   // Declare runtime functions.
   declareVulkanFunctions(loc);
 
-  launchOp.erase();
+  vulkanLaunchCallOp.erase();
 }
 
 std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
-mlir::createConvertGpuLaunchFuncToVulkanCallsPass() {
-  return std::make_unique<GpuLaunchFuncToVulkanCalssPass>();
+mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
+  return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
 }
 
-static PassRegistration<GpuLaunchFuncToVulkanCalssPass>
+static PassRegistration<VulkanLaunchFuncToVulkanCallsPass>
     pass("launch-func-to-vulkan",
-         "Convert gpu.launch_func op to Vulkan runtime calls");
+         "Convert vulkanLaunch external call to Vulkan runtime external calls");

diff  --git a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
index 580c13364a23..060e2b3c93db 100644
--- a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
+++ b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
@@ -2,44 +2,47 @@
 
 // CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
 // CHECK: llvm.mlir.global internal constant @SPIRV_BIN
+// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm<"i8*">
 // CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
 // CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
 // CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @setBinaryShader(%[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm<"i8*">, !llvm.i32) -> !llvm.void
+// CHECK: llvm.call @bindResource(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm.i32, !llvm.i32, !llvm<"float*">, !llvm.i64) -> !llvm.void
+// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32) -> !llvm.void
 // CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
 // CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
-// CHECK: llvm.call @setEntryPoint(%[[entry_point_ptr]]) : (!llvm<"i8*">) -> !llvm.void
-// CHECK: %[[Workgroup_X:.*]] = llvm.mlir.constant
-// CHECK: %[[Workgroup_Y:.*]] = llvm.mlir.constant
-// CHECK: %[[Workgroup_Z:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @setNumWorkGroups(%[[Workgroup_X]], %[[Workgroup_Y]], %[[Workgroup_Z]]) : (!llvm.i32, !llvm.i32, !llvm.i32) -> !llvm.void
-// CHECK: llvm.call @runOnVulkan() : () -> !llvm.void
+// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm<"i8*">, !llvm<"i8*">) -> !llvm.void
+// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm.i64, !llvm.i64, !llvm.i64) -> !llvm.void
+// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm<"i8*">) -> !llvm.void
+// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm<"i8*">) -> !llvm.void
 
 module attributes {gpu.container_module} {
-  spv.module "Logical" "GLSL450" {
-    spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
-    spv.globalVariable @kernel_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
-    spv.func @kernel() "None" attributes {workgroup_attributions = 0 : i64} {
-      %0 = spv._address_of @kernel_arg_1 : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
-      %1 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
-      %2 = spv.constant 0 : i32
-      %3 = spv.AccessChain %1[%2] : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
-      %4 = spv.Load "StorageBuffer" %3 : f32
-      spv.Return
-    }
-    spv.EntryPoint "GLCompute" @kernel
-    spv.ExecutionMode @kernel "LocalSize", 1, 1, 1
-  } attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}
-  gpu.module @kernels {
-    gpu.func @kernel(%arg0: f32, %arg1: memref<12xf32>) kernel {
-      gpu.return
-    }
-  }
-  func @foo() {
-    %0 = "op"() : () -> f32
-    %1 = "op"() : () -> memref<12xf32>
-    %c1 = constant 1 : index
-    "gpu.launch_func"(%c1, %c1, %c1, %c1, %c1, %c1, %0, %1) {kernel = "kernel", kernel_module = @kernels} : (index, index, index, index, index, index, f32, memref<12xf32>) -> ()
-    return
+  llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
+  llvm.func @foo() {
+    %0 = llvm.mlir.constant(12 : index) : !llvm.i64
+    %1 = llvm.mlir.null : !llvm<"float*">
+    %2 = llvm.mlir.constant(1 : index) : !llvm.i64
+    %3 = llvm.getelementptr %1[%2] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+    %4 = llvm.ptrtoint %3 : !llvm<"float*"> to !llvm.i64
+    %5 = llvm.mul %0, %4 : !llvm.i64
+    %6 = llvm.call @malloc(%5) : (!llvm.i64) -> !llvm<"i8*">
+    %7 = llvm.bitcast %6 : !llvm<"i8*"> to !llvm<"float*">
+    %8 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %9 = llvm.insertvalue %7, %8[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %10 = llvm.insertvalue %7, %9[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %11 = llvm.mlir.constant(0 : index) : !llvm.i64
+    %12 = llvm.insertvalue %11, %10[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %13 = llvm.mlir.constant(1 : index) : !llvm.i64
+    %14 = llvm.insertvalue %0, %12[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %15 = llvm.insertvalue %13, %14[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %16 = llvm.mlir.constant(1 : index) : !llvm.i64
+    %17 = llvm.extractvalue %15[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %18 = llvm.extractvalue %15[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %19 = llvm.extractvalue %15[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %20 = llvm.extractvalue %15[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    %21 = llvm.extractvalue %15[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+    llvm.call @vulkanLaunch(%16, %16, %16, %16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_entry_point = "kernel"}
+    : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64) -> ()
+    llvm.return
   }
+  llvm.func @vulkanLaunch(!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64)
 }

diff  --git a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
new file mode 100644
index 000000000000..aa3daa04734e
--- /dev/null
+++ b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -convert-gpu-launch-to-vulkan-launch | FileCheck %s
+
+// CHECK: %[[resource:.*]] = alloc() : memref<12xf32>
+// CHECK: %[[index:.*]] = constant 1 : index
+// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_entry_point = "kernel"}
+
+module attributes {gpu.container_module} {
+  spv.module "Logical" "GLSL450" {
+    spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+    spv.func @kernel() "None" attributes {workgroup_attributions = 0 : i64} {
+      %0 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+      %2 = spv.constant 0 : i32
+      %3 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+      %4 = spv.AccessChain %0[%2, %2] : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+      %5 = spv.Load "StorageBuffer" %4 : f32
+      spv.Return
+    }
+    spv.EntryPoint "GLCompute" @kernel
+    spv.ExecutionMode @kernel "LocalSize", 1, 1, 1
+  } attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}
+  gpu.module @kernels {
+    gpu.func @kernel(%arg0: memref<12xf32>) kernel {
+      gpu.return
+    }
+  }
+  func @foo() {
+    %0 = alloc() : memref<12xf32>
+    %c1 = constant 1 : index
+    "gpu.launch_func"(%c1, %c1, %c1, %c1, %c1, %c1, %0) {kernel = "kernel", kernel_module = @kernels} : (index, index, index, index, index, index, memref<12xf32>) -> ()
+    return
+  }
+}

diff  --git a/mlir/test/mlir-vulkan-runner/addf.mlir b/mlir/test/mlir-vulkan-runner/addf.mlir
index 17b2a91943e5..21f5c8cdd1e5 100644
--- a/mlir/test/mlir-vulkan-runner/addf.mlir
+++ b/mlir/test/mlir-vulkan-runner/addf.mlir
@@ -27,9 +27,9 @@ module attributes {gpu.container_module} {
     %arg3 = memref_cast %arg0 : memref<8xf32> to memref<?xf32>
     %arg4 = memref_cast %arg1 : memref<8xf32> to memref<?xf32>
     %arg5 = memref_cast %arg2 : memref<8xf32> to memref<?xf32>
-    call @setResourceData(%0, %0, %arg3, %value1) : (i32, i32, memref<?xf32>, f32) -> ()
-    call @setResourceData(%0, %1, %arg4, %value2) : (i32, i32, memref<?xf32>, f32) -> ()
-    call @setResourceData(%0, %2, %arg5, %value0) : (i32, i32, memref<?xf32>, f32) -> ()
+    call @fillResource1DFloat(%arg3, %value1) : (memref<?xf32>, f32) -> ()
+    call @fillResource1DFloat(%arg4, %value2) : (memref<?xf32>, f32) -> ()
+    call @fillResource1DFloat(%arg5, %value0) : (memref<?xf32>, f32) -> ()
 
     %cst1 = constant 1 : index
     %cst8 = constant 8 : index
@@ -39,7 +39,7 @@ module attributes {gpu.container_module} {
     call @print_memref_f32(%arg6) : (memref<*xf32>) -> ()
     return
   }
-  func @setResourceData(%0 : i32, %1 : i32, %2 : memref<?xf32>, %4 : f32)
+  func @fillResource1DFloat(%0 : memref<?xf32>, %1 : f32)
   func @print_memref_f32(%ptr : memref<*xf32>)
 }
 

diff  --git a/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h
index 9c63714306b9..91f234007f74 100644
--- a/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h
+++ b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h
@@ -22,7 +22,7 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/ToolOutputFile.h"
 
-#include <vulkan/vulkan.h>
+#include <vulkan/vulkan.h> // NOLINT
 
 using namespace mlir;
 

diff  --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 6bb51f0f8d12..33f6472df4d2 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -38,8 +38,9 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
   passManager.addPass(createConvertGPUToSPIRVPass());
   OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
   modulePM.addPass(spirv::createLowerABIAttributesPass());
-  passManager.addPass(createConvertGpuLaunchFuncToVulkanCallsPass());
+  passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
   passManager.addPass(createLowerToLLVMPass());
+  passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
   return passManager.run(module);
 }
 

diff  --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
index 136076c64926..eb9a682da300 100644
--- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
@@ -14,84 +14,95 @@
 #include <numeric>
 
 #include "VulkanRuntime.h"
-#include "llvm/Support/ManagedStatic.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace {
 
-// TODO(denis0x0D): This static machinery should be replaced by `initVulkan` and
-// `deinitVulkan` to be more explicit and to avoid static initialization and
-// destruction.
-class VulkanRuntimeManager;
-static llvm::ManagedStatic<VulkanRuntimeManager> vkRuntimeManager;
-
 class VulkanRuntimeManager {
-  public:
-    VulkanRuntimeManager() = default;
-    VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
-    VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
-    ~VulkanRuntimeManager() = default;
-
-    void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
-                         const VulkanHostMemoryBuffer &memBuffer) {
-      std::lock_guard<std::mutex> lock(mutex);
-      vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
-    }
-
-    void setEntryPoint(const char *entryPoint) {
-      std::lock_guard<std::mutex> lock(mutex);
-      vulkanRuntime.setEntryPoint(entryPoint);
-    }
-
-    void setNumWorkGroups(NumWorkGroups numWorkGroups) {
-      std::lock_guard<std::mutex> lock(mutex);
-      vulkanRuntime.setNumWorkGroups(numWorkGroups);
-    }
-
-    void setShaderModule(uint8_t *shader, uint32_t size) {
-      std::lock_guard<std::mutex> lock(mutex);
-      vulkanRuntime.setShaderModule(shader, size);
-    }
-
-    void runOnVulkan() {
-      std::lock_guard<std::mutex> lock(mutex);
-      if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
-          failed(vulkanRuntime.updateHostMemoryBuffers()) ||
-          failed(vulkanRuntime.destroy())) {
-        llvm::errs() << "runOnVulkan failed";
-      }
+public:
+  VulkanRuntimeManager() = default;
+  VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
+  VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
+  ~VulkanRuntimeManager() = default;
+
+  void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
+                       const VulkanHostMemoryBuffer &memBuffer) {
+    std::lock_guard<std::mutex> lock(mutex);
+    vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
+  }
+
+  void setEntryPoint(const char *entryPoint) {
+    std::lock_guard<std::mutex> lock(mutex);
+    vulkanRuntime.setEntryPoint(entryPoint);
+  }
+
+  void setNumWorkGroups(NumWorkGroups numWorkGroups) {
+    std::lock_guard<std::mutex> lock(mutex);
+    vulkanRuntime.setNumWorkGroups(numWorkGroups);
+  }
+
+  void setShaderModule(uint8_t *shader, uint32_t size) {
+    std::lock_guard<std::mutex> lock(mutex);
+    vulkanRuntime.setShaderModule(shader, size);
+  }
+
+  void runOnVulkan() {
+    std::lock_guard<std::mutex> lock(mutex);
+    if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
+        failed(vulkanRuntime.updateHostMemoryBuffers()) ||
+        failed(vulkanRuntime.destroy())) {
+      llvm::errs() << "runOnVulkan failed";
     }
+  }
 
-  private:
-    VulkanRuntime vulkanRuntime;
-    std::mutex mutex;
+private:
+  VulkanRuntime vulkanRuntime;
+  std::mutex mutex;
 };
 
 } // namespace
 
 extern "C" {
-/// Fills the given memref with the given value.
+// Initializes `VulkanRuntimeManager` and returns a pointer to it.
+void *initVulkan() { return new VulkanRuntimeManager(); }
+
+// Deinitializes `VulkanRuntimeManager` by the given pointer.
+void deinitVulkan(void *vkRuntimeManager) {
+  delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager);
+}
+
 /// Binds the given memref to the given descriptor set and descriptor index.
-void setResourceData(const DescriptorSetIndex setIndex, BindingIndex bindIndex,
-                     float *allocated, float *aligned, int64_t offset,
-                     int64_t size, int64_t stride, float value) {
-  std::fill_n(allocated, size, value);
-  VulkanHostMemoryBuffer memBuffer{allocated,
+void bindResource(void *vkRuntimeManager, DescriptorSetIndex setIndex,
+                  BindingIndex bindIndex, float *ptr, int64_t size) {
+  VulkanHostMemoryBuffer memBuffer{ptr,
                                    static_cast<uint32_t>(size * sizeof(float))};
-  vkRuntimeManager->setResourceData(setIndex, bindIndex, memBuffer);
+  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+      ->setResourceData(setIndex, bindIndex, memBuffer);
 }
 
-void setEntryPoint(const char *entryPoint) {
-  vkRuntimeManager->setEntryPoint(entryPoint);
+void runOnVulkan(void *vkRuntimeManager) {
+  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan();
 }
 
-void setNumWorkGroups(uint32_t x, uint32_t y, uint32_t z) {
-  vkRuntimeManager->setNumWorkGroups({x, y, z});
+/// Fills the given 1D float memref with the given float value.
+void fillResource1DFloat(float *allocated, float *aligned, int64_t offset,
+                         int64_t size, int64_t stride, float value) {
+  std::fill_n(allocated, size, value);
+}
+
+void setEntryPoint(void *vkRuntimeManager, const char *entryPoint) {
+  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+      ->setEntryPoint(entryPoint);
 }
 
-void setBinaryShader(uint8_t *shader, uint32_t size) {
-  vkRuntimeManager->setShaderModule(shader, size);
+void setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y,
+                      uint32_t z) {
+  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+      ->setNumWorkGroups({x, y, z});
 }
 
-void runOnVulkan() { vkRuntimeManager->runOnVulkan(); }
+void setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
+  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+      ->setShaderModule(shader, size);
+}
 }


        


More information about the Mlir-commits mailing list