[Mlir-commits] [mlir] 1090a83 - [mlir][vulkan-runner] Update mlir-vulkan-runner execution driver.
Lei Zhang
llvmlistbot at llvm.org
Tue Mar 10 12:58:51 PDT 2020
Author: Denis Khalikov
Date: 2020-03-10T15:58:31-04:00
New Revision: 1090a830692a863ccb091e3fad8cc1a287417493
URL: https://github.com/llvm/llvm-project/commit/1090a830692a863ccb091e3fad8cc1a287417493
DIFF: https://github.com/llvm/llvm-project/commit/1090a830692a863ccb091e3fad8cc1a287417493.diff
LOG: [mlir][vulkan-runner] Update mlir-vulkan-runner execution driver.
* Adds GpuLaunchFuncToVulkanLaunchFunc conversion pass.
* Moves a serialization of the `spirv::Module` from LaunchFuncToVulkanCalls pass to newly created pass.
* Updates LaunchFuncToVulkanCalls instrumentation pass, adds `initVulkan` and `deinitVulkan` runtime calls.
* Adds `bindResource` call to bind specifc resource by the given descriptor set and descriptor binding.
* Eliminates static construction and desctruction of `VulkanRuntimeManager`.
Differential Revision: https://reviews.llvm.org/D75192
Added:
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
Modified:
mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/GPUToVulkan/CMakeLists.txt
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
mlir/test/mlir-vulkan-runner/addf.mlir
mlir/tools/mlir-vulkan-runner/VulkanRuntime.h
mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h b/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
index af2c0629c49a..9a02860bfc1a 100644
--- a/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
+++ b/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
@@ -24,7 +24,10 @@ class ModuleOp;
template <typename T> class OpPassBase;
std::unique_ptr<OpPassBase<ModuleOp>>
-createConvertGpuLaunchFuncToVulkanCallsPass();
+createConvertVulkanLaunchFuncToVulkanCallsPass();
+
+std::unique_ptr<OpPassBase<mlir::ModuleOp>>
+createConvertGpuLaunchFuncToVulkanLaunchFuncPass();
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOVULKAN_CONVERTGPUTOVULKANPASS_H
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 2b21ccb613ac..4df27e1bc1cd 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -128,7 +128,8 @@ inline void registerAllPasses() {
createLinalgToSPIRVPass();
// Vulkan
- createConvertGpuLaunchFuncToVulkanCallsPass();
+ createConvertGpuLaunchFuncToVulkanLaunchFuncPass();
+ createConvertVulkanLaunchFuncToVulkanCallsPass();
}
} // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToVulkan/CMakeLists.txt b/mlir/lib/Conversion/GPUToVulkan/CMakeLists.txt
index eeafe5f37b97..847c9c5031e9 100644
--- a/mlir/lib/Conversion/GPUToVulkan/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToVulkan/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_conversion_library(MLIRGPUtoVulkanTransforms
ConvertLaunchFuncToVulkanCalls.cpp
+ ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
)
target_link_libraries(MLIRGPUtoVulkanTransforms
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
new file mode 100644
index 000000000000..fcfae4563778
--- /dev/null
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -0,0 +1,173 @@
+//===- ConvertGPULaunchFuncToVulkanLaunchFunc.cpp - MLIR conversion pass --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert gpu launch function into a vulkan
+// launch function. Creates a SPIR-V binary shader from the `spirv::ModuleOp`
+// using `spirv::serialize` function, attaches binary data and entry point name
+// as an attributes to vulkan launch call op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Serialization.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
+static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kVulkanLaunch = "vulkanLaunch";
+
+namespace {
+
+// A pass to convert gpu launch op to vulkan launch call op, by creating a
+// SPIR-V binary shader from `spirv::ModuleOp` using `spirv::serialize`
+// function and attaching binary data and entry point name as an attributes to
+// created vulkan launch call op.
+class ConvertGpuLaunchFuncToVulkanLaunchFunc
+ : public ModulePass<ConvertGpuLaunchFuncToVulkanLaunchFunc> {
+public:
+ void runOnModule() override;
+
+private:
+ /// Creates a SPIR-V binary shader from the given `module` using
+ /// `spirv::serialize` function.
+ LogicalResult createBinaryShader(ModuleOp module,
+ std::vector<char> &binaryShader);
+
+ /// Converts the given `luanchOp` to vulkan launch call.
+ void convertGpuLaunchFunc(gpu::LaunchFuncOp launchOp);
+
+ /// Checks where the given type is supported by Vulkan runtime.
+ bool isSupportedType(Type type) {
+ // TODO(denis0x0D): Handle other types.
+ if (auto memRefType = type.dyn_cast_or_null<MemRefType>())
+ return memRefType.hasRank() && memRefType.getRank() == 1;
+ return false;
+ }
+
+ /// Declares the vulkan launch function. Returns an error if the any type of
+ /// operand is unsupported by Vulkan runtime.
+ LogicalResult declareVulkanLaunchFunc(Location loc,
+ gpu::LaunchFuncOp launchOp);
+
+};
+
+} // anonymous namespace
+
+void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
+ bool done = false;
+ getModule().walk([this, &done](gpu::LaunchFuncOp op) {
+ if (done) {
+ op.emitError("should only contain one 'gpu::LaunchFuncOp' op");
+ return signalPassFailure();
+ }
+ done = true;
+ convertGpuLaunchFunc(op);
+ });
+
+ // Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
+ for (auto gpuModule :
+ llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
+ gpuModule.erase();
+
+ for (auto spirvModule :
+ llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
+ spirvModule.erase();
+}
+
+LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
+ Location loc, gpu::LaunchFuncOp launchOp) {
+ OpBuilder builder(getModule().getBody()->getTerminator());
+ // TODO: Workgroup size is written into the kernel. So to properly modelling
+ // vulkan launch, we cannot have the local workgroup size configuration here.
+ SmallVector<Type, 8> vulkanLaunchTypes{launchOp.getOperandTypes()};
+
+ // Check that all operands have supported types except those for the launch
+ // configuration.
+ for (auto type : llvm::drop_begin(vulkanLaunchTypes, 6)) {
+ if (!isSupportedType(type))
+ return launchOp.emitError() << type << " is unsupported to run on Vulkan";
+ }
+
+ // Declare vulkan launch function.
+ builder.create<FuncOp>(
+ loc, kVulkanLaunch,
+ FunctionType::get(vulkanLaunchTypes, ArrayRef<Type>{}, loc->getContext()),
+ ArrayRef<NamedAttribute>{});
+
+ return success();
+}
+
+LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::createBinaryShader(
+ ModuleOp module, std::vector<char> &binaryShader) {
+ bool done = false;
+ SmallVector<uint32_t, 0> binary;
+ for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
+ if (done)
+ return spirvModule.emitError("should only contain one 'spv.module' op");
+ done = true;
+
+ if (failed(spirv::serialize(spirvModule, binary)))
+ return failure();
+ }
+ binaryShader.resize(binary.size() * sizeof(uint32_t));
+ std::memcpy(binaryShader.data(), reinterpret_cast<char *>(binary.data()),
+ binaryShader.size());
+ return success();
+}
+
+void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
+ gpu::LaunchFuncOp launchOp) {
+ ModuleOp module = getModule();
+ OpBuilder builder(launchOp);
+ Location loc = launchOp.getLoc();
+
+ // Serialize `spirv::Module` into binary form.
+ std::vector<char> binary;
+ if (failed(createBinaryShader(module, binary)))
+ return signalPassFailure();
+
+ // Declare vulkan launch function.
+ if (failed(declareVulkanLaunchFunc(loc, launchOp)))
+ return signalPassFailure();
+
+ // Create vulkan launch call op.
+ auto vulkanLaunchCallOp = builder.create<CallOp>(
+ loc, ArrayRef<Type>{}, builder.getSymbolRefAttr(kVulkanLaunch),
+ launchOp.getOperands());
+
+ // Set SPIR-V binary shader data as an attribute.
+ vulkanLaunchCallOp.setAttr(
+ kSPIRVBlobAttrName,
+ StringAttr::get({binary.data(), binary.size()}, loc->getContext()));
+
+ // Set entry point name as an attribute.
+ vulkanLaunchCallOp.setAttr(
+ kSPIRVEntryPointAttrName,
+ StringAttr::get(launchOp.kernel(), loc->getContext()));
+
+ launchOp.erase();
+}
+
+std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
+mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass() {
+ return std::make_unique<ConvertGpuLaunchFuncToVulkanLaunchFunc>();
+}
+
+static PassRegistration<ConvertGpuLaunchFuncToVulkanLaunchFunc>
+ pass("convert-gpu-launch-to-vulkan-launch",
+ "Convert gpu.launch_func to vulkanLaunch external call");
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 03cf1c7229af..b1bdc3036076 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements a pass to convert gpu.launch_func op into a sequence of
+// This file implements a pass to convert vulkan launch call into a sequence of
// Vulkan runtime calls. The Vulkan runtime API surface is huge so currently we
// don't expose separate external functions in IR for each of them, instead we
// expose a few external functions to wrapper libraries which manages Vulkan
@@ -15,40 +15,44 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
-#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/SPIRV/SPIRVOps.h"
-#include "mlir/Dialect/SPIRV/Serialization.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
-#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallString.h"
using namespace mlir;
+static constexpr const char *kBindResource = "bindResource";
+static constexpr const char *kDeinitVulkan = "deinitVulkan";
+static constexpr const char *kRunOnVulkan = "runOnVulkan";
+static constexpr const char *kInitVulkan = "initVulkan";
static constexpr const char *kSetBinaryShader = "setBinaryShader";
static constexpr const char *kSetEntryPoint = "setEntryPoint";
static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
-static constexpr const char *kRunOnVulkan = "runOnVulkan";
static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
+static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
+static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kVulkanLaunch = "vulkanLaunch";
namespace {
-/// A pass to convert gpu.launch_func operation into a sequence of Vulkan
-/// runtime calls.
+/// A pass to convert vulkan launch func into a sequence of Vulkan
+/// runtime calls in the following order:
///
+/// * initVulkan -- initializes vulkan runtime
+/// * bindResource -- binds resource
/// * setBinaryShader -- sets the binary shader data
/// * setEntryPoint -- sets the entry point name
/// * setNumWorkGroups -- sets the number of a local workgroups
/// * runOnVulkan -- runs vulkan runtime
+/// * deinitVulkan -- deinitializes vulkan runtime
///
-class GpuLaunchFuncToVulkanCalssPass
- : public ModulePass<GpuLaunchFuncToVulkanCalssPass> {
+class VulkanLaunchFuncToVulkanCallsPass
+ : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
private:
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
@@ -58,72 +62,145 @@ class GpuLaunchFuncToVulkanCalssPass
void initializeCachedTypes() {
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
+ llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
+ llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
}
+ LLVM::LLVMType getFloatType() { return llvmFloatType; }
LLVM::LLVMType getVoidType() { return llvmVoidType; }
LLVM::LLVMType getPointerType() { return llvmPointerType; }
LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
-
- /// Creates a SPIR-V binary shader from the given `module` using
- /// `spirv::serialize` function.
- LogicalResult createBinaryShader(ModuleOp module,
- std::vector<char> &binaryShader);
+ LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
/// Creates a LLVM global for the given `name`.
Value createEntryPointNameConstant(StringRef name, Location loc,
OpBuilder &builder);
- /// Creates a LLVM constant for each dimension of local workgroup and
- /// populates the given `numWorkGroups`.
- LogicalResult createNumWorkGroups(Location loc, OpBuilder &builder,
- mlir::gpu::LaunchFuncOp launchOp,
- SmallVectorImpl<Value> &numWorkGroups);
-
/// Declares all needed runtime functions.
void declareVulkanFunctions(Location loc);
- /// Translates the given `launcOp` op to the sequence of Vulkan runtime calls
- void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
+ /// Checks whether the given LLVM::CallOp is a vulkan launch call op.
+ bool isVulkanLaunchCallOp(LLVM::CallOp callOp) {
+ return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch &&
+ callOp.getNumOperands() >= 6);
+ }
+
+ /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan
+ /// runtime calls.
+ void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp);
+
+ /// Creates call to `bindResource` for each resource operand.
+ void createBindResourceCalls(LLVM::CallOp vulkanLaunchCallOp,
+ Value vulkanRuntiem);
public:
void runOnModule() override;
private:
LLVM::LLVMDialect *llvmDialect;
+ LLVM::LLVMType llvmFloatType;
LLVM::LLVMType llvmVoidType;
LLVM::LLVMType llvmPointerType;
LLVM::LLVMType llvmInt32Type;
+ LLVM::LLVMType llvmInt64Type;
};
-} // anonymous namespace
+/// Represents operand adaptor for vulkan launch call operation, to simplify an
+/// access to the lowered memref.
+// TODO: We should use 'emit-c-wrappers' option to lower memref type:
+// https://mlir.llvm.org/docs/ConversionToLLVMDialect/#c-compatible-wrapper-emission.
+struct VulkanLaunchOpOperandAdaptor {
+ VulkanLaunchOpOperandAdaptor(ArrayRef<Value> values) { operands = values; }
+ VulkanLaunchOpOperandAdaptor(const VulkanLaunchOpOperandAdaptor &) = delete;
+ VulkanLaunchOpOperandAdaptor
+ operator=(const VulkanLaunchOpOperandAdaptor &) = delete;
+
+ /// Returns a tuple with a pointer to the memory and the size for the index-th
+ /// resource.
+ std::tuple<Value, Value> getResourceDescriptor1D(uint32_t index) {
+ assert(index < getResourceCount1D());
+ // 1D memref calling convention according to "ConversionToLLVMDialect.md":
+ // 0. Allocated pointer.
+ // 1. Aligned pointer.
+ // 2. Offset.
+ // 3. Size in dim 0.
+ // 4. Stride in dim 0.
+ return {operands[numConfigOps + index * loweredMemRefNumOps1D],
+ operands[numConfigOps + index * loweredMemRefNumOps1D + 3]};
+ }
-void GpuLaunchFuncToVulkanCalssPass::runOnModule() {
- initializeCachedTypes();
+ /// Returns the number of resources assuming all operands lowered from
+ /// 1D memref.
+ uint32_t getResourceCount1D() {
+ return (operands.size() - numConfigOps) / loweredMemRefNumOps1D;
+ }
- getModule().walk(
- [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
+private:
+ /// The number of operands of lowered 1D memref.
+ static constexpr const uint32_t loweredMemRefNumOps1D = 5;
+ /// The number of the first config operands.
+ static constexpr const uint32_t numConfigOps = 6;
+ ArrayRef<Value> operands;
+};
- // Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
- for (auto gpuModule :
- llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
- gpuModule.erase();
+} // anonymous namespace
- for (auto spirvModule :
- llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
- spirvModule.erase();
+void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
+ initializeCachedTypes();
+ getModule().walk([this](LLVM::CallOp op) {
+ if (isVulkanLaunchCallOp(op))
+ translateVulkanLaunchCall(op);
+ });
}
-void GpuLaunchFuncToVulkanCalssPass::declareVulkanFunctions(Location loc) {
+void VulkanLaunchFuncToVulkanCallsPass::createBindResourceCalls(
+ LLVM::CallOp vulkanLaunchCallOp, Value vulkanRuntime) {
+ if (vulkanLaunchCallOp.getNumOperands() == 6)
+ return;
+ OpBuilder builder(vulkanLaunchCallOp);
+ Location loc = vulkanLaunchCallOp.getLoc();
+
+ // Create LLVM constant for the descriptor set index.
+ // Bind all resources to the `0` descriptor set, the same way as `GPUToSPIRV`
+ // pass does.
+ Value descriptorSet = builder.create<LLVM::ConstantOp>(
+ loc, getInt32Type(), builder.getI32IntegerAttr(0));
+
+ auto operands = SmallVector<Value, 32>{vulkanLaunchCallOp.getOperands()};
+ VulkanLaunchOpOperandAdaptor vkLaunchOperandAdaptor(operands);
+
+ for (auto resourceIdx :
+ llvm::seq<uint32_t>(0, vkLaunchOperandAdaptor.getResourceCount1D())) {
+ // Create LLVM constant for the descriptor binding index.
+ Value descriptorBinding = builder.create<LLVM::ConstantOp>(
+ loc, getInt32Type(), builder.getI32IntegerAttr(resourceIdx));
+ // Get a pointer to the memory and size of that memory.
+ auto resourceDescriptor =
+ vkLaunchOperandAdaptor.getResourceDescriptor1D(resourceIdx);
+ // Create call to `bindResource`.
+ builder.create<LLVM::CallOp>(
+ loc, ArrayRef<Type>{getVoidType()},
+ builder.getSymbolRefAttr(kBindResource),
+ ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
+ // Pointer to the memory.
+ std::get<0>(resourceDescriptor),
+ // Size of the memory.
+ std::get<1>(resourceDescriptor)});
+ }
+}
+
+void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
ModuleOp module = getModule();
OpBuilder builder(module.getBody()->getTerminator());
if (!module.lookupSymbol(kSetEntryPoint)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kSetEntryPoint,
- LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
+ LLVM::LLVMType::getFunctionTy(getVoidType(),
+ {getPointerType(), getPointerType()},
/*isVarArg=*/false));
}
@@ -131,27 +208,52 @@ void GpuLaunchFuncToVulkanCalssPass::declareVulkanFunctions(Location loc) {
builder.create<LLVM::LLVMFuncOp>(
loc, kSetNumWorkGroups,
LLVM::LLVMType::getFunctionTy(
- getVoidType(), {getInt32Type(), getInt32Type(), getInt32Type()},
+ getVoidType(),
+ {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()},
/*isVarArg=*/false));
}
if (!module.lookupSymbol(kSetBinaryShader)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kSetBinaryShader,
- LLVM::LLVMType::getFunctionTy(getVoidType(),
- {getPointerType(), getInt32Type()},
- /*isVarArg=*/false));
+ LLVM::LLVMType::getFunctionTy(
+ getVoidType(), {getPointerType(), getPointerType(), getInt32Type()},
+ /*isVarArg=*/false));
}
if (!module.lookupSymbol(kRunOnVulkan)) {
builder.create<LLVM::LLVMFuncOp>(
loc, kRunOnVulkan,
- LLVM::LLVMType::getFunctionTy(getVoidType(), {},
+ LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
+ /*isVarArg=*/false));
+ }
+
+ if (!module.lookupSymbol(kBindResource)) {
+ builder.create<LLVM::LLVMFuncOp>(
+ loc, kBindResource,
+ LLVM::LLVMType::getFunctionTy(
+ getVoidType(),
+ {getPointerType(), getInt32Type(), getInt32Type(),
+ getFloatType().getPointerTo(), getInt64Type()},
+ /*isVarArg=*/false));
+ }
+
+ if (!module.lookupSymbol(kInitVulkan)) {
+ builder.create<LLVM::LLVMFuncOp>(
+ loc, kInitVulkan,
+ LLVM::LLVMType::getFunctionTy(getPointerType(), {},
+ /*isVarArg=*/false));
+ }
+
+ if (!module.lookupSymbol(kDeinitVulkan)) {
+ builder.create<LLVM::LLVMFuncOp>(
+ loc, kDeinitVulkan,
+ LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()},
/*isVarArg=*/false));
}
}
-Value GpuLaunchFuncToVulkanCalssPass::createEntryPointNameConstant(
+Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
StringRef name, Location loc, OpBuilder &builder) {
SmallString<16> shaderName(name.begin(), name.end());
// Append `\0` to follow C style string given that LLVM::createGlobalString()
@@ -164,107 +266,95 @@ Value GpuLaunchFuncToVulkanCalssPass::createEntryPointNameConstant(
getLLVMDialect());
}
-LogicalResult GpuLaunchFuncToVulkanCalssPass::createBinaryShader(
- ModuleOp module, std::vector<char> &binaryShader) {
- bool done = false;
- SmallVector<uint32_t, 0> binary;
- for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
- if (done)
- return spirvModule.emitError("should only contain one 'spv.module' op");
- done = true;
-
- if (failed(spirv::serialize(spirvModule, binary)))
- return failure();
+void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
+ LLVM::CallOp vulkanLaunchCallOp) {
+ OpBuilder builder(vulkanLaunchCallOp);
+ Location loc = vulkanLaunchCallOp.getLoc();
+
+ // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes
+ // for the given vulkan launch call.
+ auto spirvBlobAttr =
+ vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName);
+ if (!spirvBlobAttr) {
+ vulkanLaunchCallOp.emitError()
+ << "missing " << kSPIRVBlobAttrName << " attribute";
+ return signalPassFailure();
}
- binaryShader.resize(binary.size() * sizeof(uint32_t));
- std::memcpy(binaryShader.data(), reinterpret_cast<char *>(binary.data()),
- binaryShader.size());
- return success();
-}
-
-LogicalResult GpuLaunchFuncToVulkanCalssPass::createNumWorkGroups(
- Location loc, OpBuilder &builder, mlir::gpu::LaunchFuncOp launchOp,
- SmallVectorImpl<Value> &numWorkGroups) {
- for (auto index : llvm::seq(0, 3)) {
- auto numWorkGroupDimConstant = dyn_cast_or_null<ConstantOp>(
- launchOp.getOperand(index).getDefiningOp());
-
- if (!numWorkGroupDimConstant)
- return failure();
-
- auto numWorkGroupDimValue =
- numWorkGroupDimConstant.getValue().cast<IntegerAttr>().getInt();
- numWorkGroups.push_back(builder.create<LLVM::ConstantOp>(
- loc, getInt32Type(), builder.getI32IntegerAttr(numWorkGroupDimValue)));
+ auto entryPointNameAttr =
+ vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName);
+ if (!entryPointNameAttr) {
+ vulkanLaunchCallOp.emitError()
+ << "missing " << kSPIRVEntryPointAttrName << " attribute";
+ return signalPassFailure();
}
- return success();
-}
-
-void GpuLaunchFuncToVulkanCalssPass::translateGpuLaunchCalls(
- mlir::gpu::LaunchFuncOp launchOp) {
- ModuleOp module = getModule();
- OpBuilder builder(launchOp);
- Location loc = launchOp.getLoc();
-
- // Serialize `spirv::Module` into binary form.
- std::vector<char> binary;
- if (failed(
- GpuLaunchFuncToVulkanCalssPass::createBinaryShader(module, binary)))
- return signalPassFailure();
+ // Create call to `initVulkan`.
+ auto initVulkanCall = builder.create<LLVM::CallOp>(
+ loc, ArrayRef<Type>{getPointerType()},
+ builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{});
+ // The result of `initVulkan` function is a pointer to Vulkan runtime, we
+ // need to pass that pointer to each Vulkan runtime call.
+ auto vulkanRuntime = initVulkanCall.getResult(0);
// Create LLVM global with SPIR-V binary data, so we can pass a pointer with
// that data to runtime call.
Value ptrToSPIRVBinary = LLVM::createGlobalString(
- loc, builder, kSPIRVBinary, StringRef(binary.data(), binary.size()),
+ loc, builder, kSPIRVBinary, spirvBlobAttr.getValue(),
LLVM::Linkage::Internal, getLLVMDialect());
+
// Create LLVM constant for the size of SPIR-V binary shader.
Value binarySize = builder.create<LLVM::ConstantOp>(
- loc, getInt32Type(), builder.getI32IntegerAttr(binary.size()));
+ loc, getInt32Type(),
+ builder.getI32IntegerAttr(spirvBlobAttr.getValue().size()));
+
+ // Create call to `bindResource` for each resource operand.
+ createBindResourceCalls(vulkanLaunchCallOp, vulkanRuntime);
+
// Create call to `setBinaryShader` runtime function with the given pointer to
// SPIR-V binary and binary size.
- builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
- builder.getSymbolRefAttr(kSetBinaryShader),
- ArrayRef<Value>{ptrToSPIRVBinary, binarySize});
-
+ builder.create<LLVM::CallOp>(
+ loc, ArrayRef<Type>{getVoidType()},
+ builder.getSymbolRefAttr(kSetBinaryShader),
+ ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize});
// Create LLVM global with entry point name.
Value entryPointName =
- createEntryPointNameConstant(launchOp.kernel(), loc, builder);
+ createEntryPointNameConstant(entryPointNameAttr.getValue(), loc, builder);
// Create call to `setEntryPoint` runtime function with the given pointer to
// entry point name.
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
builder.getSymbolRefAttr(kSetEntryPoint),
- ArrayRef<Value>{entryPointName});
+ ArrayRef<Value>{vulkanRuntime, entryPointName});
// Create number of local workgroup for each dimension.
- SmallVector<Value, 3> numWorkGroups;
- if (failed(createNumWorkGroups(loc, builder, launchOp, numWorkGroups)))
- return signalPassFailure();
-
- // Create call `setNumWorkGroups` runtime function with the given numbers of
- // local workgroup.
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getVoidType()},
builder.getSymbolRefAttr(kSetNumWorkGroups),
- ArrayRef<Value>{numWorkGroups[0], numWorkGroups[1], numWorkGroups[2]});
+ ArrayRef<Value>{vulkanRuntime, vulkanLaunchCallOp.getOperand(0),
+ vulkanLaunchCallOp.getOperand(1),
+ vulkanLaunchCallOp.getOperand(2)});
// Create call to `runOnVulkan` runtime function.
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
builder.getSymbolRefAttr(kRunOnVulkan),
- ArrayRef<Value>{});
+ ArrayRef<Value>{vulkanRuntime});
+
+ // Create call to 'deinitVulkan' runtime function.
+ builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()},
+ builder.getSymbolRefAttr(kDeinitVulkan),
+ ArrayRef<Value>{vulkanRuntime});
// Declare runtime functions.
declareVulkanFunctions(loc);
- launchOp.erase();
+ vulkanLaunchCallOp.erase();
}
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>>
-mlir::createConvertGpuLaunchFuncToVulkanCallsPass() {
- return std::make_unique<GpuLaunchFuncToVulkanCalssPass>();
+mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
+ return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
}
-static PassRegistration<GpuLaunchFuncToVulkanCalssPass>
+static PassRegistration<VulkanLaunchFuncToVulkanCallsPass>
pass("launch-func-to-vulkan",
- "Convert gpu.launch_func op to Vulkan runtime calls");
+ "Convert vulkanLaunch external call to Vulkan runtime external calls");
diff --git a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
index 580c13364a23..060e2b3c93db 100644
--- a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
+++ b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
@@ -2,44 +2,47 @@
// CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
// CHECK: llvm.mlir.global internal constant @SPIRV_BIN
+// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm<"i8*">
// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @setBinaryShader(%[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm<"i8*">, !llvm.i32) -> !llvm.void
+// CHECK: llvm.call @bindResource(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm.i32, !llvm.i32, !llvm<"float*">, !llvm.i64) -> !llvm.void
+// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32) -> !llvm.void
// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
-// CHECK: llvm.call @setEntryPoint(%[[entry_point_ptr]]) : (!llvm<"i8*">) -> !llvm.void
-// CHECK: %[[Workgroup_X:.*]] = llvm.mlir.constant
-// CHECK: %[[Workgroup_Y:.*]] = llvm.mlir.constant
-// CHECK: %[[Workgroup_Z:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @setNumWorkGroups(%[[Workgroup_X]], %[[Workgroup_Y]], %[[Workgroup_Z]]) : (!llvm.i32, !llvm.i32, !llvm.i32) -> !llvm.void
-// CHECK: llvm.call @runOnVulkan() : () -> !llvm.void
+// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm<"i8*">, !llvm<"i8*">) -> !llvm.void
+// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm.i64, !llvm.i64, !llvm.i64) -> !llvm.void
+// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm<"i8*">) -> !llvm.void
+// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm<"i8*">) -> !llvm.void
module attributes {gpu.container_module} {
- spv.module "Logical" "GLSL450" {
- spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
- spv.globalVariable @kernel_arg_1 bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
- spv.func @kernel() "None" attributes {workgroup_attributions = 0 : i64} {
- %0 = spv._address_of @kernel_arg_1 : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
- %1 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
- %2 = spv.constant 0 : i32
- %3 = spv.AccessChain %1[%2] : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
- %4 = spv.Load "StorageBuffer" %3 : f32
- spv.Return
- }
- spv.EntryPoint "GLCompute" @kernel
- spv.ExecutionMode @kernel "LocalSize", 1, 1, 1
- } attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}
- gpu.module @kernels {
- gpu.func @kernel(%arg0: f32, %arg1: memref<12xf32>) kernel {
- gpu.return
- }
- }
- func @foo() {
- %0 = "op"() : () -> f32
- %1 = "op"() : () -> memref<12xf32>
- %c1 = constant 1 : index
- "gpu.launch_func"(%c1, %c1, %c1, %c1, %c1, %c1, %0, %1) {kernel = "kernel", kernel_module = @kernels} : (index, index, index, index, index, index, f32, memref<12xf32>) -> ()
- return
+ llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
+ llvm.func @foo() {
+ %0 = llvm.mlir.constant(12 : index) : !llvm.i64
+ %1 = llvm.mlir.null : !llvm<"float*">
+ %2 = llvm.mlir.constant(1 : index) : !llvm.i64
+ %3 = llvm.getelementptr %1[%2] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
+ %4 = llvm.ptrtoint %3 : !llvm<"float*"> to !llvm.i64
+ %5 = llvm.mul %0, %4 : !llvm.i64
+ %6 = llvm.call @malloc(%5) : (!llvm.i64) -> !llvm<"i8*">
+ %7 = llvm.bitcast %6 : !llvm<"i8*"> to !llvm<"float*">
+ %8 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %9 = llvm.insertvalue %7, %8[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %10 = llvm.insertvalue %7, %9[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %11 = llvm.mlir.constant(0 : index) : !llvm.i64
+ %12 = llvm.insertvalue %11, %10[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %13 = llvm.mlir.constant(1 : index) : !llvm.i64
+ %14 = llvm.insertvalue %0, %12[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %15 = llvm.insertvalue %13, %14[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %16 = llvm.mlir.constant(1 : index) : !llvm.i64
+ %17 = llvm.extractvalue %15[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %18 = llvm.extractvalue %15[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %19 = llvm.extractvalue %15[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %20 = llvm.extractvalue %15[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ %21 = llvm.extractvalue %15[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }">
+ llvm.call @vulkanLaunch(%16, %16, %16, %16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_entry_point = "kernel"}
+ : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64) -> ()
+ llvm.return
}
+ llvm.func @vulkanLaunch(!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64)
}
diff --git a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
new file mode 100644
index 000000000000..aa3daa04734e
--- /dev/null
+++ b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -convert-gpu-launch-to-vulkan-launch | FileCheck %s
+
+// CHECK: %[[resource:.*]] = alloc() : memref<12xf32>
+// CHECK: %[[index:.*]] = constant 1 : index
+// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_entry_point = "kernel"}
+
+module attributes {gpu.container_module} {
+ spv.module "Logical" "GLSL450" {
+ spv.globalVariable @kernel_arg_0 bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+ spv.func @kernel() "None" attributes {workgroup_attributions = 0 : i64} {
+ %0 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+ %2 = spv.constant 0 : i32
+ %3 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+ %4 = spv.AccessChain %0[%2, %2] : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
+ %5 = spv.Load "StorageBuffer" %4 : f32
+ spv.Return
+ }
+ spv.EntryPoint "GLCompute" @kernel
+ spv.ExecutionMode @kernel "LocalSize", 1, 1, 1
+ } attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}
+ gpu.module @kernels {
+ gpu.func @kernel(%arg0: memref<12xf32>) kernel {
+ gpu.return
+ }
+ }
+ func @foo() {
+ %0 = alloc() : memref<12xf32>
+ %c1 = constant 1 : index
+ "gpu.launch_func"(%c1, %c1, %c1, %c1, %c1, %c1, %0) {kernel = "kernel", kernel_module = @kernels} : (index, index, index, index, index, index, memref<12xf32>) -> ()
+ return
+ }
+}
diff --git a/mlir/test/mlir-vulkan-runner/addf.mlir b/mlir/test/mlir-vulkan-runner/addf.mlir
index 17b2a91943e5..21f5c8cdd1e5 100644
--- a/mlir/test/mlir-vulkan-runner/addf.mlir
+++ b/mlir/test/mlir-vulkan-runner/addf.mlir
@@ -27,9 +27,9 @@ module attributes {gpu.container_module} {
%arg3 = memref_cast %arg0 : memref<8xf32> to memref<?xf32>
%arg4 = memref_cast %arg1 : memref<8xf32> to memref<?xf32>
%arg5 = memref_cast %arg2 : memref<8xf32> to memref<?xf32>
- call @setResourceData(%0, %0, %arg3, %value1) : (i32, i32, memref<?xf32>, f32) -> ()
- call @setResourceData(%0, %1, %arg4, %value2) : (i32, i32, memref<?xf32>, f32) -> ()
- call @setResourceData(%0, %2, %arg5, %value0) : (i32, i32, memref<?xf32>, f32) -> ()
+ call @fillResource1DFloat(%arg3, %value1) : (memref<?xf32>, f32) -> ()
+ call @fillResource1DFloat(%arg4, %value2) : (memref<?xf32>, f32) -> ()
+ call @fillResource1DFloat(%arg5, %value0) : (memref<?xf32>, f32) -> ()
%cst1 = constant 1 : index
%cst8 = constant 8 : index
@@ -39,7 +39,7 @@ module attributes {gpu.container_module} {
call @print_memref_f32(%arg6) : (memref<*xf32>) -> ()
return
}
- func @setResourceData(%0 : i32, %1 : i32, %2 : memref<?xf32>, %4 : f32)
+ func @fillResource1DFloat(%0 : memref<?xf32>, %1 : f32)
func @print_memref_f32(%ptr : memref<*xf32>)
}
diff --git a/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h
index 9c63714306b9..91f234007f74 100644
--- a/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h
+++ b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h
@@ -22,7 +22,7 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/ToolOutputFile.h"
-#include <vulkan/vulkan.h>
+#include <vulkan/vulkan.h> // NOLINT
using namespace mlir;
diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 6bb51f0f8d12..33f6472df4d2 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -38,8 +38,9 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
passManager.addPass(createConvertGPUToSPIRVPass());
OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
modulePM.addPass(spirv::createLowerABIAttributesPass());
- passManager.addPass(createConvertGpuLaunchFuncToVulkanCallsPass());
+ passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
passManager.addPass(createLowerToLLVMPass());
+ passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
return passManager.run(module);
}
diff --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
index 136076c64926..eb9a682da300 100644
--- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
@@ -14,84 +14,95 @@
#include <numeric>
#include "VulkanRuntime.h"
-#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/raw_ostream.h"
namespace {
-// TODO(denis0x0D): This static machinery should be replaced by `initVulkan` and
-// `deinitVulkan` to be more explicit and to avoid static initialization and
-// destruction.
-class VulkanRuntimeManager;
-static llvm::ManagedStatic<VulkanRuntimeManager> vkRuntimeManager;
-
class VulkanRuntimeManager {
- public:
- VulkanRuntimeManager() = default;
- VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
- VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
- ~VulkanRuntimeManager() = default;
-
- void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
- const VulkanHostMemoryBuffer &memBuffer) {
- std::lock_guard<std::mutex> lock(mutex);
- vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
- }
-
- void setEntryPoint(const char *entryPoint) {
- std::lock_guard<std::mutex> lock(mutex);
- vulkanRuntime.setEntryPoint(entryPoint);
- }
-
- void setNumWorkGroups(NumWorkGroups numWorkGroups) {
- std::lock_guard<std::mutex> lock(mutex);
- vulkanRuntime.setNumWorkGroups(numWorkGroups);
- }
-
- void setShaderModule(uint8_t *shader, uint32_t size) {
- std::lock_guard<std::mutex> lock(mutex);
- vulkanRuntime.setShaderModule(shader, size);
- }
-
- void runOnVulkan() {
- std::lock_guard<std::mutex> lock(mutex);
- if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
- failed(vulkanRuntime.updateHostMemoryBuffers()) ||
- failed(vulkanRuntime.destroy())) {
- llvm::errs() << "runOnVulkan failed";
- }
+public:
+ VulkanRuntimeManager() = default;
+ VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
+ VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
+ ~VulkanRuntimeManager() = default;
+
+ void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
+ const VulkanHostMemoryBuffer &memBuffer) {
+ std::lock_guard<std::mutex> lock(mutex);
+ vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
+ }
+
+ void setEntryPoint(const char *entryPoint) {
+ std::lock_guard<std::mutex> lock(mutex);
+ vulkanRuntime.setEntryPoint(entryPoint);
+ }
+
+ void setNumWorkGroups(NumWorkGroups numWorkGroups) {
+ std::lock_guard<std::mutex> lock(mutex);
+ vulkanRuntime.setNumWorkGroups(numWorkGroups);
+ }
+
+ void setShaderModule(uint8_t *shader, uint32_t size) {
+ std::lock_guard<std::mutex> lock(mutex);
+ vulkanRuntime.setShaderModule(shader, size);
+ }
+
+ void runOnVulkan() {
+ std::lock_guard<std::mutex> lock(mutex);
+ if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
+ failed(vulkanRuntime.updateHostMemoryBuffers()) ||
+ failed(vulkanRuntime.destroy())) {
+ llvm::errs() << "runOnVulkan failed";
}
+ }
- private:
- VulkanRuntime vulkanRuntime;
- std::mutex mutex;
+private:
+ VulkanRuntime vulkanRuntime;
+ std::mutex mutex;
};
} // namespace
extern "C" {
-/// Fills the given memref with the given value.
+// Initializes `VulkanRuntimeManager` and returns a pointer to it.
+void *initVulkan() { return new VulkanRuntimeManager(); }
+
+// Deinitializes `VulkanRuntimeManager` by the given pointer.
+void deinitVulkan(void *vkRuntimeManager) {
+ delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager);
+}
+
/// Binds the given memref to the given descriptor set and descriptor index.
-void setResourceData(const DescriptorSetIndex setIndex, BindingIndex bindIndex,
- float *allocated, float *aligned, int64_t offset,
- int64_t size, int64_t stride, float value) {
- std::fill_n(allocated, size, value);
- VulkanHostMemoryBuffer memBuffer{allocated,
+void bindResource(void *vkRuntimeManager, DescriptorSetIndex setIndex,
+ BindingIndex bindIndex, float *ptr, int64_t size) {
+ VulkanHostMemoryBuffer memBuffer{ptr,
static_cast<uint32_t>(size * sizeof(float))};
- vkRuntimeManager->setResourceData(setIndex, bindIndex, memBuffer);
+ reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+ ->setResourceData(setIndex, bindIndex, memBuffer);
}
-void setEntryPoint(const char *entryPoint) {
- vkRuntimeManager->setEntryPoint(entryPoint);
+void runOnVulkan(void *vkRuntimeManager) {
+ reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan();
}
-void setNumWorkGroups(uint32_t x, uint32_t y, uint32_t z) {
- vkRuntimeManager->setNumWorkGroups({x, y, z});
+/// Fills the given 1D float memref with the given float value.
+void fillResource1DFloat(float *allocated, float *aligned, int64_t offset,
+ int64_t size, int64_t stride, float value) {
+ std::fill_n(allocated, size, value);
+}
+
+void setEntryPoint(void *vkRuntimeManager, const char *entryPoint) {
+ reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+ ->setEntryPoint(entryPoint);
}
-void setBinaryShader(uint8_t *shader, uint32_t size) {
- vkRuntimeManager->setShaderModule(shader, size);
+void setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y,
+ uint32_t z) {
+ reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+ ->setNumWorkGroups({x, y, z});
}
-void runOnVulkan() { vkRuntimeManager->runOnVulkan(); }
+void setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
+ reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+ ->setShaderModule(shader, size);
+}
}
More information about the Mlir-commits
mailing list