[Mlir-commits] [mlir] [mlir][spirv] Add GpuToLLVM cconv suited to Vulkan, migrate last tests (PR #123384)

Andrea Faulds llvmlistbot at llvm.org
Mon Jan 20 09:59:14 PST 2025


https://github.com/andfau-amd updated https://github.com/llvm/llvm-project/pull/123384

>From 8fb279b409614e424faa7ca4ae9ce8b4e0d2b876 Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Mon, 20 Jan 2025 18:53:11 +0100
Subject: [PATCH] [mlir][spirv] Add GpuToLLVM cconv suited to Vulkan, migrate
 last tests

This commit is a follow-up to 99a562b3cb17e89273ba0fe77129f2fb17a19381,
which migrated some of the mlir-vulkan-runner tests to mlir-cpu-runner
using a new pipeline and set of wrappers. That commit could not migrate
all the tests, because the existing calling conventions/ABIs for kernel
arguments generated by GPUToLLVMConversionPass were not a good fit for
the Vulkan runtime. This commit fixes this and migrates the remaining
tests. With this commit, mlir-vulkan-runner and many related components
are now unused, and they will be removed in a later commit (see #73457).

The old calling conventions require both the caller (host LLVM code)
and callee (device code) to have compile-time knowledge of the precise
argument types. This works for CUDA, ROCm and SYCL, where there is a
C-like calling convention agreed between the host and device code, and
the runtime passes through arguments as raw data without comprehension.
For Vulkan, however, the interface declared by the shader/kernel is in a
more abstract form, so the device code has indirect access to the
argument data, and the runtime must process the arguments to set up and
bind appropriately-sized buffer descriptors.

This commit introduces a new calling convention option to meet the
Vulkan runtime's needs. It lowers memref arguments to {void*, size_t}
pairs, which can be trivially interpreted by the runtime without it
needing to know the original argument types. Unlike the stopgap measure
in the previous commit, this system can support memrefs of various ranks
and element types, which unblocked migrating the remaining tests.
---
 .../mlir/Conversion/GPUCommon/GPUCommonPass.h |  8 +-
 mlir/include/mlir/Conversion/Passes.td        | 10 ++-
 .../GPUCommon/GPUToLLVMConversion.cpp         | 90 ++++++++++++-------
 ...launch-func-bare-ptr-intersperse-size.mlir | 25 ++++++
 .../lib/Pass/TestVulkanRunnerPipeline.cpp     |  7 +-
 mlir/test/mlir-vulkan-runner/addi.mlir        |  4 +-
 mlir/test/mlir-vulkan-runner/addi8.mlir       |  4 +-
 mlir/test/mlir-vulkan-runner/mulf.mlir        |  4 +-
 mlir/test/mlir-vulkan-runner/subf.mlir        |  4 +-
 .../vulkan-runtime-wrappers.cpp               | 25 +++---
 10 files changed, 113 insertions(+), 68 deletions(-)
 create mode 100644 mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr-intersperse-size.mlir

diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
index cf0c96f0eba000..312f1fc5b20c95 100644
--- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
@@ -62,10 +62,10 @@ struct FunctionCallBuilder {
 
 /// Collect a set of patterns to convert from the GPU dialect to LLVM and
 /// populate converter for gpu types.
-void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                         RewritePatternSet &patterns,
-                                         bool kernelBarePtrCallConv = false,
-                                         bool typeCheckKernelArgs = false);
+void populateGpuToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    bool kernelBarePtrCallConv = false,
+    bool kernelIntersperseSizeCallConv = false);
 
 /// A function that maps a MemorySpace enum to a target-specific integer value.
 using MemorySpaceMapping = std::function<unsigned(gpu::AddressSpace)>;
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 0f42ffb3a80266..34f7ab46b62982 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -518,11 +518,13 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
              "Use bare pointers to pass memref arguments to kernels. "
              "The kernel must use the same setting for this option."
           >,
-    Option<"typeCheckKernelArgs", "type-check-kernel-args", "bool",
+    Option<"kernelIntersperseSizeCallConv", "intersperse-sizes-for-kernels", "bool",
            /*default=*/"false",
-             "Require all kernel arguments to be memrefs of rank 1 and with a "
-             "32-bit element size. This is a temporary option that will be "
-             "removed; TODO(https://github.com/llvm/llvm-project/issues/73457)."
+           "Inserts a size_t argument following each memref argument, "
+           "containing the static size in bytes of the buffer. Incompatible "
+           "arguments are rejected. This is intended for use by the Vulkan "
+           "runtime with the kernel bare pointer calling convention, to enable "
+           "dynamic binding of buffers as arguments without static type info."
           >
   ];
 
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index ca9883a79dc168..8017eb6bb383b4 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -428,10 +428,10 @@ class LegalizeLaunchFuncOpPattern
 public:
   LegalizeLaunchFuncOpPattern(const LLVMTypeConverter &typeConverter,
                               bool kernelBarePtrCallConv,
-                              bool typeCheckKernelArgs)
+                              bool kernelIntersperseSizeCallConv)
       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
         kernelBarePtrCallConv(kernelBarePtrCallConv),
-        typeCheckKernelArgs(typeCheckKernelArgs) {}
+        kernelIntersperseSizeCallConv(kernelIntersperseSizeCallConv) {}
 
 private:
   LogicalResult
@@ -439,7 +439,7 @@ class LegalizeLaunchFuncOpPattern
                   ConversionPatternRewriter &rewriter) const override;
 
   bool kernelBarePtrCallConv;
-  bool typeCheckKernelArgs;
+  bool kernelIntersperseSizeCallConv;
 };
 
 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
@@ -566,8 +566,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
   populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
                                                     target);
-  populateGpuToLLVMConversionPatterns(
-      converter, patterns, kernelBarePtrCallConv, typeCheckKernelArgs);
+  populateGpuToLLVMConversionPatterns(converter, patterns,
+                                      kernelBarePtrCallConv,
+                                      kernelIntersperseSizeCallConv);
 
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
@@ -970,33 +971,55 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
   else if (launchOp.getAsyncToken())
     stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
 
-  if (typeCheckKernelArgs) {
-    // The current non-bare-pointer ABI is a bad fit for `mgpuLaunchKernel`,
-    // which takes an untyped list of arguments. The type check here prevents
-    // accidentally violating the assumption made in vulkan-runtime-wrappers.cpp
-    // and creating a unchecked runtime ABI mismatch.
-    // TODO(https://github.com/llvm/llvm-project/issues/73457): Change the ABI
-    // here to remove the need for this type check.
-    for (Value arg : launchOp.getKernelOperands()) {
-      if (auto memrefTy = dyn_cast<MemRefType>(arg.getType())) {
-        if (memrefTy.getRank() != 1 ||
-            memrefTy.getElementTypeBitWidth() != 32) {
-          return rewriter.notifyMatchFailure(
-              launchOp, "Operand to launch op is not a rank-1 memref with "
-                        "32-bit element type.");
-        }
-      } else {
+  // Lower the kernel operands to match kernel parameters.
+  // Note: If `useBarePtrCallConv` is set in the type converter's options,
+  // the value of `kernelBarePtrCallConv` will be ignored.
+  OperandRange origArguments = launchOp.getKernelOperands();
+  SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
+      loc, origArguments, adaptor.getKernelOperands(), rewriter,
+      /*useBarePtrCallConv=*/kernelBarePtrCallConv);
+  SmallVector<Value, 8> llvmArgumentsWithSizes;
+
+  // Intersperse size information if requested.
+  if (kernelIntersperseSizeCallConv) {
+    if (origArguments.size() != llvmArguments.size()) {
+      // This shouldn't happen if the bare-pointer calling convention is used.
+      return rewriter.notifyMatchFailure(
+          launchOp,
+          "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
+    }
+
+    llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
+    for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
+      auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
+      if (!memrefTy) {
         return rewriter.notifyMatchFailure(
             launchOp, "Operand to launch op is not a memref.");
       }
+
+      if (!memrefTy.hasStaticShape() ||
+          !memrefTy.getElementType().isIntOrFloat()) {
+        return rewriter.notifyMatchFailure(
+            launchOp, "Operand to launch op is not a memref with a static "
+                      "shape and an integer or float element type.");
+      }
+
+      unsigned bitwidth = memrefTy.getElementTypeBitWidth();
+      if (bitwidth % 8 != 0) {
+        return rewriter.notifyMatchFailure(
+            launchOp, "Operand to launch op is not a memref with a "
+                      "byte-aligned element type.");
+      }
+
+      uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) *
+                            static_cast<uint64_t>(memrefTy.getNumElements());
+
+      Value sizeArg = rewriter.create<LLVM::ConstantOp>(
+          loc, getIndexType(), rewriter.getIndexAttr(staticSize));
+      llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer.
+      llvmArgumentsWithSizes.push_back(sizeArg);
     }
   }
-  // Lower the kernel operands to match kernel parameters.
-  // Note: If `useBarePtrCallConv` is set in the type converter's options,
-  // the value of `kernelBarePtrCallConv` will be ignored.
-  SmallVector<Value, 4> arguments = getTypeConverter()->promoteOperands(
-      loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter,
-      /*useBarePtrCallConv=*/kernelBarePtrCallConv);
 
   std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
   if (launchOp.hasClusterSize()) {
@@ -1010,7 +1033,9 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
                       adaptor.getGridSizeZ()},
       gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
                       adaptor.getBlockSizeZ()},
-      adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
+      adaptor.getDynamicSharedMemorySize(),
+      llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
+      stream, clusterSize);
   if (launchOp.getAsyncToken())
     rewriter.replaceOp(launchOp, {stream});
   else
@@ -1760,10 +1785,9 @@ LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
-void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                               RewritePatternSet &patterns,
-                                               bool kernelBarePtrCallConv,
-                                               bool typeCheckKernelArgs) {
+void mlir::populateGpuToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    bool kernelBarePtrCallConv, bool kernelIntersperseSizeCallConv) {
   addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
   addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
   addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
@@ -1801,7 +1825,7 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                ConvertSpMatGetSizeOpToGpuRuntimeCallPattern,
                ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
   patterns.add<LegalizeLaunchFuncOpPattern>(converter, kernelBarePtrCallConv,
-                                            typeCheckKernelArgs);
+                                            kernelIntersperseSizeCallConv);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr-intersperse-size.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr-intersperse-size.mlir
new file mode 100644
index 00000000000000..171b13da227136
--- /dev/null
+++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr-intersperse-size.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-kernels=1 intersperse-sizes-for-kernels=1" -split-input-file | FileCheck %s
+
+module attributes {gpu.container_module, spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+  llvm.func @malloc(i64) -> !llvm.ptr
+  gpu.binary @kernels  [#gpu.object<#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>, "">]
+  func.func @main() attributes {llvm.emit_c_interface} {
+    // CHECK: [[RANK1UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    %rank1UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+    // CHECK: [[RANK2UMD:%.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    %rank2UndefMemrefDescriptor = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+    %c1 = arith.constant 1 : index
+    // CHECK: [[PTR1:%.*]] = llvm.extractvalue [[RANK1UMD]][1]
+    // CHECK: [[PTR2:%.*]] = llvm.extractvalue [[RANK2UMD]][1]
+    // CHECK: [[PTR3:%.*]] = llvm.extractvalue [[RANK2UMD]][1]
+    // CHECK: [[SIZE1:%.*]] = llvm.mlir.constant(32 : index) : i64
+    // CHECK: [[SIZE2:%.*]] = llvm.mlir.constant(256 : index) : i64
+    // CHECK: [[SIZE3:%.*]] = llvm.mlir.constant(48 : index) : i64
+    %6 = builtin.unrealized_conversion_cast %rank1UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<8xf32>
+    %10 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<8x8xi32>
+    %14 = builtin.unrealized_conversion_cast %rank2UndefMemrefDescriptor : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<4x12xi8>
+    // CHECK: gpu.launch_func  @kernels::@kernel_add blocks in ({{.*}}) threads in ({{.*}}) : i64 args([[PTR1]] : !llvm.ptr, [[SIZE1]] : i64, [[PTR2]] : !llvm.ptr, [[SIZE2]] : i64, [[PTR3]] : !llvm.ptr, [[SIZE3]] : i64)
+    gpu.launch_func  @kernels::@kernel_add blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)  args(%6 : memref<8xf32>, %10 : memref<8x8xi32>, %14 : memref<4x12xi8>)
+    return
+  }
+}
diff --git a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
index a3624eb31e26e5..fa6a5a7140d8ca 100644
--- a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
+++ b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
@@ -68,12 +68,11 @@ void buildTestVulkanRunnerPipeline(OpPassManager &passManager,
     passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
     passManager.nest<func::FuncOp>().addPass(
         LLVM::createRequestCWrappersPass());
-    // vulkan-runtime-wrappers.cpp uses the non-bare-pointer calling convention,
-    // and the type check is needed to prevent accidental ABI mismatches.
+    // vulkan-runtime-wrappers.cpp requires these calling convention options.
     GpuToLLVMConversionPassOptions opt;
     opt.hostBarePtrCallConv = false;
-    opt.kernelBarePtrCallConv = false;
-    opt.typeCheckKernelArgs = true;
+    opt.kernelBarePtrCallConv = true;
+    opt.kernelIntersperseSizeCallConv = true;
     passManager.addPass(createGpuToLLVMConversionPass(opt));
   }
 }
diff --git a/mlir/test/mlir-vulkan-runner/addi.mlir b/mlir/test/mlir-vulkan-runner/addi.mlir
index 7e212a4fb179c2..e60eb7696770a8 100644
--- a/mlir/test/mlir-vulkan-runner/addi.mlir
+++ b/mlir/test/mlir-vulkan-runner/addi.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
 
 // CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3]
 module attributes {
diff --git a/mlir/test/mlir-vulkan-runner/addi8.mlir b/mlir/test/mlir-vulkan-runner/addi8.mlir
index e0b1a8e8875c02..6374f29fd0ed2d 100644
--- a/mlir/test/mlir-vulkan-runner/addi8.mlir
+++ b/mlir/test/mlir-vulkan-runner/addi8.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
 
 // CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3]
 module attributes {
diff --git a/mlir/test/mlir-vulkan-runner/mulf.mlir b/mlir/test/mlir-vulkan-runner/mulf.mlir
index 22fa034a9b455f..bd1b8d9abf4c9f 100644
--- a/mlir/test/mlir-vulkan-runner/mulf.mlir
+++ b/mlir/test/mlir-vulkan-runner/mulf.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
 
 // CHECK-COUNT-4: [6, 6, 6, 6]
 module attributes {
diff --git a/mlir/test/mlir-vulkan-runner/subf.mlir b/mlir/test/mlir-vulkan-runner/subf.mlir
index 23496ef3abc00d..e8cee0e021a277 100644
--- a/mlir/test/mlir-vulkan-runner/subf.mlir
+++ b/mlir/test/mlir-vulkan-runner/subf.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-vulkan-runner-pipeline \
-// RUN:   | mlir-vulkan-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
+// RUN: mlir-opt %s -test-vulkan-runner-pipeline=to-llvm \
+// RUN:   | mlir-cpu-runner - --shared-libs=%vulkan-runtime-wrappers,%mlir_runner_utils --entry-point-result=void | FileCheck %s
 
 // CHECK-COUNT-32: [2.2, 2.2, 2.2, 2.2]
 module attributes {
diff --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
index ffd1114cec6aa3..8d1bac3b6f2869 100644
--- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
@@ -169,26 +169,21 @@ mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
                  void ** /*extra*/, size_t paramsCount) {
   auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
 
-  // The non-bare-pointer memref ABI interacts badly with mgpuLaunchKernel's
-  // signature:
-  // - The memref descriptor struct gets split into several elements, each
-  //   passed as their own "param".
-  // - No metadata is provided as to the rank or element type/size of a memref.
-  //   Here we assume that all MemRefs have rank 1 and an element size of
-  //   4 bytes. This means each descriptor struct will have five members.
-  // TODO(https://github.com/llvm/llvm-project/issues/73457): Refactor the
-  //       ABI/API of mgpuLaunchKernel to use a different ABI for memrefs, so
-  //       that other memref types can also be used. This will allow migrating
-  //       the remaining tests and removal of mlir-vulkan-runner.
-  const size_t paramsPerMemRef = 5;
+  // GpuToLLVMConversionPass with the kernelBarePtrCallConv and
+  // kernelIntersperseSizeCallConv options will set up the params array like:
+  // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... }
+  const size_t paramsPerMemRef = 2;
   if (paramsCount % paramsPerMemRef != 0) {
-    abort();
+    abort(); // This would indicate a serious calling convention mismatch.
   }
   const DescriptorSetIndex setIndex = 0;
   BindingIndex bindIndex = 0;
   for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) {
-    auto memref = static_cast<MemRefDescriptor<uint32_t, 1> *>(params[i]);
-    bindMemRef<uint32_t, 1>(manager, setIndex, bindIndex, memref);
+    void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]);
+    size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]);
+    VulkanHostMemoryBuffer memBuffer{memrefBufferBasePtr,
+                                     static_cast<uint32_t>(memrefBufferSize)};
+    manager->setResourceData(setIndex, bindIndex, memBuffer);
     ++bindIndex;
   }
 



More information about the Mlir-commits mailing list