[Mlir-commits] [mlir] [MLIR][GPU] Add cooperative launch support to gpu.launch_func (PR #190639)

Jared Hoberock llvmlistbot at llvm.org
Tue Apr 7 11:35:26 PDT 2026


https://github.com/jaredhoberock updated https://github.com/llvm/llvm-project/pull/190639

>From 0bc6cfae961843e4d8d397178f54402db0d48b99 Mon Sep 17 00:00:00 2001
From: Jared Hoberock <jaredhoberock at gmail.com>
Date: Mon, 6 Apr 2026 12:35:19 -0500
Subject: [PATCH 1/2] [MLIR][GPU] Add cooperative launch support to
 gpu.launch_func

Add a `cooperative` UnitAttr to `gpu.launch_func` that enables
cooperative kernel launch semantics. Cooperative launches guarantee
that all thread blocks in the grid are co-resident on the GPU
simultaneously, enabling grid-wide synchronization patterns.

When `cooperative` is set (with or without cluster sizes), the lowering
emits a call to the new `mgpuLaunchKernelEx` runtime function, which
uses `cuLaunchKernelEx` with a `CUlaunchConfig` and
`CU_LAUNCH_ATTRIBUTE_COOPERATIVE`. This unifies cooperative and cluster
launch through a single attribute-driven API, guarded behind
CUDA_VERSION >= 12000.

Changes:
- GPUOps.td: add `cooperative` UnitAttr and assembly format keyword
- SelectObjectAttr.cpp: add `getKernelLaunchExFn()`, route cooperative
  and/or cluster launches through `mgpuLaunchKernelEx`
- CudaRuntimeWrappers.cpp: implement `mgpuLaunchKernelEx` using
  `cuLaunchKernelEx` with dynamic launch attributes
- GPUToLLVMConversion.cpp: propagate cooperative attribute through
  the legalization pattern
- test/Dialect/GPU/ops.mlir: round-trip tests for cooperative keyword
  with and without clusters
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    |  8 +++
 .../GPUCommon/GPUToLLVMConversion.cpp         |  4 +-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  8 +++
 .../GPU/Transforms/KernelOutlining.cpp        |  2 +
 .../ExecutionEngine/CudaRuntimeWrappers.cpp   | 70 +++++++++++++++++++
 .../LLVMIR/Dialect/GPU/SelectObjectAttr.cpp   | 42 ++++++++---
 mlir/test/Dialect/GPU/ops.mlir                | 18 +++++
 mlir/test/Dialect/GPU/outlining.mlir          | 14 ++++
 8 files changed, 157 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f0a4dd44c8f67..635a9201ab209 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -621,6 +621,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
                Optional<LaunchIndx>:$clusterSizeY,
                Optional<LaunchIndx>:$clusterSizeZ,
                Optional<I32>:$dynamicSharedMemorySize,
+               UnitAttr:$cooperative,
                Variadic<AnyType>:$kernelOperands,
                Optional<AnyType>:$asyncObject)>,
     Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
@@ -663,6 +664,11 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
     arguments are present, the Op launches a kernel that clusters the given
     thread blocks. This feature is exclusive to certain architectures.
 
+    The `cooperative` attribute indicates that the kernel should be launched
+    cooperatively, guaranteeing that all thread blocks in the grid are
+    co-resident on the GPU simultaneously. This enables grid-wide
+    synchronization patterns.
+
     Example:
 
     ```mlir
@@ -789,6 +795,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
       `threads` `in` ` ` `(` $blockSizeX `,` $blockSizeY `,` $blockSizeZ `)`
       custom<LaunchDimType>(type($gridSizeX), ref($clusterSizeX), type($clusterSizeX), type($clusterSizeY), type($clusterSizeZ))
       (`dynamic_shared_memory_size` $dynamicSharedMemorySize^)?
+      (`cooperative` $cooperative^)?
       custom<LaunchFuncOperands>($kernelOperands, type($kernelOperands)) attr-dict
   }];
   let hasVerifier = 1;
@@ -805,6 +812,7 @@ def GPU_LaunchOp : GPU_Op<"launch", [
                Optional<Index>:$clusterSizeY,
                Optional<Index>:$clusterSizeZ,
                Optional<I32>:$dynamicSharedMemorySize,
+               UnitAttr:$cooperative,
                OptionalAttr<FlatSymbolRefAttr>:$module,
                OptionalAttr<FlatSymbolRefAttr>:$function)>,
     Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3e99c537d0e02..290e8a3f1c896 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -1061,7 +1061,7 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
         gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
                         adaptor.getClusterSizeZ()};
   }
-  gpu::LaunchFuncOp::create(
+  auto newLaunchOp = gpu::LaunchFuncOp::create(
       rewriter, launchOp.getLoc(), launchOp.getKernelAttr(),
       gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
                       adaptor.getGridSizeZ()},
@@ -1070,6 +1070,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
       adaptor.getDynamicSharedMemorySize(),
       llvmArgumentsWithSizes.empty() ? llvmArguments : llvmArgumentsWithSizes,
       stream, clusterSize);
+  if (launchOp.getCooperative())
+    newLaunchOp.setCooperative(true);
   if (launchOp.getAsyncToken())
     rewriter.replaceOp(launchOp, {stream});
   else
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 8039f3952eea6..f2a0fa63ef11a 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -918,6 +918,9 @@ void LaunchOp::print(OpAsmPrinter &p) {
     p << ')';
   }
 
+  if (getCooperative())
+    p << " cooperative";
+
   printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
   printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
 
@@ -927,6 +930,7 @@ void LaunchOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
                               LaunchOp::getOperandSegmentSizeAttr(),
                               getNumWorkgroupAttributionsAttrName(),
+                              getCooperativeAttrName(),
                               moduleAttrName, functionAttrName});
 }
 
@@ -1069,6 +1073,10 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
       return failure();
   }
 
+  // Parse optional cooperative keyword.
+  if (succeeded(parser.parseOptionalKeyword("cooperative")))
+    result.addAttribute("cooperative", parser.getBuilder().getUnitAttr());
+
   // Create the region arguments, it has kNumConfigRegionAttributes arguments
   // that correspond to block/thread identifiers and grid/block sizes, all
   // having `index` type, a variadic number of WorkGroup Attributions and
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index b9529b0d067f2..e9e60ebd39d85 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -297,6 +297,8 @@ static void convertToLaunchFuncOp(gpu::LaunchOp launchOp,
       launchOp.getDynamicSharedMemorySize(), operands,
       asyncToken ? asyncToken.getType() : nullptr,
       launchOp.getAsyncDependencies(), clusterSize);
+  if (launchOp.getCooperative())
+    launchFunc.setCooperative(true);
   launchOp.replaceAllUsesWith(launchFunc);
   launchOp.erase();
 }
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 6307e0b59f3d2..6639d2103f8ee 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -338,6 +338,76 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
 
 #if (CUDA_VERSION >= 12000)
 
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchKernelEx(
+    CUfunction function, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
+    intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem,
+    CUstream stream, void **params, void **extra, intptr_t clusterX,
+    intptr_t clusterY, intptr_t clusterZ, int32_t cooperative) {
+  ScopedContext scopedContext;
+  if (smem > 0) {
+    int32_t maxShmem = 0;
+    CUdevice device = getDefaultCuDevice();
+    CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
+    CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
+        &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
+        device));
+    if (maxShmem < smem) {
+      fprintf(stderr,
+              "Requested shared memory (%dkb) is larger than maximum allowed "
+              "shared memory (%dkb) for this device\n",
+              smem, maxShmem);
+    }
+    CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
+        function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
+  }
+
+  CUlaunchConfig config;
+  config.gridDimX = gridX;
+  config.gridDimY = gridY;
+  config.gridDimZ = gridZ;
+  config.blockDimX = blockX;
+  config.blockDimY = blockY;
+  config.blockDimZ = blockZ;
+  config.sharedMemBytes = smem;
+  config.hStream = stream;
+
+  CUlaunchAttribute launchAttrs[3];
+  int numAttrs = 0;
+
+  bool hasCluster = clusterX > 0 && clusterY > 0 && clusterZ > 0;
+  if (hasCluster) {
+    launchAttrs[numAttrs].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
+    launchAttrs[numAttrs].value.clusterDim.x = clusterX;
+    launchAttrs[numAttrs].value.clusterDim.y = clusterY;
+    launchAttrs[numAttrs].value.clusterDim.z = clusterZ;
+    numAttrs++;
+
+    launchAttrs[numAttrs].id =
+        CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
+    launchAttrs[numAttrs].value.clusterSchedulingPolicyPreference =
+        CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
+    numAttrs++;
+  }
+
+  if (cooperative) {
+    launchAttrs[numAttrs].id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE;
+    launchAttrs[numAttrs].value.cooperative = 1;
+    numAttrs++;
+  }
+
+  config.numAttrs = numAttrs;
+  config.attrs = launchAttrs;
+
+  debug_print("Launching kernel (cooperative=%d, cluster=%d),"
+              "grid=%ld,%ld,%ld, "
+              "threads: %ld, %ld, %ld, "
+              "smem: %dkb\n",
+              cooperative, hasCluster, gridX, gridY, gridZ, blockX, blockY,
+              blockZ, smem);
+
+  CUDA_REPORT_IF_ERROR(cuLaunchKernelEx(&config, function, params, extra));
+}
+
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
     CUfunction function, intptr_t clusterX, intptr_t clusterY,
     intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
index c25e9a3c36973..a776928c14817 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
@@ -215,6 +215,9 @@ class LaunchKernel {
   // Get the kernel launch callee.
   FunctionCallee getClusterKernelLaunchFn();
 
+  // Get the extended kernel launch callee (cooperative and/or cluster).
+  FunctionCallee getKernelLaunchExFn();
+
   // Get the module function callee.
   FunctionCallee getModuleFunctionFn();
 
@@ -311,6 +314,20 @@ llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
           false));
 }
 
+llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchExFn() {
+  // mgpuLaunchKernelEx(function, gridX, gridY, gridZ, blockX, blockY, blockZ,
+  //   smem, stream, params, extra,
+  //   clusterX, clusterY, clusterZ, cooperative)
+  return module.getOrInsertFunction(
+      "mgpuLaunchKernelEx",
+      FunctionType::get(
+          voidTy,
+          ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
+                            intPtrTy, intPtrTy, i32Ty, ptrTy, ptrTy, ptrTy,
+                            intPtrTy, intPtrTy, intPtrTy, i32Ty}),
+          false));
+}
+
 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
   return module.getOrInsertFunction(
       "mgpuModuleGetFunction",
@@ -452,15 +469,24 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
   // Create the launch call.
   Value *nullPtr = ConstantPointerNull::get(ptrTy);
 
-  // Launch kernel with clusters if cluster size is specified.
-  if (op.hasClusterSize()) {
-    mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
-    Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
-          *cz = llvmValue(cluster.z);
+  // Use mgpuLaunchKernelEx when cooperative or cluster launch is requested.
+  if (op.getCooperative() || op.hasClusterSize()) {
+    Value *cx = ConstantInt::get(intPtrTy, 0);
+    Value *cy = ConstantInt::get(intPtrTy, 0);
+    Value *cz = ConstantInt::get(intPtrTy, 0);
+    if (op.hasClusterSize()) {
+      mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
+      cx = llvmValue(cluster.x);
+      cy = llvmValue(cluster.y);
+      cz = llvmValue(cluster.z);
+    }
+    Value *cooperativeFlag =
+        ConstantInt::get(i32Ty, op.getCooperative() ? 1 : 0);
     builder.CreateCall(
-        getClusterKernelLaunchFn(),
-        ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
-                           dynamicMemorySize, stream, argArray, nullPtr}));
+        getKernelLaunchExFn(),
+        ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, bz,
+                           dynamicMemorySize, stream, argArray, nullPtr,
+                           cx, cy, cz, cooperativeFlag}));
   } else {
     builder.CreateCall(getKernelLaunchFn(),
                        ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index cbafc376fb89a..11cea6f82d7b5 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -17,6 +17,18 @@ module attributes {gpu.container_module} {
     return
   }
 
+  // CHECK-LABEL:func @launch_cooperative(%{{.*}}: index)
+  func.func @launch_cooperative(%sz : index) {
+    // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) cooperative
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+               threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz)
+               cooperative {
+      // CHECK: gpu.terminator
+      gpu.terminator
+    }
+    return
+  }
+
   // CHECK-LABEL:func @launch_with_module_func_attr(%{{.*}}: index)
   func.func @launch_with_module_func_attr(%sz : index) {
     // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) module(@test_module) function(@test_kernel_func)
@@ -233,6 +245,12 @@ module attributes {gpu.container_module} {
     // CHECK: gpu.launch_func @kernels::@kernel_1 clusters in (%{{.*}}, %{{.*}}, %{{.*}}) blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
     gpu.launch_func @kernels::@kernel_1 clusters in (%cst, %cst, %cst) blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%0 : f32, %1 : memref<?xf32, 1>)
 
+    // CHECK: gpu.launch_func @kernels::@kernel_1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) cooperative args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
+    gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) cooperative args(%0 : f32, %1 : memref<?xf32, 1>)
+
+    // CHECK: gpu.launch_func @kernels::@kernel_1 clusters in (%{{.*}}, %{{.*}}, %{{.*}}) blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) cooperative args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
+    gpu.launch_func @kernels::@kernel_1 clusters in (%cst, %cst, %cst) blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) cooperative args(%0 : f32, %1 : memref<?xf32, 1>)
+
     gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) dynamic_shared_memory_size %c0 args(%0 : f32, %1 : memref<?xf32, 1>)
 
     // CHECK: gpu.launch_func @kernels::@kernel_2 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}})
diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir
index 25220dff7a5bb..f708561bc2f01 100644
--- a/mlir/test/Dialect/GPU/outlining.mlir
+++ b/mlir/test/Dialect/GPU/outlining.mlir
@@ -705,3 +705,17 @@ module attributes {gpu.container_module} {
     return
   }
 }
+
+// -----
+
+// CHECK-LABEL: func @launch_cooperative
+func.func @launch_cooperative() {
+  %cst = arith.constant 8 : index
+  // CHECK: gpu.launch_func @launch_cooperative_kernel::@launch_cooperative_kernel blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) cooperative
+  gpu.launch blocks(%bx, %by, %bz) in (%gx = %cst, %gy = %cst, %gz = %cst)
+             threads(%tx, %ty, %tz) in (%bxs = %cst, %bys = %cst, %bzs = %cst)
+             cooperative {
+    gpu.terminator
+  }
+  return
+}

>From 60a2d3ba68b15fd7b78f767715af242185bc57fc Mon Sep 17 00:00:00 2001
From: Jared Hoberock <jaredhoberock at gmail.com>
Date: Tue, 7 Apr 2026 13:23:32 -0500
Subject: [PATCH 2/2] [MLIR][GPU] Preserve mgpuLaunchClusterKernel for
 cluster-only launches

Route only cooperative launches (with or without cluster) through the
new mgpuLaunchKernelEx. Cluster-only (non-cooperative) launches keep
their existing path through mgpuLaunchClusterKernel so that runtime
entry point stays in use rather than becoming dead code.

Also fix two clang-format complaints flagged by the formatter check
and add lit coverage for the new cooperative paths in
mlir/test/Target/LLVMIR/gpu.mlir:
 - cooperative without cluster -> mgpuLaunchKernelEx with cluster=0
 - cooperative with cluster    -> mgpuLaunchKernelEx with real cluster

The pre-existing cluster-only test case is unchanged and now passes
again because that lowering path is restored.

Assisted-by: Claude (Anthropic)
---
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |  4 +-
 .../ExecutionEngine/CudaRuntimeWrappers.cpp   | 13 ++++--
 .../LLVMIR/Dialect/GPU/SelectObjectAttr.cpp   | 22 ++++++---
 mlir/test/Target/LLVMIR/gpu.mlir              | 46 ++++++++++++++++++-
 4 files changed, 71 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index f2a0fa63ef11a..a6e1d4adcf20e 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -930,8 +930,8 @@ void LaunchOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
                               LaunchOp::getOperandSegmentSizeAttr(),
                               getNumWorkgroupAttributionsAttrName(),
-                              getCooperativeAttrName(),
-                              moduleAttrName, functionAttrName});
+                              getCooperativeAttrName(), moduleAttrName,
+                              functionAttrName});
 }
 
 // Parse the size assignment blocks for blocks and threads.  These have the form
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 6639d2103f8ee..2f0fe2b3bf972 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -338,11 +338,14 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
 
 #if (CUDA_VERSION >= 12000)
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchKernelEx(
-    CUfunction function, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
-    intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem,
-    CUstream stream, void **params, void **extra, intptr_t clusterX,
-    intptr_t clusterY, intptr_t clusterZ, int32_t cooperative) {
+// Cooperative launch entry point, optionally with a cluster. Pass
+// `clusterX/Y/Z = 0` to launch cooperatively without a cluster.
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuLaunchKernelEx(CUfunction function, intptr_t gridX, intptr_t gridY,
+                   intptr_t gridZ, intptr_t blockX, intptr_t blockY,
+                   intptr_t blockZ, int32_t smem, CUstream stream,
+                   void **params, void **extra, intptr_t clusterX,
+                   intptr_t clusterY, intptr_t clusterZ, int32_t cooperative) {
   ScopedContext scopedContext;
   if (smem > 0) {
     int32_t maxShmem = 0;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
index a776928c14817..88cb7317b88ba 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
@@ -469,8 +469,11 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
   // Create the launch call.
   Value *nullPtr = ConstantPointerNull::get(ptrTy);
 
-  // Use mgpuLaunchKernelEx when cooperative or cluster launch is requested.
-  if (op.getCooperative() || op.hasClusterSize()) {
+  // Cooperative launches go through mgpuLaunchKernelEx, which also handles
+  // an optional cluster. Cluster-only (non-cooperative) launches keep their
+  // existing path through mgpuLaunchClusterKernel. Plain launches go through
+  // mgpuLaunchKernel.
+  if (op.getCooperative()) {
     Value *cx = ConstantInt::get(intPtrTy, 0);
     Value *cy = ConstantInt::get(intPtrTy, 0);
     Value *cz = ConstantInt::get(intPtrTy, 0);
@@ -480,13 +483,20 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
       cy = llvmValue(cluster.y);
       cz = llvmValue(cluster.z);
     }
-    Value *cooperativeFlag =
-        ConstantInt::get(i32Ty, op.getCooperative() ? 1 : 0);
+    Value *cooperativeFlag = ConstantInt::get(i32Ty, 1);
     builder.CreateCall(
         getKernelLaunchExFn(),
         ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, bz,
-                           dynamicMemorySize, stream, argArray, nullPtr,
-                           cx, cy, cz, cooperativeFlag}));
+                           dynamicMemorySize, stream, argArray, nullPtr, cx, cy,
+                           cz, cooperativeFlag}));
+  } else if (op.hasClusterSize()) {
+    mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
+    Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
+          *cz = llvmValue(cluster.z);
+    builder.CreateCall(
+        getClusterKernelLaunchFn(),
+        ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
+                           dynamicMemorySize, stream, argArray, nullPtr}));
   } else {
     builder.CreateCall(getKernelLaunchFn(),
                        ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
diff --git a/mlir/test/Target/LLVMIR/gpu.mlir b/mlir/test/Target/LLVMIR/gpu.mlir
index 0d29a95b12266..45c4c9e13ec84 100644
--- a/mlir/test/Target/LLVMIR/gpu.mlir
+++ b/mlir/test/Target/LLVMIR/gpu.mlir
@@ -93,7 +93,7 @@ module attributes {gpu.container_module} {
     %c1 = llvm.mlir.constant(1 : index) : i64
     %c2 = llvm.mlir.constant(2 : index) : i64
     %c3 = llvm.mlir.constant(3 : index) : i64
-    gpu.launch_func @kernel_module::@kernel 
+    gpu.launch_func @kernel_module::@kernel
         clusters in (%c1, %c1, %c1)
         blocks in (%c2, %c2, %c2)
         threads in (%c3, %c3, %c3) : i64
@@ -103,6 +103,50 @@ module attributes {gpu.container_module} {
 
 // -----
 
+// Test cooperative launch without a cluster: lowers to mgpuLaunchKernelEx
+// with cluster dims = 0 and the cooperative flag set to 1.
+module attributes {gpu.container_module} {
+  gpu.binary @kernel_module  [#gpu.object<#nvvm.target, "BLOB">]
+  llvm.func @cooperative_no_cluster() {
+  // CHECK: call void @mgpuLaunchKernelEx(
+  // CHECK-SAME: i64 2, i64 2, i64 2,
+  // CHECK-SAME: i64 3, i64 3, i64 3,
+  // CHECK-SAME: i64 0, i64 0, i64 0, i32 1)
+    %c2 = llvm.mlir.constant(2 : index) : i64
+    %c3 = llvm.mlir.constant(3 : index) : i64
+    gpu.launch_func @kernel_module::@kernel
+        blocks in (%c2, %c2, %c2)
+        threads in (%c3, %c3, %c3) : i64
+        cooperative
+    llvm.return
+  }
+}
+
+// -----
+
+// Test cooperative launch combined with a cluster: lowers to
+// mgpuLaunchKernelEx with the real cluster dims and the cooperative flag.
+module attributes {gpu.container_module} {
+  gpu.binary @kernel_module  [#gpu.object<#nvvm.target, "BLOB">]
+  llvm.func @cooperative_with_cluster() {
+  // CHECK: call void @mgpuLaunchKernelEx(
+  // CHECK-SAME: i64 2, i64 2, i64 2,
+  // CHECK-SAME: i64 3, i64 3, i64 3,
+  // CHECK-SAME: i64 1, i64 1, i64 1, i32 1)
+    %c1 = llvm.mlir.constant(1 : index) : i64
+    %c2 = llvm.mlir.constant(2 : index) : i64
+    %c3 = llvm.mlir.constant(3 : index) : i64
+    gpu.launch_func @kernel_module::@kernel
+        clusters in (%c1, %c1, %c1)
+        blocks in (%c2, %c2, %c2)
+        threads in (%c3, %c3, %c3) : i64
+        cooperative
+    llvm.return
+  }
+}
+
+// -----
+
 // Checking that ELF section is populated
 module attributes {gpu.container_module} {
   // CHECK: @cuda_device_mod_binary = internal constant [4 x i8] c"BLOB", section "__nv_rel_fatbin", align 8



More information about the Mlir-commits mailing list