[Mlir-commits] [mlir] [MLIR][GPU] Add cooperative launch support to gpu.launch_func (PR #190639)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 6 11:08:40 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu
Author: Jared Hoberock (jaredhoberock)
<details>
<summary>Changes</summary>
Add a `cooperative` UnitAttr to `gpu.launch_func` that enables cooperative
kernel launch semantics. Cooperative launches guarantee that all thread blocks
in the grid are co-resident on the GPU simultaneously, enabling grid-wide
synchronization patterns.
## Implementation
When `cooperative` is set (with or without cluster sizes), the lowering emits
a call to the new `mgpuLaunchKernelEx` runtime function, which uses
`cuLaunchKernelEx` with a `CUlaunchConfig` and `CU_LAUNCH_ATTRIBUTE_COOPERATIVE`.
This unifies cooperative and cluster launch through a single attribute-driven
API, guarded behind `CUDA_VERSION >= 12000`.
## Changes
- **GPUOps.td**: add `cooperative` UnitAttr and assembly format keyword
- **SelectObjectAttr.cpp**: add `getKernelLaunchExFn()`, route cooperative
and/or cluster launches through `mgpuLaunchKernelEx`
- **CudaRuntimeWrappers.cpp**: implement `mgpuLaunchKernelEx` using
`cuLaunchKernelEx` with dynamic launch attributes
- **GPUToLLVMConversion.cpp**: propagate cooperative attribute through
the legalization pattern
- **test/Dialect/GPU/ops.mlir**: round-trip tests for cooperative keyword
with and without clusters
## Context
MLIR currently has no support for cooperative kernel launches. Flang works
around this with a CUF-specific attribute (PRs #<!-- -->124325, #<!-- -->124362), but
there is no first-class support in the GPU dialect. This patch adds it
at the `gpu.launch_func` level so all frontends can use it.
Cooperative launch requires `cudaLaunchCooperativeKernel` (CUDA 9+) or
`cuLaunchKernelEx` with `CU_LAUNCH_ATTRIBUTE_COOPERATIVE` (CUDA 12+).
This implementation uses the latter since the cluster launch path already
depends on CUDA 12.
---
Full diff: https://github.com/llvm/llvm-project/pull/190639.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+7)
- (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+3-1)
- (modified) mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp (+70)
- (modified) mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp (+34-8)
- (modified) mlir/test/Dialect/GPU/ops.mlir (+6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f0a4dd44c8f67..4e24e70c47a9f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -621,6 +621,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
Optional<LaunchIndx>:$clusterSizeY,
Optional<LaunchIndx>:$clusterSizeZ,
Optional<I32>:$dynamicSharedMemorySize,
+ UnitAttr:$cooperative,
Variadic<AnyType>:$kernelOperands,
Optional<AnyType>:$asyncObject)>,
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
@@ -663,6 +664,11 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
arguments are present, the Op launches a kernel that clusters the given
thread blocks. This feature is exclusive to certain architectures.
+ The `cooperative` attribute indicates that the kernel should be launched
+ cooperatively, guaranteeing that all thread blocks in the grid are
+ co-resident on the GPU simultaneously. This enables grid-wide
+ synchronization patterns.
+
Example:
```mlir
@@ -789,6 +795,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
`threads` `in` ` ` `(` $blockSizeX `,` $blockSizeY `,` $blockSizeZ `)`
custom<LaunchDimType>(type($gridSizeX), ref($clusterSizeX), type($clusterSizeX), type($clusterSizeY), type($clusterSizeZ))
(`dynamic_shared_memory_size` $dynamicSharedMemorySize^)?
+ (`cooperative` $cooperative^)?
custom<LaunchFuncOperands>($kernelOperands, type($kernelOperands)) attr-dict
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3e99c537d0e02..290e8a3f1c896 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -1061,7 +1061,7 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
adaptor.getClusterSizeZ()};
}
- gpu::LaunchFuncOp::create(
+ auto newLaunchOp = gpu::LaunchFuncOp::create(
rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
adaptor.getGridSizeZ()},
@@ -1070,6 +1070,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
adaptor.getDynamicSharedMemorySize(),
llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
stream, clusterSize);
+ if (launchOp.getCooperative())
+ newLaunchOp.setCooperative(true);
if (launchOp.getAsyncToken())
rewriter.replaceOp(launchOp, {stream});
else
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 6307e0b59f3d2..6639d2103f8ee 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -338,6 +338,76 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
#if (CUDA_VERSION >= 12000)
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchKernelEx(
+ CUfunction function, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
+ intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem,
+ CUstream stream, void **params, void **extra, intptr_t clusterX,
+ intptr_t clusterY, intptr_t clusterZ, int32_t cooperative) {
+ ScopedContext scopedContext;
+ if (smem > 0) {
+ int32_t maxShmem = 0;
+ CUdevice device = getDefaultCuDevice();
+ CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
+ CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
+ &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
+ device));
+ if (maxShmem < smem) {
+ fprintf(stderr,
+ "Requested shared memory (%dkb) is larger than maximum allowed "
+ "shared memory (%dkb) for this device\n",
+ smem, maxShmem);
+ }
+ CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
+ function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
+ }
+
+ CUlaunchConfig config;
+ config.gridDimX = gridX;
+ config.gridDimY = gridY;
+ config.gridDimZ = gridZ;
+ config.blockDimX = blockX;
+ config.blockDimY = blockY;
+ config.blockDimZ = blockZ;
+ config.sharedMemBytes = smem;
+ config.hStream = stream;
+
+ CUlaunchAttribute launchAttrs[3];
+ int numAttrs = 0;
+
+ bool hasCluster = clusterX > 0 && clusterY > 0 && clusterZ > 0;
+ if (hasCluster) {
+ launchAttrs[numAttrs].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
+ launchAttrs[numAttrs].value.clusterDim.x = clusterX;
+ launchAttrs[numAttrs].value.clusterDim.y = clusterY;
+ launchAttrs[numAttrs].value.clusterDim.z = clusterZ;
+ numAttrs++;
+
+ launchAttrs[numAttrs].id =
+ CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
+ launchAttrs[numAttrs].value.clusterSchedulingPolicyPreference =
+ CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
+ numAttrs++;
+ }
+
+ if (cooperative) {
+ launchAttrs[numAttrs].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
+ launchAttrs[numAttrs].value.cooperative = 1;
+ numAttrs++;
+ }
+
+ config.numAttrs = numAttrs;
+ config.attrs = launchAttrs;
+
+ debug_print("Launching kernel (cooperative=%d, cluster=%d),"
+ "grid=%ld,%ld,%ld, "
+ "threads: %ld, %ld, %ld, "
+ "smem: %dkb\n",
+ cooperative, hasCluster, gridX, gridY, gridZ, blockX, blockY,
+ blockZ, smem);
+
+ CUDA_REPORT_IF_ERROR(cuLaunchKernelEx(&config, function, params, extra));
+}
+
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
CUfunction function, intptr_t clusterX, intptr_t clusterY,
intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
index c25e9a3c36973..a776928c14817 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
@@ -215,6 +215,9 @@ class LaunchKernel {
// Get the kernel launch callee.
FunctionCallee getClusterKernelLaunchFn();
+ // Get the extended kernel launch callee (cooperative and/or cluster).
+ FunctionCallee getKernelLaunchExFn();
+
// Get the module function callee.
FunctionCallee getModuleFunctionFn();
@@ -311,6 +314,20 @@ llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
false));
}
+llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchExFn() {
+ // mgpuLaunchKernelEx(function, gridX, gridY, gridZ, blockX, blockY, blockZ,
+ // smem, stream, params, extra,
+ // clusterX, clusterY, clusterZ, cooperative)
+ return module.getOrInsertFunction(
+ "mgpuLaunchKernelEx",
+ FunctionType::get(
+ voidTy,
+ ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
+ intPtrTy, intPtrTy, i32Ty, ptrTy, ptrTy, ptrTy,
+ intPtrTy, intPtrTy, intPtrTy, i32Ty}),
+ false));
+}
+
llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
return module.getOrInsertFunction(
"mgpuModuleGetFunction",
@@ -452,15 +469,24 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
// Create the launch call.
Value *nullPtr = ConstantPointerNull::get(ptrTy);
- // Launch kernel with clusters if cluster size is specified.
- if (op.hasClusterSize()) {
- mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
- Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
- *cz = llvmValue(cluster.z);
+ // Use mgpuLaunchKernelEx when cooperative or cluster launch is requested.
+ if (op.getCooperative() || op.hasClusterSize()) {
+ Value *cx = ConstantInt::get(intPtrTy, 0);
+ Value *cy = ConstantInt::get(intPtrTy, 0);
+ Value *cz = ConstantInt::get(intPtrTy, 0);
+ if (op.hasClusterSize()) {
+ mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
+ cx = llvmValue(cluster.x);
+ cy = llvmValue(cluster.y);
+ cz = llvmValue(cluster.z);
+ }
+ Value *cooperativeFlag =
+ ConstantInt::get(i32Ty, op.getCooperative() ? 1 : 0);
builder.CreateCall(
- getClusterKernelLaunchFn(),
- ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
- dynamicMemorySize, stream, argArray, nullPtr}));
+ getKernelLaunchExFn(),
+ ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, bz,
+ dynamicMemorySize, stream, argArray, nullPtr,
+ cx, cy, cz, cooperativeFlag}));
} else {
builder.CreateCall(getKernelLaunchFn(),
ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index cbafc376fb89a..5f858d187c520 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -233,6 +233,12 @@ module attributes {gpu.container_module} {
// CHECK: gpu.launch_func @kernels::@kernel_1 clusters in (%{{.*}}, %{{.*}}, %{{.*}}) blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
gpu.launch_func @kernels::@kernel_1 clusters in (%cst, %cst, %cst) blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%0 : f32, %1 : memref<?xf32, 1>)
+ // CHECK: gpu.launch_func @kernels::@kernel_1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) cooperative args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
+ gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) cooperative args(%0 : f32, %1 : memref<?xf32, 1>)
+
+ // CHECK: gpu.launch_func @kernels::@kernel_1 clusters in (%{{.*}}, %{{.*}}, %{{.*}}) blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) cooperative args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
+ gpu.launch_func @kernels::@kernel_1 clusters in (%cst, %cst, %cst) blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) cooperative args(%0 : f32, %1 : memref<?xf32, 1>)
+
gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) dynamic_shared_memory_size %c0 args(%0 : f32, %1 : memref<?xf32, 1>)
// CHECK: gpu.launch_func @kernels::@kernel_2 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}})
``````````
</details>
https://github.com/llvm/llvm-project/pull/190639
More information about the Mlir-commits
mailing list