[Mlir-commits] [mlir] 8f4ab8c - [mlir][vulkan-runner] Add support for 2D memref.

Denis Khalikov llvmlistbot at llvm.org
Fri Mar 27 04:01:37 PDT 2020


Author: Denis Khalikov
Date: 2020-03-27T13:59:17+03:00
New Revision: 8f4ab8c7d7f51bc51fa3edebd508481eb27efbf1

URL: https://github.com/llvm/llvm-project/commit/8f4ab8c7d7f51bc51fa3edebd508481eb27efbf1
DIFF: https://github.com/llvm/llvm-project/commit/8f4ab8c7d7f51bc51fa3edebd508481eb27efbf1.diff

LOG: [mlir][vulkan-runner] Add support for 2D memref.

Summary:
This patch adds support for 2D memref in mlir-vulkan-runner.

Differential Revision: https://reviews.llvm.org/D76737

Added: 
    mlir/test/mlir-vulkan-runner/mulf.mlir

Modified: 
    mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
    mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
    mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index 26833fd4daa7..cccd53f45992 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -55,7 +55,8 @@ class ConvertGpuLaunchFuncToVulkanLaunchFunc
   bool isSupportedType(Type type) {
     // TODO(denis0x0D): Handle other types.
     if (auto memRefType = type.dyn_cast_or_null<MemRefType>())
-      return memRefType.hasRank() && memRefType.getRank() == 1;
+      return memRefType.hasRank() &&
+             (memRefType.getRank() == 1 || memRefType.getRank() == 2);
     return false;
   }
 
@@ -98,7 +99,8 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
 
   // Check that all operands have supported types except those for the launch
   // configuration.
-  for (auto type : llvm::drop_begin(vulkanLaunchTypes, 6)) {
+  for (auto type :
+       llvm::drop_begin(vulkanLaunchTypes, gpu::LaunchOp::kNumConfigOperands)) {
     if (!isSupportedType(type))
       return launchOp.emitError() << type << " is unsupported to run on Vulkan";
   }

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index f1dc52e5f856..d03adc2c64ac 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -24,10 +24,12 @@
 #include "mlir/Pass/Pass.h"
 
 #include "llvm/ADT/SmallString.h"
+#include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
 
 static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
+static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
 static constexpr const char *kCInterfaceVulkanLaunch =
     "_mlir_ciface_vulkanLaunch";
 static constexpr const char *kDeinitVulkan = "deinitVulkan";
@@ -87,12 +89,20 @@ class VulkanLaunchFuncToVulkanCallsPass
     auto llvmPtrToFloatType = getFloatType().getPointerTo();
     auto llvmArrayOneElementSizeType =
         LLVM::LLVMType::getArrayTy(getInt64Type(), 1);
+    auto llvmArrayTwoElementSizeType =
+        LLVM::LLVMType::getArrayTy(getInt64Type(), 2);
 
     // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`.
     llvmMemRef1DFloat = LLVM::LLVMType::getStructTy(
         llvmDialect,
         {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
          llvmArrayOneElementSizeType, llvmArrayOneElementSizeType});
+
+    // Create a type `!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64]}">`.
+    llvmMemRef2DFloat = LLVM::LLVMType::getStructTy(
+        llvmDialect,
+        {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
+         llvmArrayTwoElementSizeType, llvmArrayTwoElementSizeType});
   }
 
   LLVM::LLVMType getFloatType() { return llvmFloatType; }
@@ -101,6 +111,7 @@ class VulkanLaunchFuncToVulkanCallsPass
   LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
   LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
   LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
+  LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
 
   /// Creates a LLVM global for the given `name`.
   Value createEntryPointNameConstant(StringRef name, Location loc,
@@ -134,6 +145,9 @@ class VulkanLaunchFuncToVulkanCallsPass
   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
 
+  /// Deduces a rank from the given 'ptrToMemRefDescriptor`.
+  LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
+
 public:
   void runOnModule() override;
 
@@ -145,6 +159,7 @@ class VulkanLaunchFuncToVulkanCallsPass
   LLVM::LLVMType llvmInt32Type;
   LLVM::LLVMType llvmInt64Type;
   LLVM::LLVMType llvmMemRef1DFloat;
+  LLVM::LLVMType llvmMemRef2DFloat;
 
   // TODO: Use an associative array to support multiple vulkan launch calls.
   std::pair<StringAttr, StringAttr> spirvAttributes;
@@ -212,16 +227,54 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
     // Create LLVM constant for the descriptor binding index.
     Value descriptorBinding = builder.create<LLVM::ConstantOp>(
         loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
+
+    auto ptrToMemRefDescriptor = en.value();
+    uint32_t rank = 0;
+    if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
+      cInterfaceVulkanLaunchCallOp.emitError()
+          << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
+      return signalPassFailure();
+    }
+
+    auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str();
     // Create call to `bindMemRef`.
     builder.create<LLVM::CallOp>(
         loc, ArrayRef<Type>{getVoidType()},
-        // TODO: Add support for memref with other ranks.
-        builder.getSymbolRefAttr(kBindMemRef1DFloat),
+        builder.getSymbolRefAttr(
+            StringRef(symbolName.data(), symbolName.size())),
         ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding,
-                        en.value()});
+                        ptrToMemRefDescriptor});
   }
 }
 
+LogicalResult
+VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
+                                                    uint32_t &rank) {
+  auto llvmPtrDescriptorTy =
+      ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
+  if (!llvmPtrDescriptorTy)
+    return failure();
+
+  auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy();
+  // template <typename Elem, size_t Rank>
+  // struct {
+  //   Elem *allocated;
+  //   Elem *aligned;
+  //   int64_t offset;
+  //   int64_t sizes[Rank]; // omitted when rank == 0
+  //   int64_t strides[Rank]; // omitted when rank == 0
+  // };
+  if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
+    return failure();
+  if (llvmDescriptorTy.getStructNumElements() == 3) {
+    rank = 0;
+    return success();
+  }
+
+  rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
+  return success();
+}
+
 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
   ModuleOp module = getModule();
   OpBuilder builder(module.getBody()->getTerminator());
@@ -268,6 +321,16 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
                                       /*isVarArg=*/false));
   }
 
+  if (!module.lookupSymbol(kBindMemRef2DFloat)) {
+    builder.create<LLVM::LLVMFuncOp>(
+        loc, kBindMemRef2DFloat,
+        LLVM::LLVMType::getFunctionTy(getVoidType(),
+                                      {getPointerType(), getInt32Type(),
+                                       getInt32Type(),
+                                       getMemRef2DFloat().getPointerTo()},
+                                      /*isVarArg=*/false));
+  }
+
   if (!module.lookupSymbol(kInitVulkan)) {
     builder.create<LLVM::LLVMFuncOp>(
         loc, kInitVulkan,

diff  --git a/mlir/test/mlir-vulkan-runner/mulf.mlir b/mlir/test/mlir-vulkan-runner/mulf.mlir
new file mode 100644
index 000000000000..dc962108cbc3
--- /dev/null
+++ b/mlir/test/mlir-vulkan-runner/mulf.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-vulkan-runner %s --shared-libs=%vulkan_wrapper_library_dir/libvulkan-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+// CHECK-COUNT-4: [6, 6, 6, 6]
+module attributes {
+  gpu.container_module,
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+  gpu.module @kernels {
+    gpu.func @kernel_mul(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>, %arg2 : memref<4x4xf32>)
+      attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} {
+      %x = "gpu.block_id"() {dimension = "x"} : () -> index
+      %y = "gpu.block_id"() {dimension = "y"} : () -> index
+      %1 = load %arg0[%x, %y] : memref<4x4xf32>
+      %2 = load %arg1[%x, %y] : memref<4x4xf32>
+      %3 = mulf %1, %2 : f32
+      store %3, %arg2[%x, %y] : memref<4x4xf32>
+      gpu.return
+    }
+  }
+
+  func @main() {
+    %arg0 = alloc() : memref<4x4xf32>
+    %arg1 = alloc() : memref<4x4xf32>
+    %arg2 = alloc() : memref<4x4xf32>
+    %0 = constant 0 : i32
+    %1 = constant 1 : i32
+    %2 = constant 2 : i32
+    %value0 = constant 0.0 : f32
+    %value1 = constant 2.0 : f32
+    %value2 = constant 3.0 : f32
+    %arg3 = memref_cast %arg0 : memref<4x4xf32> to memref<?x?xf32>
+    %arg4 = memref_cast %arg1 : memref<4x4xf32> to memref<?x?xf32>
+    %arg5 = memref_cast %arg2 : memref<4x4xf32> to memref<?x?xf32>
+    call @fillResource2DFloat(%arg3, %value1) : (memref<?x?xf32>, f32) -> ()
+    call @fillResource2DFloat(%arg4, %value2) : (memref<?x?xf32>, f32) -> ()
+    call @fillResource2DFloat(%arg5, %value0) : (memref<?x?xf32>, f32) -> ()
+
+    %cst1 = constant 1 : index
+    %cst4 = constant 4 : index
+    "gpu.launch_func"(%cst4, %cst4, %cst1, %cst1, %cst1, %cst1, %arg0, %arg1, %arg2) { kernel = "kernel_mul", kernel_module = @kernels }
+        : (index, index, index, index, index, index, memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> ()
+    %arg6 = memref_cast %arg5 : memref<?x?xf32> to memref<*xf32>
+    call @print_memref_f32(%arg6) : (memref<*xf32>) -> ()
+    return
+  }
+  func @fillResource2DFloat(%0 : memref<?x?xf32>, %1 : f32)
+  func @print_memref_f32(%ptr : memref<*xf32>)
+}
+

diff  --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
index 52c11dad8c7e..7cbd864df4fd 100644
--- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
@@ -111,9 +111,27 @@ void bindMemRef1DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex,
       ->setResourceData(setIndex, bindIndex, memBuffer);
 }
 
+/// Binds the given 2D float memref to the given descriptor set and descriptor
+/// index.
+void bindMemRef2DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex,
+                       BindingIndex bindIndex,
+                       MemRefDescriptor<float, 2> *ptr) {
+  VulkanHostMemoryBuffer memBuffer{
+      ptr->allocated,
+      static_cast<uint32_t>(ptr->sizes[0] * ptr->sizes[1] * sizeof(float))};
+  reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+      ->setResourceData(setIndex, bindIndex, memBuffer);
+}
+
 /// Fills the given 1D float memref with the given float value.
 void _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
                                       float value) {
   std::fill_n(ptr->allocated, ptr->sizes[0], value);
 }
+
+/// Fills the given 2D float memref with the given float value.
+void _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
+                                      float value) {
+  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
+}
 }


        


More information about the Mlir-commits mailing list