[Mlir-commits] [mlir] [mlir][spirv] Add GpuToLLVM cconv suited to Vulkan, migrate last tests (PR #123384)

Andrea Faulds llvmlistbot at llvm.org
Mon Jan 20 06:26:27 PST 2025


================
@@ -970,33 +971,56 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
   else if (launchOp.getAsyncToken())
     stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult();
 
-  if (typeCheckKernelArgs) {
-    // The current non-bare-pointer ABI is a bad fit for `mgpuLaunchKernel`,
-    // which takes an untyped list of arguments. The type check here prevents
-    // accidentally violating the assumption made in vulkan-runtime-wrappers.cpp
-    // and creating a unchecked runtime ABI mismatch.
-    // TODO(https://github.com/llvm/llvm-project/issues/73457): Change the ABI
-    // here to remove the need for this type check.
-    for (Value arg : launchOp.getKernelOperands()) {
-      if (auto memrefTy = dyn_cast<MemRefType>(arg.getType())) {
-        if (memrefTy.getRank() != 1 ||
-            memrefTy.getElementTypeBitWidth() != 32) {
-          return rewriter.notifyMatchFailure(
-              launchOp, "Operand to launch op is not a rank-1 memref with "
-                        "32-bit element type.");
-        }
-      } else {
+  // Lower the kernel operands to match kernel parameters.
+  // Note: If `useBarePtrCallConv` is set in the type converter's options,
+  // the value of `kernelBarePtrCallConv` will be ignored.
+  OperandRange origArguments = launchOp.getKernelOperands();
+  SmallVector<Value, 4> llvmArguments = getTypeConverter()->promoteOperands(
+      loc, origArguments, adaptor.getKernelOperands(), rewriter,
+      /*useBarePtrCallConv=*/kernelBarePtrCallConv);
+
+  // Intersperse size information if requested.
+  if (kernelIntersperseSizeCallConv) {
+    if (origArguments.size() != llvmArguments.size()) {
+      // This shouldn't happen if the bare-pointer calling convention is used.
+      return rewriter.notifyMatchFailure(
+          launchOp,
+          "Cannot add sizes to arguments with one-to-many LLVM IR expansion.");
+    }
+
+    SmallVector<Value, 8> llvmArgumentsWithSizes;
+    llvmArgumentsWithSizes.reserve(llvmArguments.size() * 2);
+    for (auto [llvmArg, origArg] : zip_equal(llvmArguments, origArguments)) {
+      auto memrefTy = dyn_cast<MemRefType>(origArg.getType());
+      if (!memrefTy) {
         return rewriter.notifyMatchFailure(
             launchOp, "Operand to launch op is not a memref.");
       }
+
+      if (!memrefTy.hasStaticShape() ||
+          !memrefTy.getElementType().isIntOrFloat()) {
+        return rewriter.notifyMatchFailure(
+            launchOp, "Operand to launch op is not a memref with a static "
+                      "shape and an integer or float element type.");
+      }
+
+      unsigned bitwidth = memrefTy.getElementTypeBitWidth();
+      if (bitwidth % 8 != 0) {
+        return rewriter.notifyMatchFailure(
+            launchOp, "Operand to launch op is not a memref with a "
+                      "byte-aligned element type.");
+      }
----------------
andfau-amd wrote:

Is there some rule somewhere that guarantees elements smaller than 8 bits get padded to 8?

I doubt the Vulkan target can actually do anything useful with non-8-bit types though, so maybe this simple check is fine.

https://github.com/llvm/llvm-project/pull/123384


More information about the Mlir-commits mailing list