[Mlir-commits] [mlir] 1009177 - [mlir][vulkan-runner] Add support for integer types.

Denis Khalikov llvmlistbot at llvm.org
Wed Apr 22 09:43:17 PDT 2020


Author: Denis Khalikov
Date: 2020-04-22T19:42:39+03:00
New Revision: 1009177d498f45b59a3e6490a5e222cafc7993a7

URL: https://github.com/llvm/llvm-project/commit/1009177d498f45b59a3e6490a5e222cafc7993a7
DIFF: https://github.com/llvm/llvm-project/commit/1009177d498f45b59a3e6490a5e222cafc7993a7.diff

LOG: [mlir][vulkan-runner] Add support for integer types.

Summary:
Add support for memrefs with element type as integer type
and simple test.

Differential Revision: https://reviews.llvm.org/D78560

Added: 
    mlir/test/mlir-vulkan-runner/addi.mlir

Modified: 
    mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
    mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
    mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index 26588049b939..d6908680d798 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -54,10 +54,12 @@ class ConvertGpuLaunchFuncToVulkanLaunchFunc
 
   /// Checks where the given type is supported by Vulkan runtime.
   bool isSupportedType(Type type) {
-    // TODO(denis0x0D): Handle other types.
-    if (auto memRefType = type.dyn_cast_or_null<MemRefType>())
+    if (auto memRefType = type.dyn_cast_or_null<MemRefType>()) {
+      auto elementType = memRefType.getElementType();
       return memRefType.hasRank() &&
-             (memRefType.getRank() >= 1 && memRefType.getRank() <= 3);
+             (memRefType.getRank() >= 1 && memRefType.getRank() <= 3) &&
+             (elementType.isIntOrFloat());
+    }
     return false;
   }
 

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 4a481bf959da..bc13d177a62e 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -30,6 +30,9 @@ using namespace mlir;
 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
 static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
 static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat";
+static constexpr const char *kBindMemRef1DInt = "bindMemRef1DInt";
+static constexpr const char *kBindMemRef2DInt = "bindMemRef2DInt";
+static constexpr const char *kBindMemRef3DInt = "bindMemRef3DInt";
 static constexpr const char *kCInterfaceVulkanLaunch =
     "_mlir_ciface_vulkanLaunch";
 static constexpr const char *kDeinitVulkan = "deinitVulkan";
@@ -73,12 +76,15 @@ class VulkanLaunchFuncToVulkanCallsPass
     llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
     llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
     llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
-    llvmMemRef1DFloat = getMemRefType(1);
-    llvmMemRef2DFloat = getMemRefType(2);
-    llvmMemRef3DFloat = getMemRefType(3);
+    llvmMemRef1DFloat = getMemRefType(1, llvmFloatType);
+    llvmMemRef2DFloat = getMemRefType(2, llvmFloatType);
+    llvmMemRef3DFloat = getMemRefType(3, llvmFloatType);
+    llvmMemRef1DInt = getMemRefType(1, llvmInt32Type);
+    llvmMemRef2DInt = getMemRefType(2, llvmInt32Type);
+    llvmMemRef3DInt = getMemRefType(3, llvmInt32Type);
   }
 
-  LLVM::LLVMType getMemRefType(uint32_t rank) {
+  LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
     // According to the MLIR doc memref argument is converted into a
     // pointer-to-struct argument of type:
     // template <typename Elem, size_t Rank>
@@ -89,15 +95,16 @@ class VulkanLaunchFuncToVulkanCallsPass
     //   int64_t sizes[Rank]; // omitted when rank == 0
     //   int64_t strides[Rank]; // omitted when rank == 0
     // };
-    auto llvmPtrToFloatType = getFloatType().getPointerTo();
+    auto llvmPtrToElementType = elemenType.getPointerTo();
     auto llvmArrayRankElementSizeType =
         LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
 
     // Create a type
-    // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`.
+    // `!llvm<"{ `element-type`*, `element-type`*, i64,
+    // [`rank` x i64], [`rank` x i64]}">`.
     return LLVM::LLVMType::getStructTy(
         llvmDialect,
-        {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
+        {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
   }
 
@@ -109,6 +116,9 @@ class VulkanLaunchFuncToVulkanCallsPass
   LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
   LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
   LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; }
+  LLVM::LLVMType getMemRef1DInt() { return llvmMemRef1DInt; }
+  LLVM::LLVMType getMemRef2DInt() { return llvmMemRef2DInt; }
+  LLVM::LLVMType getMemRef3DInt() { return llvmMemRef3DInt; }
 
   /// Creates a LLVM global for the given `name`.
   Value createEntryPointNameConstant(StringRef name, Location loc,
@@ -142,8 +152,19 @@ class VulkanLaunchFuncToVulkanCallsPass
   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
 
-  /// Deduces a rank from the given 'ptrToMemRefDescriptor`.
-  LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
+  /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
+  LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
+                                        uint32_t &rank, LLVM::LLVMType &type);
+
+  /// Returns a string representation from the given `type`.
+  StringRef stringifyType(LLVM::LLVMType type) {
+    if (type.isFloatTy())
+      return "Float";
+    if (type.isIntegerTy())
+      return "Int";
+
+    llvm_unreachable("unsupported type");
+  }
 
 public:
   void runOnOperation() override;
@@ -158,6 +179,9 @@ class VulkanLaunchFuncToVulkanCallsPass
   LLVM::LLVMType llvmMemRef1DFloat;
   LLVM::LLVMType llvmMemRef2DFloat;
   LLVM::LLVMType llvmMemRef3DFloat;
+  LLVM::LLVMType llvmMemRef1DInt;
+  LLVM::LLVMType llvmMemRef2DInt;
+  LLVM::LLVMType llvmMemRef3DInt;
 
   // TODO: Use an associative array to support multiple vulkan launch calls.
   std::pair<StringAttr, StringAttr> spirvAttributes;
@@ -231,13 +255,15 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
 
     auto ptrToMemRefDescriptor = en.value();
     uint32_t rank = 0;
-    if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
+    LLVM::LLVMType type;
+    if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
       cInterfaceVulkanLaunchCallOp.emitError()
           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
       return signalPassFailure();
     }
 
-    auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str();
+    auto symbolName =
+        llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
     // Create call to `bindMemRef`.
     builder.create<LLVM::CallOp>(
         loc, ArrayRef<Type>{getVoidType()},
@@ -248,9 +274,8 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
   }
 }
 
-LogicalResult
-VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
-                                                    uint32_t &rank) {
+LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
+    Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
   auto llvmPtrDescriptorTy =
       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
   if (!llvmPtrDescriptorTy)
@@ -267,11 +292,12 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
   // };
   if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
     return failure();
+
+  type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy();
   if (llvmDescriptorTy.getStructNumElements() == 3) {
     rank = 0;
     return success();
   }
-
   rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
   return success();
 }
@@ -312,35 +338,23 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
                                       /*isVarArg=*/false));
   }
 
-  if (!module.lookupSymbol(kBindMemRef1DFloat)) {
-    builder.create<LLVM::LLVMFuncOp>(
-        loc, kBindMemRef1DFloat,
-        LLVM::LLVMType::getFunctionTy(getVoidType(),
-                                      {getPointerType(), getInt32Type(),
-                                       getInt32Type(),
-                                       getMemRef1DFloat().getPointerTo()},
-                                      /*isVarArg=*/false));
+#define CREATE_VULKAN_BIND_FUNC(MemRefType)                                    \
+  if (!module.lookupSymbol(kBind##MemRefType)) {                               \
+    builder.create<LLVM::LLVMFuncOp>(                                          \
+        loc, kBind##MemRefType,                                                \
+        LLVM::LLVMType::getFunctionTy(getVoidType(),                           \
+                                      {getPointerType(), getInt32Type(),       \
+                                       getInt32Type(),                         \
+                                       get##MemRefType().getPointerTo()},      \
+                                      /*isVarArg=*/false));                    \
   }
 
-  if (!module.lookupSymbol(kBindMemRef2DFloat)) {
-    builder.create<LLVM::LLVMFuncOp>(
-        loc, kBindMemRef2DFloat,
-        LLVM::LLVMType::getFunctionTy(getVoidType(),
-                                      {getPointerType(), getInt32Type(),
-                                       getInt32Type(),
-                                       getMemRef2DFloat().getPointerTo()},
-                                      /*isVarArg=*/false));
-  }
-
-  if (!module.lookupSymbol(kBindMemRef3DFloat)) {
-    builder.create<LLVM::LLVMFuncOp>(
-        loc, kBindMemRef3DFloat,
-        LLVM::LLVMType::getFunctionTy(getVoidType(),
-                                      {getPointerType(), getInt32Type(),
-                                       getInt32Type(),
-                                       getMemRef3DFloat().getPointerTo()},
-                                      /*isVarArg=*/false));
-  }
+  CREATE_VULKAN_BIND_FUNC(MemRef1DFloat);
+  CREATE_VULKAN_BIND_FUNC(MemRef2DFloat);
+  CREATE_VULKAN_BIND_FUNC(MemRef3DFloat);
+  CREATE_VULKAN_BIND_FUNC(MemRef1DInt);
+  CREATE_VULKAN_BIND_FUNC(MemRef2DInt);
+  CREATE_VULKAN_BIND_FUNC(MemRef3DInt);
 
   if (!module.lookupSymbol(kInitVulkan)) {
     builder.create<LLVM::LLVMFuncOp>(

diff  --git a/mlir/test/mlir-vulkan-runner/addi.mlir b/mlir/test/mlir-vulkan-runner/addi.mlir
new file mode 100644
index 000000000000..c690120718b2
--- /dev/null
+++ b/mlir/test/mlir-vulkan-runner/addi.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-vulkan-runner %s --shared-libs=%vulkan_wrapper_library_dir/libvulkan-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+// CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3]
+module attributes {
+  gpu.container_module,
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+  gpu.module @kernels {
+    gpu.func @kernel_addi(%arg0 : memref<8xi32>, %arg1 : memref<8x8xi32>, %arg2 : memref<8x8x8xi32>)
+      kernel attributes { spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} {
+      %x = "gpu.block_id"() {dimension = "x"} : () -> index
+      %y = "gpu.block_id"() {dimension = "y"} : () -> index
+      %z = "gpu.block_id"() {dimension = "z"} : () -> index
+      %0 = load %arg0[%x] : memref<8xi32>
+      %1 = load %arg1[%y, %x] : memref<8x8xi32>
+      %2 = addi %0, %1 : i32
+      store %2, %arg2[%z, %y, %x] : memref<8x8x8xi32>
+      gpu.return
+    }
+  }
+
+  func @main() {
+    %arg0 = alloc() : memref<8xi32>
+    %arg1 = alloc() : memref<8x8xi32>
+    %arg2 = alloc() : memref<8x8x8xi32>
+    %value0 = constant 0 : i32
+    %value1 = constant 1 : i32
+    %value2 = constant 2 : i32
+    %arg3 = memref_cast %arg0 : memref<8xi32> to memref<?xi32>
+    %arg4 = memref_cast %arg1 : memref<8x8xi32> to memref<?x?xi32>
+    %arg5 = memref_cast %arg2 : memref<8x8x8xi32> to memref<?x?x?xi32>
+    call @fillResource1DInt(%arg3, %value1) : (memref<?xi32>, i32) -> ()
+    call @fillResource2DInt(%arg4, %value2) : (memref<?x?xi32>, i32) -> ()
+    call @fillResource3DInt(%arg5, %value0) : (memref<?x?x?xi32>, i32) -> ()
+
+    %cst1 = constant 1 : index
+    %cst8 = constant 8 : index
+    "gpu.launch_func"(%cst8, %cst8, %cst8, %cst1, %cst1, %cst1, %arg0, %arg1, %arg2) { kernel = @kernels::@kernel_addi }
+        : (index, index, index, index, index, index, memref<8xi32>, memref<8x8xi32>, memref<8x8x8xi32>) -> ()
+    %arg6 = memref_cast %arg5 : memref<?x?x?xi32> to memref<*xi32>
+    call @print_memref_i32(%arg6) : (memref<*xi32>) -> ()
+    return
+  }
+  func @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
+  func @fillResource2DInt(%0 : memref<?x?xi32>, %1 : i32)
+  func @fillResource3DInt(%0 : memref<?x?x?xi32>, %1 : i32)
+  func @print_memref_i32(%ptr : memref<*xi32>)
+}
+

diff  --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
index 4c428ef0349a..b1848de00690 100644
--- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
@@ -135,6 +135,41 @@ void bindMemRef3DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex,
       ->setResourceData(setIndex, bindIndex, memBuffer);
 }
 
+/// Binds the given 1D int memref to the given descriptor set and descriptor
+/// index.
+void bindMemRef1DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex,
+                     BindingIndex bindIndex,
+                     MemRefDescriptor<int32_t, 1> *ptr) {
+  VulkanHostMemoryBuffer memBuffer{
+      ptr->allocated, static_cast<uint32_t>(ptr->sizes[0] * sizeof(int32_t))};
+  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+      ->setResourceData(setIndex, bindIndex, memBuffer);
+}
+
+/// Binds the given 2D int memref to the given descriptor set and descriptor
+/// index.
+void bindMemRef2DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex,
+                     BindingIndex bindIndex,
+                     MemRefDescriptor<int32_t, 2> *ptr) {
+  VulkanHostMemoryBuffer memBuffer{
+      ptr->allocated,
+      static_cast<uint32_t>(ptr->sizes[0] * ptr->sizes[1] * sizeof(int32_t))};
+  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+      ->setResourceData(setIndex, bindIndex, memBuffer);
+}
+
+/// Binds the given 3D int memref to the given descriptor set and descriptor
+/// index.
+void bindMemRef3DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex,
+                     BindingIndex bindIndex,
+                     MemRefDescriptor<int32_t, 3> *ptr) {
+  VulkanHostMemoryBuffer memBuffer{
+      ptr->allocated, static_cast<uint32_t>(ptr->sizes[0] * ptr->sizes[1] *
+                                            ptr->sizes[2] * sizeof(int32_t))};
+  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+      ->setResourceData(setIndex, bindIndex, memBuffer);
+}
+
 /// Fills the given 1D float memref with the given float value.
 void _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
                                       float value) {
@@ -153,4 +188,23 @@ void _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
   std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
               value);
 }
+
+/// Fills the given 1D int memref with the given int value.
+void _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
+                                    int32_t value) {
+  std::fill_n(ptr->allocated, ptr->sizes[0], value);
+}
+
+/// Fills the given 2D int memref with the given int value.
+void _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
+                                    int32_t value) {
+  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
+}
+
+/// Fills the given 3D int memref with the given int value.
+void _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
+                                    int32_t value) {
+  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
+              value);
+}
 }


        


More information about the Mlir-commits mailing list