[Mlir-commits] [mlir] 91be358 - [mlir][GPUToVulkan] Port conversion passes and `mlir-vulkan-runner` to opaque pointers

Markus Böck llvmlistbot at llvm.org
Fri Feb 24 08:16:08 PST 2023


Author: Markus Böck
Date: 2023-02-24T17:04:42+01:00
New Revision: 91be3586b52e8c5fd2891a35524068dc139a7d23

URL: https://github.com/llvm/llvm-project/commit/91be3586b52e8c5fd2891a35524068dc139a7d23
DIFF: https://github.com/llvm/llvm-project/commit/91be3586b52e8c5fd2891a35524068dc139a7d23.diff

LOG: [mlir][GPUToVulkan] Port conversion passes and `mlir-vulkan-runner` to opaque pointers

Part of https://discourse.llvm.org/t/rfc-switching-the-llvm-dialect-and-dialect-lowerings-to-opaque-pointers/68179

This patch adds the new pass option 'use-opaque-pointers' to `-launch-func-to-vulkan` instructing the pass to emit LLVM opaque pointers instead of typed pointers.

Note that the pass as it was previously implemented relied on the fact LLVM pointers carried an element type. The passed used this information to deduce both the rank of a "lowered-to-llvm" MemRef as well as the element type. Since the element type when using LLVM opaque pointers is completely erased it is not possible to deduce the element type.

I therefore added a new attribute that is attached to the `vulkanLaunch` call alongside the binary blob and entry point name by the `-convert-gpu-launch-to-vulkan-launch` pass. It simply attaches a type array specifying the element types of each memref. This way the `-launch-func-to-vulkan` can simply read out the element type from the attribute.
The rank can still be deduced from the auto-generated C interface from `FinalizeMemRefToLLVM`. This is admittedly a bit fragile but I was not sure whether it was worth the effort to also add a rank array attribute.

As a last step, the use of opaque-pointers in `mlir-vulkan-runners` codegen pipeline was also enabled, since all covnersion passes used fully support it.

Differential Revision: https://reviews.llvm.org/D144460

Added: 
    mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir

Modified: 
    mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
    mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
    mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
    mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
    mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h b/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
index 5a528df18e6cb..f69720328f2a4 100644
--- a/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
+++ b/mlir/include/mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h
@@ -23,14 +23,12 @@ namespace mlir {
 class ModuleOp;
 template <typename T>
 class OperationPass;
+class Pass;
 
-#define GEN_PASS_DECL_CONVERTVULKANLAUNCHFUNCTOVULKANCALLS
+#define GEN_PASS_DECL_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
 #define GEN_PASS_DECL_CONVERTGPULAUNCHFUNCTOVULKANLAUNCHFUNC
 #include "mlir/Conversion/Passes.h.inc"
 
-std::unique_ptr<OperationPass<ModuleOp>>
-createConvertVulkanLaunchFuncToVulkanCallsPass();
-
 std::unique_ptr<OperationPass<mlir::ModuleOp>>
 createConvertGpuLaunchFuncToVulkanLaunchFuncPass();
 

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b069627298e1d..70ecfba5ea49f 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -492,14 +492,20 @@ def ConvertGpuLaunchFuncToVulkanLaunchFunc
   let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
-def ConvertVulkanLaunchFuncToVulkanCalls
+def ConvertVulkanLaunchFuncToVulkanCallsPass
     : Pass<"launch-func-to-vulkan", "ModuleOp"> {
   let summary = "Convert vulkanLaunch external call to Vulkan runtime external "
                 "calls";
   let description = [{
     This pass is only intended for the mlir-vulkan-runner.
   }];
-  let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()";
+
+  let options = [
+     Option<"useOpaquePointers", "use-opaque-pointers", "bool",
+            /*default=*/"false", "Generate LLVM IR using opaque pointers "
+            "instead of typed pointers">
+  ];
+
   let dependentDialects = ["LLVM::LLVMDialect"];
 }
 

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index b6aff16407495..e4ac64252acc4 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -35,6 +35,7 @@ using namespace mlir;
 
 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
 
 namespace {
@@ -189,6 +190,18 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
   vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName,
                               launchOp.getKernelName());
 
+  // Add MemRef element types before they're lost when lowering to LLVM.
+  SmallVector<Type> elementTypes;
+  for (Type type : llvm::drop_begin(launchOp.getOperandTypes(),
+                                    gpu::LaunchOp::kNumConfigOperands)) {
+    // The below cast always succeeds as it has already been verified in
+    // 'declareVulkanLaunchFunc' that these are MemRefs with compatible element
+    // types.
+    elementTypes.push_back(type.cast<MemRefType>().getElementType());
+  }
+  vulkanLaunchCallOp->setAttr(kSPIRVElementTypesAttrName,
+                              builder.getTypeArrayAttr(elementTypes));
+
   launchOp.erase();
 }
 

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 9fca86317ed11..78d1f6790c859 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -25,7 +25,7 @@
 #include "llvm/Support/FormatVariadic.h"
 
 namespace mlir {
-#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLS
+#define GEN_PASS_DEF_CONVERTVULKANLAUNCHFUNCTOVULKANCALLSPASS
 #include "mlir/Conversion/Passes.h.inc"
 } // namespace mlir
 
@@ -42,6 +42,7 @@ static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups";
 static constexpr const char *kSPIRVBinary = "SPIRV_BIN";
 static constexpr const char *kSPIRVBlobAttrName = "spirv_blob";
 static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point";
+static constexpr const char *kSPIRVElementTypesAttrName = "spirv_element_types";
 static constexpr const char *kVulkanLaunch = "vulkanLaunch";
 
 namespace {
@@ -58,14 +59,17 @@ namespace {
 /// * deinitVulkan         -- deinitializes vulkan runtime
 ///
 class VulkanLaunchFuncToVulkanCallsPass
-    : public impl::ConvertVulkanLaunchFuncToVulkanCallsBase<
+    : public impl::ConvertVulkanLaunchFuncToVulkanCallsPassBase<
           VulkanLaunchFuncToVulkanCallsPass> {
 private:
   void initializeCachedTypes() {
     llvmFloatType = Float32Type::get(&getContext());
     llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
-    llvmPointerType =
-        LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
+    if (useOpaquePointers)
+      llvmPointerType = LLVM::LLVMPointerType::get(&getContext());
+    else
+      llvmPointerType =
+          LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
     llvmInt32Type = IntegerType::get(&getContext(), 32);
     llvmInt64Type = IntegerType::get(&getContext(), 64);
   }
@@ -81,7 +85,9 @@ class VulkanLaunchFuncToVulkanCallsPass
     //   int64_t sizes[Rank]; // omitted when rank == 0
     //   int64_t strides[Rank]; // omitted when rank == 0
     // };
-    auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType);
+    auto llvmPtrToElementType = useOpaquePointers
+                                    ? llvmPointerType
+                                    : LLVM::LLVMPointerType::get(elemenType);
     auto llvmArrayRankElementSizeType =
         LLVM::LLVMArrayType::get(getInt64Type(), rank);
 
@@ -131,9 +137,8 @@ class VulkanLaunchFuncToVulkanCallsPass
   /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
   void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
 
-  /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
-  LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
-                                        uint32_t &rank, Type &type);
+  /// Deduces a rank from the given 'launchCallArg`.
+  LogicalResult deduceMemRefRank(Value launchCallArg, uint32_t &rank);
 
   /// Returns a string representation from the given `type`.
   StringRef stringifyType(Type type) {
@@ -154,6 +159,8 @@ class VulkanLaunchFuncToVulkanCallsPass
   }
 
 public:
+  using Base::Base;
+
   void runOnOperation() override;
 
 private:
@@ -163,8 +170,14 @@ class VulkanLaunchFuncToVulkanCallsPass
   Type llvmInt32Type;
   Type llvmInt64Type;
 
+  struct SPIRVAttributes {
+    StringAttr blob;
+    StringAttr entryPoint;
+    SmallVector<Type> elementTypes;
+  };
+
   // TODO: Use an associative array to support multiple vulkan launch calls.
-  std::pair<StringAttr, StringAttr> spirvAttributes;
+  SPIRVAttributes spirvAttributes;
   /// The number of vulkan launch configuration operands, placed at the leading
   /// positions of the operand list.
   static constexpr unsigned kVulkanLaunchNumConfigOperands = 3;
@@ -209,7 +222,24 @@ void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes(
     return signalPassFailure();
   }
 
-  spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr);
+  auto spirvElementTypesAttr =
+      vulkanLaunchCallOp->getAttrOfType<ArrayAttr>(kSPIRVElementTypesAttrName);
+  if (!spirvElementTypesAttr) {
+    vulkanLaunchCallOp.emitError()
+        << "missing " << kSPIRVElementTypesAttrName << " attribute";
+    return signalPassFailure();
+  }
+  if (llvm::any_of(spirvElementTypesAttr,
+                   [](Attribute attr) { return !isa<TypeAttr>(attr); })) {
+    vulkanLaunchCallOp.emitError()
+        << "expected " << spirvElementTypesAttr << " to be an array of types";
+    return signalPassFailure();
+  }
+
+  spirvAttributes.blob = spirvBlobAttr;
+  spirvAttributes.entryPoint = spirvEntryPointNameAttr;
+  spirvAttributes.elementTypes =
+      llvm::to_vector(spirvElementTypesAttr.getAsValueRange<mlir::TypeAttr>());
 }
 
 void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
@@ -226,17 +256,23 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
   Value descriptorSet =
       builder.create<LLVM::ConstantOp>(loc, getInt32Type(), 0);
 
-  for (const auto &en :
+  for (auto [index, ptrToMemRefDescriptor] :
        llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
            kVulkanLaunchNumConfigOperands))) {
     // Create LLVM constant for the descriptor binding index.
     Value descriptorBinding =
-        builder.create<LLVM::ConstantOp>(loc, getInt32Type(), en.index());
+        builder.create<LLVM::ConstantOp>(loc, getInt32Type(), index);
+
+    if (index >= spirvAttributes.elementTypes.size()) {
+      cInterfaceVulkanLaunchCallOp.emitError()
+          << kSPIRVElementTypesAttrName << " missing element type for "
+          << ptrToMemRefDescriptor;
+      return signalPassFailure();
+    }
 
-    auto ptrToMemRefDescriptor = en.value();
     uint32_t rank = 0;
-    Type type;
-    if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
+    Type type = spirvAttributes.elementTypes[index];
+    if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
       cInterfaceVulkanLaunchCallOp.emitError()
           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
       return signalPassFailure();
@@ -246,7 +282,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
     // Special case for fp16 type. Since it is not a supported type in C we use
     // int16_t and bitcast the descriptor.
-    if (type.isa<Float16Type>()) {
+    if (!useOpaquePointers && type.isa<Float16Type>()) {
       auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
       ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
           loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
@@ -259,15 +295,24 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
   }
 }
 
-LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
-    Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) {
-  auto llvmPtrDescriptorTy =
-      ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
-  if (!llvmPtrDescriptorTy)
+LogicalResult
+VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg,
+                                                    uint32_t &rank) {
+  // Deduce the rank from the type used to allocate the lowered MemRef.
+  auto alloca = launchCallArg.getDefiningOp<LLVM::AllocaOp>();
+  if (!alloca)
     return failure();
 
-  auto llvmDescriptorTy =
-      llvmPtrDescriptorTy.getElementType().dyn_cast<LLVM::LLVMStructType>();
+  LLVM::LLVMStructType llvmDescriptorTy;
+  if (std::optional<Type> elementType = alloca.getElemType()) {
+    llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
+  } else {
+    // This case is only possible if we are not using opaque pointers
+    // since opaque pointer producing allocas require an element type.
+    llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(
+        alloca.getRes().getType().getElementType());
+  }
+
   // template <typename Elem, size_t Rank>
   // struct {
   //   Elem *allocated;
@@ -279,9 +324,6 @@ LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
   if (!llvmDescriptorTy)
     return failure();
 
-  type = llvmDescriptorTy.getBody()[0]
-             .cast<LLVM::LLVMPointerType>()
-             .getElementType();
   if (llvmDescriptorTy.getBody().size() == 3) {
     rank = 0;
     return success();
@@ -339,7 +381,9 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
         auto fnType = LLVM::LLVMFunctionType::get(
             getVoidType(),
             {getPointerType(), getInt32Type(), getInt32Type(),
-             LLVM::LLVMPointerType::get(getMemRefType(i, type))},
+             useOpaquePointers
+                 ? llvmPointerType
+                 : LLVM::LLVMPointerType::get(getMemRefType(i, type))},
             /*isVarArg=*/false);
         builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
       }
@@ -368,7 +412,7 @@ Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
                                   shaderName, LLVM::Linkage::Internal,
-                                  /*TODO:useOpaquePointers=*/false);
+                                  useOpaquePointers);
 }
 
 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
@@ -385,12 +429,12 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
   // Create LLVM global with SPIR-V binary data, so we can pass a pointer with
   // that data to runtime call.
   Value ptrToSPIRVBinary = LLVM::createGlobalString(
-      loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(),
-      LLVM::Linkage::Internal, /*TODO:useOpaquePointers=*/false);
+      loc, builder, kSPIRVBinary, spirvAttributes.blob.getValue(),
+      LLVM::Linkage::Internal, useOpaquePointers);
 
   // Create LLVM constant for the size of SPIR-V binary shader.
   Value binarySize = builder.create<LLVM::ConstantOp>(
-      loc, getInt32Type(), spirvAttributes.first.getValue().size());
+      loc, getInt32Type(), spirvAttributes.blob.getValue().size());
 
   // Create call to `bindMemRef` for each memref operand.
   createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);
@@ -402,7 +446,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
       ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize});
   // Create LLVM global with entry point name.
   Value entryPointName = createEntryPointNameConstant(
-      spirvAttributes.second.getValue(), loc, builder);
+      spirvAttributes.entryPoint.getValue(), loc, builder);
   // Create call to `setEntryPoint` runtime function with the given pointer to
   // entry point name.
   builder.create<LLVM::CallOp>(loc, TypeRange(), kSetEntryPoint,
@@ -428,8 +472,3 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
 
   cInterfaceVulkanLaunchCallOp.erase();
 }
-
-std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
-mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() {
-  return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>();
-}

diff  --git a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
index b99584ba126f0..c77bf238c8ca2 100644
--- a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
+++ b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
@@ -1,63 +1,62 @@
-// RUN: mlir-opt %s -launch-func-to-vulkan | FileCheck %s
+// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=1' | FileCheck %s
 
 // CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
 // CHECK: llvm.mlir.global internal constant @SPIRV_BIN
-// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr<i8>
+// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr
 // CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
 // CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
 // CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
-// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
+// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i32, i32, !llvm.ptr) -> ()
+// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr, !llvm.ptr, i32) -> ()
 // CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
 // CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
-// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
-// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
-// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr, !llvm.ptr) -> ()
+// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, i64, i64, i64) -> ()
+// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr) -> ()
 
-// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
+// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr, i32, i32, !llvm.ptr)
 
 module attributes {gpu.container_module} {
-  llvm.func @malloc(i64) -> !llvm.ptr<i8>
+  llvm.func @malloc(i64) -> !llvm.ptr
   llvm.func @foo() {
     %0 = llvm.mlir.constant(12 : index) : i64
-    %1 = llvm.mlir.null : !llvm.ptr<f32>
+    %1 = llvm.mlir.null : !llvm.ptr
     %2 = llvm.mlir.constant(1 : index) : i64
-    %3 = llvm.getelementptr %1[%2] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-    %4 = llvm.ptrtoint %3 : !llvm.ptr<f32> to i64
+    %3 = llvm.getelementptr %1[%2] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    %4 = llvm.ptrtoint %3 : !llvm.ptr to i64
     %5 = llvm.mul %0, %4 : i64
-    %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr<i8>
-    %7 = llvm.bitcast %6 : !llvm.ptr<i8> to !llvm.ptr<f32>
-    %8 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr
+    %8 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %9 = llvm.insertvalue %6, %8[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %10 = llvm.insertvalue %6, %9[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     %11 = llvm.mlir.constant(0 : index) : i64
-    %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     %13 = llvm.mlir.constant(1 : index) : i64
-    %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     %16 = llvm.mlir.constant(1 : index) : i64
-    %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_entry_point = "kernel"}
-    : (i64, i64, i64, !llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64) -> ()
+    %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_element_types = [f32], spirv_entry_point = "kernel"}
+    : (i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> ()
     llvm.return
   }
-  llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr<f32>, %arg7: !llvm.ptr<f32>, %arg8: i64, %arg9: i64, %arg10: i64) {
-    %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr, %arg7: !llvm.ptr, %arg8: i64, %arg9: i64, %arg10: i64) {
+    %0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
     %6 = llvm.mlir.constant(1 : index) : i64
-    %7 = llvm.alloca %6 x !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
-    llvm.store %5, %7 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
-    llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+    %7 = llvm.alloca %6 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr
+    llvm.store %5, %7 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, !llvm.ptr
+    llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr) -> ()
     llvm.return
   }
-  llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>)
+  llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr)
 }

diff  --git a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
index 3b176c3cf8346..13eb3a194df44 100644
--- a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
+++ b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
@@ -2,7 +2,7 @@
 
 // CHECK: %[[resource:.*]] = memref.alloc() : memref<12xf32>
 // CHECK: %[[index:.*]] = arith.constant 1 : index
-// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_entry_point = "kernel"}
+// CHECK: call @vulkanLaunch(%[[index]], %[[index]], %[[index]], %[[resource]]) {spirv_blob = "{{.*}}", spirv_element_types = [f32], spirv_entry_point = "kernel"}
 
 module attributes {gpu.container_module} {
   spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {

diff  --git a/mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir b/mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir
new file mode 100644
index 0000000000000..67bd640e5d44c
--- /dev/null
+++ b/mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=0' | FileCheck %s
+
+// CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
+// CHECK: llvm.mlir.global internal constant @SPIRV_BIN
+// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr<i8>
+// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
+// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
+// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
+// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
+// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
+// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
+// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
+// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
+
+// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
+
+module attributes {gpu.container_module} {
+  llvm.func @malloc(i64) -> !llvm.ptr<i8>
+  llvm.func @foo() {
+    %0 = llvm.mlir.constant(12 : index) : i64
+    %1 = llvm.mlir.null : !llvm.ptr<f32>
+    %2 = llvm.mlir.constant(1 : index) : i64
+    %3 = llvm.getelementptr %1[%2] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+    %4 = llvm.ptrtoint %3 : !llvm.ptr<f32> to i64
+    %5 = llvm.mul %0, %4 : i64
+    %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr<i8>
+    %7 = llvm.bitcast %6 : !llvm.ptr<i8> to !llvm.ptr<f32>
+    %8 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %11 = llvm.mlir.constant(0 : index) : i64
+    %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %13 = llvm.mlir.constant(1 : index) : i64
+    %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %16 = llvm.mlir.constant(1 : index) : i64
+    %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_element_types = [f32], spirv_entry_point = "kernel"}
+    : (i64, i64, i64, !llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64) -> ()
+    llvm.return
+  }
+  llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr<f32>, %arg7: !llvm.ptr<f32>, %arg8: i64, %arg9: i64, %arg10: i64) {
+    %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+    %6 = llvm.mlir.constant(1 : index) : i64
+    %7 = llvm.alloca %6 x !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
+    llvm.store %5, %7 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
+    llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
+    llvm.return
+  }
+  llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>)
+}

diff  --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index caec6439ec747..d196902eb169f 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -68,16 +68,25 @@ static LogicalResult runMLIRPasses(Operation *op,
   if (options.spirvWebGPUPrepare)
     modulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
 
+  auto enableOpaquePointers = [](auto passOption) {
+    passOption.useOpaquePointers = true;
+    return passOption;
+  };
+
   passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
-  passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
-  passManager.addPass(createConvertVectorToLLVMPass());
+  passManager.addPass(createFinalizeMemRefToLLVMConversionPass(
+      enableOpaquePointers(FinalizeMemRefToLLVMConversionPassOptions{})));
+  passManager.addPass(createConvertVectorToLLVMPass(
+      enableOpaquePointers(ConvertVectorToLLVMPassOptions{})));
   passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
   ConvertFuncToLLVMPassOptions funcToLLVMOptions{};
   funcToLLVMOptions.indexBitwidth =
       DataLayout(module).getTypeSizeInBits(IndexType::get(module.getContext()));
-  passManager.addPass(createConvertFuncToLLVMPass(funcToLLVMOptions));
+  passManager.addPass(
+      createConvertFuncToLLVMPass(enableOpaquePointers(funcToLLVMOptions)));
   passManager.addPass(createReconcileUnrealizedCastsPass());
-  passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
+  passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass(
+      enableOpaquePointers(ConvertVulkanLaunchFuncToVulkanCallsPassOptions{})));
 
   return passManager.run(module);
 }


        


More information about the Mlir-commits mailing list