[Mlir-commits] [mlir] afd43a7 - [mlir][vulkan-runner] add support for memref of i8, i16 types in vulkan runner

Thomas Raoux llvmlistbot at llvm.org
Thu Jun 18 13:26:06 PDT 2020


Author: Thomas Raoux
Date: 2020-06-18T13:24:51-07:00
New Revision: afd43a7a7878ea448079feab4ec922109c0eb6cf

URL: https://github.com/llvm/llvm-project/commit/afd43a7a7878ea448079feab4ec922109c0eb6cf
DIFF: https://github.com/llvm/llvm-project/commit/afd43a7a7878ea448079feab4ec922109c0eb6cf.diff

LOG: [mlir][vulkan-runner] add support for memref of i8, i16 types in vulkan runner

This extends the types supported as kernel arguments when using vulkan runner.

Differential Revision: https://reviews.llvm.org/D82068

Added: 
    mlir/test/mlir-vulkan-runner/addi8.mlir

Modified: 
    mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
    mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index bc13d177a62e..10394b795daa 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -27,12 +27,6 @@
 
 using namespace mlir;
 
-static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
-static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
-static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat";
-static constexpr const char *kBindMemRef1DInt = "bindMemRef1DInt";
-static constexpr const char *kBindMemRef2DInt = "bindMemRef2DInt";
-static constexpr const char *kBindMemRef3DInt = "bindMemRef3DInt";
 static constexpr const char *kCInterfaceVulkanLaunch =
     "_mlir_ciface_vulkanLaunch";
 static constexpr const char *kDeinitVulkan = "deinitVulkan";
@@ -76,12 +70,6 @@ class VulkanLaunchFuncToVulkanCallsPass
     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
-    llvmMemRef1DFloat = getMemRefType(1, llvmFloatType);
-    llvmMemRef2DFloat = getMemRefType(2, llvmFloatType);
-    llvmMemRef3DFloat = getMemRefType(3, llvmFloatType);
-    llvmMemRef1DInt = getMemRefType(1, llvmInt32Type);
-    llvmMemRef2DInt = getMemRefType(2, llvmInt32Type);
-    llvmMemRef3DInt = getMemRefType(3, llvmInt32Type);
   }
 
   LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
@@ -108,17 +96,10 @@ class VulkanLaunchFuncToVulkanCallsPass
          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
   }
 
-  LLVM::LLVMType getFloatType() { return llvmFloatType; }
   LLVM::LLVMType getVoidType() { return llvmVoidType; }
   LLVM::LLVMType getPointerType() { return llvmPointerType; }
   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
-  LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
-  LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
-  LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; }
-  LLVM::LLVMType getMemRef1DInt() { return llvmMemRef1DInt; }
-  LLVM::LLVMType getMemRef2DInt() { return llvmMemRef2DInt; }
-  LLVM::LLVMType getMemRef3DInt() { return llvmMemRef3DInt; }
 
   /// Creates a LLVM global for the given `name`.
   Value createEntryPointNameConstant(StringRef name, Location loc,
@@ -160,8 +141,14 @@ class VulkanLaunchFuncToVulkanCallsPass
   StringRef stringifyType(LLVM::LLVMType type) {
     if (type.isFloatTy())
       return "Float";
-    if (type.isIntegerTy())
-      return "Int";
+    if (type.isHalfTy())
+      return "Half";
+    if (type.isIntegerTy(32))
+      return "Int32";
+    if (type.isIntegerTy(16))
+      return "Int16";
+    if (type.isIntegerTy(8))
+      return "Int8";
 
     llvm_unreachable("unsupported type");
   }
@@ -176,12 +163,6 @@ class VulkanLaunchFuncToVulkanCallsPass
   LLVM::LLVMType llvmPointerType;
   LLVM::LLVMType llvmInt32Type;
   LLVM::LLVMType llvmInt64Type;
-  LLVM::LLVMType llvmMemRef1DFloat;
-  LLVM::LLVMType llvmMemRef2DFloat;
-  LLVM::LLVMType llvmMemRef3DFloat;
-  LLVM::LLVMType llvmMemRef1DInt;
-  LLVM::LLVMType llvmMemRef2DInt;
-  LLVM::LLVMType llvmMemRef3DInt;
 
   // TODO: Use an associative array to support multiple vulkan launch calls.
   std::pair<StringAttr, StringAttr> spirvAttributes;
@@ -264,6 +245,14 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
 
     auto symbolName =
         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
+    // Special case for fp16 type. Since it is not a supported type in C we use
+    // int16_t and bitcast the descriptor.
+    if (type.isHalfTy()) {
+      auto memRefTy =
+          getMemRefType(rank, LLVM::LLVMType::getInt16Ty(llvmDialect));
+      ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
+          loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor);
+    }
     // Create call to `bindMemRef`.
     builder.create<LLVM::CallOp>(
         loc, ArrayRef<Type>{getVoidType()},
@@ -338,24 +327,27 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
                                       /*isVarArg=*/false));
   }
 
-#define CREATE_VULKAN_BIND_FUNC(MemRefType)                                    \
-  if (!module.lookupSymbol(kBind##MemRefType)) {                               \
-    builder.create<LLVM::LLVMFuncOp>(                                          \
-        loc, kBind##MemRefType,                                                \
-        LLVM::LLVMType::getFunctionTy(getVoidType(),                           \
-                                      {getPointerType(), getInt32Type(),       \
-                                       getInt32Type(),                         \
-                                       get##MemRefType().getPointerTo()},      \
-                                      /*isVarArg=*/false));                    \
+  for (unsigned i = 1; i <= 3; i++) {
+    for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(llvmDialect),
+                                LLVM::LLVMType::getInt32Ty(llvmDialect),
+                                LLVM::LLVMType::getInt16Ty(llvmDialect),
+                                LLVM::LLVMType::getInt8Ty(llvmDialect),
+                                LLVM::LLVMType::getHalfTy(llvmDialect)}) {
+      std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
+                           std::string(stringifyType(type));
+      if (type.isHalfTy())
+        type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(llvmDialect));
+      if (!module.lookupSymbol(fnName)) {
+        auto fnType = LLVM::LLVMType::getFunctionTy(
+            getVoidType(),
+            {getPointerType(), getInt32Type(), getInt32Type(),
+             getMemRefType(i, type).getPointerTo()},
+            /*isVarArg=*/false);
+        builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
+      }
+    }
   }
 
-  CREATE_VULKAN_BIND_FUNC(MemRef1DFloat);
-  CREATE_VULKAN_BIND_FUNC(MemRef2DFloat);
-  CREATE_VULKAN_BIND_FUNC(MemRef3DFloat);
-  CREATE_VULKAN_BIND_FUNC(MemRef1DInt);
-  CREATE_VULKAN_BIND_FUNC(MemRef2DInt);
-  CREATE_VULKAN_BIND_FUNC(MemRef3DInt);
-
   if (!module.lookupSymbol(kInitVulkan)) {
     builder.create<LLVM::LLVMFuncOp>(
         loc, kInitVulkan,

diff  --git a/mlir/test/mlir-vulkan-runner/addi8.mlir b/mlir/test/mlir-vulkan-runner/addi8.mlir
new file mode 100644
index 000000000000..094186d5731d
--- /dev/null
+++ b/mlir/test/mlir-vulkan-runner/addi8.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-vulkan-runner %s --shared-libs=%vulkan_wrapper_library_dir/libvulkan-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+// CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3]
+module attributes {
+  gpu.container_module,
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_8bit_storage]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+  gpu.module @kernels {
+    gpu.func @kernel_addi(%arg0 : memref<8xi8>, %arg1 : memref<8x8xi8>, %arg2 : memref<8x8x8xi32>)
+      kernel attributes { spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} {
+      %x = "gpu.block_id"() {dimension = "x"} : () -> index
+      %y = "gpu.block_id"() {dimension = "y"} : () -> index
+      %z = "gpu.block_id"() {dimension = "z"} : () -> index
+      %0 = load %arg0[%x] : memref<8xi8>
+      %1 = load %arg1[%y, %x] : memref<8x8xi8>
+      %2 = addi %0, %1 : i8
+      %3 = zexti %2 : i8 to i32
+      store %3, %arg2[%z, %y, %x] : memref<8x8x8xi32>
+      gpu.return
+    }
+  }
+
+  func @main() {
+    %arg0 = alloc() : memref<8xi8>
+    %arg1 = alloc() : memref<8x8xi8>
+    %arg2 = alloc() : memref<8x8x8xi32>
+    %value0 = constant 0 : i32
+    %value1 = constant 1 : i8
+    %value2 = constant 2 : i8
+    %arg3 = memref_cast %arg0 : memref<8xi8> to memref<?xi8>
+    %arg4 = memref_cast %arg1 : memref<8x8xi8> to memref<?x?xi8>
+    %arg5 = memref_cast %arg2 : memref<8x8x8xi32> to memref<?x?x?xi32>
+    call @fillResource1DInt8(%arg3, %value1) : (memref<?xi8>, i8) -> ()
+    call @fillResource2DInt8(%arg4, %value2) : (memref<?x?xi8>, i8) -> ()
+    call @fillResource3DInt(%arg5, %value0) : (memref<?x?x?xi32>, i32) -> ()
+
+    %cst1 = constant 1 : index
+    %cst8 = constant 8 : index
+    "gpu.launch_func"(%cst8, %cst8, %cst8, %cst1, %cst1, %cst1, %arg0, %arg1, %arg2) { kernel = @kernels::@kernel_addi }
+        : (index, index, index, index, index, index, memref<8xi8>, memref<8x8xi8>, memref<8x8x8xi32>) -> ()
+    %arg6 = memref_cast %arg5 : memref<?x?x?xi32> to memref<*xi32>
+    call @print_memref_i32(%arg6) : (memref<*xi32>) -> ()
+    return
+  }
+  func @fillResource1DInt8(%0 : memref<?xi8>, %1 : i8)
+  func @fillResource2DInt8(%0 : memref<?x?xi8>, %1 : i8)
+  func @fillResource3DInt(%0 : memref<?x?x?xi32>, %1 : i32)
+  func @print_memref_i32(%ptr : memref<*xi32>)
+}

diff  --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
index 11b5203c7b18..5742750e13c2 100644
--- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
@@ -71,6 +71,17 @@ struct MemRefDescriptor {
   int64_t strides[N];
 };
 
+template <typename T, uint32_t S>
+void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex,
+                BindingIndex bindIndex, MemRefDescriptor<T, S> *ptr) {
+  uint32_t size = sizeof(T);
+  for (unsigned i = 0; i < S; i++)
+    size *= ptr->sizes[i];
+  VulkanHostMemoryBuffer memBuffer{ptr->allocated, size};
+  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+      ->setResourceData(setIndex, bindIndex, memBuffer);
+}
+
 extern "C" {
 /// Initializes `VulkanRuntimeManager` and returns a pointer to it.
 void *initVulkan() { return new VulkanRuntimeManager(); }
@@ -100,75 +111,30 @@ void setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
       ->setShaderModule(shader, size);
 }
 
-/// Binds the given 1D float memref to the given descriptor set and descriptor
-/// index.
-void bindMemRef1DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex,
-                       BindingIndex bindIndex,
-                       MemRefDescriptor<float, 1> *ptr) {
-  VulkanHostMemoryBuffer memBuffer{
-      ptr->allocated, static_cast<uint32_t>(ptr->sizes[0] * sizeof(float))};
-  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
-      ->setResourceData(setIndex, bindIndex, memBuffer);
-}
-
-/// Binds the given 2D float memref to the given descriptor set and descriptor
-/// index.
-void bindMemRef2DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex,
-                       BindingIndex bindIndex,
-                       MemRefDescriptor<float, 2> *ptr) {
-  VulkanHostMemoryBuffer memBuffer{
-      ptr->allocated,
-      static_cast<uint32_t>(ptr->sizes[0] * ptr->sizes[1] * sizeof(float))};
-  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
-      ->setResourceData(setIndex, bindIndex, memBuffer);
-}
-
-/// Binds the given 3D float memref to the given descriptor set and descriptor
-/// index.
-void bindMemRef3DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex,
-                       BindingIndex bindIndex,
-                       MemRefDescriptor<float, 3> *ptr) {
-  VulkanHostMemoryBuffer memBuffer{
-      ptr->allocated, static_cast<uint32_t>(ptr->sizes[0] * ptr->sizes[1] *
-                                            ptr->sizes[2] * sizeof(float))};
-  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
-      ->setResourceData(setIndex, bindIndex, memBuffer);
-}
-
-/// Binds the given 1D int memref to the given descriptor set and descriptor
-/// index.
-void bindMemRef1DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex,
-                     BindingIndex bindIndex,
-                     MemRefDescriptor<int32_t, 1> *ptr) {
-  VulkanHostMemoryBuffer memBuffer{
-      ptr->allocated, static_cast<uint32_t>(ptr->sizes[0] * sizeof(int32_t))};
-  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
-      ->setResourceData(setIndex, bindIndex, memBuffer);
-}
-
-/// Binds the given 2D int memref to the given descriptor set and descriptor
+/// Binds the given memref to the given descriptor set and descriptor
 /// index.
-void bindMemRef2DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex,
-                     BindingIndex bindIndex,
-                     MemRefDescriptor<int32_t, 2> *ptr) {
-  VulkanHostMemoryBuffer memBuffer{
-      ptr->allocated,
-      static_cast<uint32_t>(ptr->sizes[0] * ptr->sizes[1] * sizeof(int32_t))};
-  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
-      ->setResourceData(setIndex, bindIndex, memBuffer);
-}
+#define DECLARE_BIND_MEMREF(size, type, typeName)                              \
+  void bindMemRef##size##D##typeName(                                          \
+      void *vkRuntimeManager, DescriptorSetIndex setIndex,                     \
+      BindingIndex bindIndex, MemRefDescriptor<type, size> *ptr) {             \
+    bindMemRef<type, size>(vkRuntimeManager, setIndex, bindIndex, ptr);        \
+  }
 
-/// Binds the given 3D int memref to the given descriptor set and descriptor
-/// index.
-void bindMemRef3DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex,
-                     BindingIndex bindIndex,
-                     MemRefDescriptor<int32_t, 3> *ptr) {
-  VulkanHostMemoryBuffer memBuffer{
-      ptr->allocated, static_cast<uint32_t>(ptr->sizes[0] * ptr->sizes[1] *
-                                            ptr->sizes[2] * sizeof(int32_t))};
-  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
-      ->setResourceData(setIndex, bindIndex, memBuffer);
-}
+DECLARE_BIND_MEMREF(1, float, Float)
+DECLARE_BIND_MEMREF(2, float, Float)
+DECLARE_BIND_MEMREF(3, float, Float)
+DECLARE_BIND_MEMREF(1, int32_t, Int32)
+DECLARE_BIND_MEMREF(2, int32_t, Int32)
+DECLARE_BIND_MEMREF(3, int32_t, Int32)
+DECLARE_BIND_MEMREF(1, int16_t, Int16)
+DECLARE_BIND_MEMREF(2, int16_t, Int16)
+DECLARE_BIND_MEMREF(3, int16_t, Int16)
+DECLARE_BIND_MEMREF(1, int8_t, Int8)
+DECLARE_BIND_MEMREF(2, int8_t, Int8)
+DECLARE_BIND_MEMREF(3, int8_t, Int8)
+DECLARE_BIND_MEMREF(1, int16_t, Half)
+DECLARE_BIND_MEMREF(2, int16_t, Half)
+DECLARE_BIND_MEMREF(3, int16_t, Half)
 
 /// Fills the given 1D float memref with the given float value.
 void _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
@@ -207,4 +173,23 @@ void _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
               value);
 }
+
+/// Fills the given 1D int memref with the given int8 value.
+void _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
+                                     int8_t value) {
+  std::fill_n(ptr->allocated, ptr->sizes[0], value);
+}
+
+/// Fills the given 2D int memref with the given int8 value.
+void _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
+                                     int8_t value) {
+  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
+}
+
+/// Fills the given 3D int memref with the given int8 value.
+void _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
+                                     int8_t value) {
+  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
+              value);
+}
 }


        


More information about the Mlir-commits mailing list