[Mlir-commits] [mlir] [mlir][gpu] Support Cluster of Thread Blocks in `gpu.launch_func` (PR #72871)

Guray Ozen llvmlistbot at llvm.org
Mon Nov 27 02:03:46 PST 2023


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/72871

>From 4958087cbc277fa7f8ee6911bd536cbb374741fb Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 20 Nov 2023 14:58:43 +0100
Subject: [PATCH 1/4] [mlir][gpu] Support Cluster of Cooperative Thread Arrays
 in `gpu.launch_func`

PR enables support for Cluster of Cooperative Thread Arrays (aka Cooperative Group Array (CGA)) within the `gpu.launch_func` in the GPU dialect. NVIDIA H100 architecture introduced the CTA Clusters which is new level of parallelism, allowing groups of CTAs to synchronize and communicate through shared memory while running concurrently.

This PR extends `gpu.launch_func` to accommodate this functionality. The GPU dialect has mechanism for outlining and kernel launching, so having kernel launch with cluster supports in here is convenient.

An example of this implementation can be seen below:

```
gpu.launch_func @kernel_module::@kernel
                clusters in (%1, %0, %0) // <-- Optional
                blocks in (%0, %0, %0)
                threads in (%0, %0, %0)
```

The PR also introduces index and dimensions Ops specific to clusters, binding them to NVVM Ops:

```
gpu.cluster_id x
gpu.cluster_id y
gpu.cluster_id z

gpu.cluster_dim x
gpu.cluster_dim y
gpu.cluster_dim z
```

We will introduce cluster support in `gpu.launch` Op in an upcoming PR, which will amplify the capabilities of our GPU dialect.

Please refer to the documentation provided by NVIDIA for details
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cluster-of-cooperative-thread-arrays
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    | 75 +++++++++++++++++--
 .../GPUCommon/GPUToLLVMConversion.cpp         | 28 ++++++-
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        | 25 ++++---
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 45 +++++++++--
 .../GPU/IR/InferIntRangeInterfaceImpls.cpp    | 13 ++++
 .../ExecutionEngine/CudaRuntimeWrappers.cpp   | 54 +++++++++++++
 .../LLVMIR/Dialect/GPU/SelectObjectAttr.cpp   | 34 ++++++++-
 ...ower-launch-func-to-gpu-runtime-calls.mlir | 38 ++++++++++
 mlir/test/Dialect/GPU/invalid.mlir            |  2 +-
 .../GPU/CUDA/sm90/cga_cluster.mlir            | 65 ++++++++++++++++
 mlir/test/Target/LLVMIR/gpu.mlir              | 19 +++++
 11 files changed, 370 insertions(+), 28 deletions(-)
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 632cdd96c6d4c2b..f093e4392520263 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -53,6 +53,32 @@ class GPU_IndexOp<string mnemonic, list<Trait> traits = []> :
   let assemblyFormat = "$dimension attr-dict";
 }
 
+def GPU_ClusterDimOp : GPU_IndexOp<"cluster_dim"> {
+  let description = [{
+    Returns the number of thread blocks in the cluster along
+    the x, y, or z `dimension`.
+
+    Example:
+
+    ```mlir
+    %cDimX = gpu.cluster_dim x
+    ```
+  }];
+}
+
+def GPU_ClusterIdOp : GPU_IndexOp<"cluster_id"> {
+  let description = [{
+    Returns the cluster id, i.e. the index of the current cluster within the 
+    grid along the x, y, or z `dimension`.
+
+    Example:
+
+    ```mlir
+    %cIdY = gpu.cluster_id y
+    ```
+  }];
+}
+
 def GPU_BlockDimOp : GPU_IndexOp<"block_dim"> {
   let description = [{
     Returns the number of threads in the thread block (aka the block size) along
@@ -441,8 +467,15 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
                      "blockSizeY", "blockSizeZ"]>]>,
     Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
                SymbolRefAttr:$kernel,
-               LaunchIndx:$gridSizeX, LaunchIndx:$gridSizeY, LaunchIndx:$gridSizeZ,
-               LaunchIndx:$blockSizeX, LaunchIndx:$blockSizeY, LaunchIndx:$blockSizeZ,
+               LaunchIndx:$gridSizeX, 
+               LaunchIndx:$gridSizeY, 
+               LaunchIndx:$gridSizeZ,
+               LaunchIndx:$blockSizeX, 
+               LaunchIndx:$blockSizeY, 
+               LaunchIndx:$blockSizeZ,
+               Optional<LaunchIndx>:$clusterSizeX,
+               Optional<LaunchIndx>:$clusterSizeY,
+               Optional<LaunchIndx>:$clusterSizeZ,
                Optional<I32>:$dynamicSharedMemorySize,
                Variadic<AnyType>:$kernelOperands,
                Optional<AnyType>:$asyncObject)>,
@@ -480,6 +513,12 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
     The remaining operands if present are passed as arguments to the kernel
     function.
 
+    The `gpu.launch_func` also supports kernel launching with clusters if 
+    supported by the target architecture. The cluster size can be set by 
+    `clusterSizeX`, `clusterSizeY`, and `clusterSizeZ` arguments. When these 
+    arguments are present, the Op launches a kernel that clusters the given 
+    thread blocks. This feature is exclusive to certain architectures.
+
     Example:
 
     ```mlir
@@ -509,6 +548,15 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
           %gDimY = gpu.grid_dim y
           %gDimZ = gpu.grid_dim z
 
+          // (Optional)  Cluster size only for support architectures
+          %cIdX = gpu.cluster_id x
+          %cIdY = gpu.cluster_id y
+          %cIdZ = gpu.cluster_id z
+
+          %cDimX = gpu.cluster_dim x
+          %cDimY = gpu.cluster_dim y
+          %cDimZ = gpu.cluster_dim z
+
           "some_op"(%bx, %tx) : (index, index) -> ()
           %42 = load %arg1[%bx] : memref<?xf32, 1>
         }
@@ -519,6 +567,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
           async                           // (Optional) Don't block host, return token.
           [%t0]                           // (Optional) Execute only after %t0 has completed.
           @kernels::@kernel_1             // Kernel function.
+          clusters in (%cst, %cst, %cst)  // (Optional) Cluster size only for support architectures. 
           blocks in (%cst, %cst, %cst)    // Grid size.
           threads in (%cst, %cst, %cst)   // Block size.
           dynamic_shared_memory_size %s   // (Optional) Amount of dynamic shared
@@ -536,11 +585,13 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
       "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
       "ValueRange":$kernelOperands,
       CArg<"Type", "nullptr">:$asyncTokenType,
-      CArg<"ValueRange", "{}">:$asyncDependencies)>,
+      CArg<"ValueRange", "{}">:$asyncDependencies,
+      CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize)>,
     OpBuilder<(ins "SymbolRefAttr":$kernel, "KernelDim3":$gridSize,
       "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
       "ValueRange":$kernelOperands,
-      CArg<"Value", "nullptr">:$asyncObject)>
+      CArg<"Value", "nullptr">:$asyncObject,
+      CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize)>
   ];
 
   let extraClassDeclaration = [{
@@ -550,12 +601,23 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
     /// The name of the kernel.
     StringAttr getKernelName();
 
+    /// Has cluster
+    bool hasClusterSize() {
+      auto totalSize = getOperands().size();
+      totalSize -= getKernelOperands().size();
+      totalSize -= getAsyncDependencies().size();
+      return totalSize > 7;
+    }
+
     /// The number of operands passed to the kernel function.
     unsigned getNumKernelOperands();
 
     /// The i-th operand passed to the kernel function.
     Value getKernelOperand(unsigned i);
 
+    /// Get the SSA values passed as operands to specify the cluster size.
+    KernelDim3 getClusterSizeOperandValues();
+
     /// Get the SSA values passed as operands to specify the grid size.
     KernelDim3 getGridSizeOperandValues();
 
@@ -571,10 +633,11 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
   let assemblyFormat = [{
       custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
       (`<` $asyncObject^ `:` type($asyncObject) `>`)?
-      $kernel
+      $kernel      
+      ( `clusters` `in` ` ` `(` $clusterSizeX^ `,` $clusterSizeY `,` $clusterSizeZ `)` )?
       `blocks` `in` ` ` `(` $gridSizeX `,` $gridSizeY `,` $gridSizeZ `)`
       `threads` `in` ` ` `(` $blockSizeX `,` $blockSizeY `,` $blockSizeZ `)`
-      custom<LaunchDimType>(type($gridSizeX))
+      custom<LaunchDimType>(type($gridSizeX), ref($clusterSizeX), type($clusterSizeX), type($clusterSizeY), type($clusterSizeZ))
       (`dynamic_shared_memory_size` $dynamicSharedMemorySize^)?
       custom<LaunchFuncOperands>($kernelOperands, type($kernelOperands)) attr-dict
   }];
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 7bac8f5a8f0e03b..381d5100a7a3fdc 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -121,6 +121,26 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
           llvmPointerType, /* void **extra */
           llvmInt64Type    /* size_t paramsCount */
       }};
+  FunctionCallBuilder launchClusterKernelCallBuilder = {
+      "mgpuLaunchClusterKernel",
+      llvmVoidType,
+      {
+          llvmPointerType, /* void* f */
+          llvmIntPtrType,  /* intptr_t clusterXDim */
+          llvmIntPtrType,  /* intptr_t clusteryDim */
+          llvmIntPtrType,  /* intptr_t clusterZDim */
+          llvmIntPtrType,  /* intptr_t gridXDim */
+          llvmIntPtrType,  /* intptr_t gridyDim */
+          llvmIntPtrType,  /* intptr_t gridZDim */
+          llvmIntPtrType,  /* intptr_t blockXDim */
+          llvmIntPtrType,  /* intptr_t blockYDim */
+          llvmIntPtrType,  /* intptr_t blockZDim */
+          llvmInt32Type,   /* unsigned int sharedMemBytes */
+          llvmPointerType, /* void *hstream */
+          llvmPointerType, /* void **kernelParams */
+          llvmPointerType, /* void **extra */
+          llvmInt64Type    /* size_t paramsCount */
+      }};
   FunctionCallBuilder streamCreateCallBuilder = {
       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
   FunctionCallBuilder streamDestroyCallBuilder = {
@@ -1128,13 +1148,19 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
         loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(),
         rewriter, /*useBarePtrCallConv=*/kernelBarePtrCallConv);
 
+    std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
+    if (launchOp.hasClusterSize()) {
+      clusterSize =
+          gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
+                          adaptor.getClusterSizeZ()};
+    }
     rewriter.create<gpu::LaunchFuncOp>(
         launchOp.getLoc(), launchOp.getKernelAttr(),
         gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
                         adaptor.getGridSizeZ()},
         gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
                         adaptor.getBlockSizeZ()},
-        adaptor.getDynamicSharedMemorySize(), arguments, stream);
+        adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
     if (launchOp.getAsyncToken())
       rewriter.replaceOp(launchOp, {stream});
     else
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 935e3d2a4095003..5b353b0c3bcbd76 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -313,17 +313,20 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns) {
   populateWithGenerated(patterns);
   patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
-  patterns
-      .add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
-                                       NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
-           GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
-                                       NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
-           GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
-                                       NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
-           GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
-                                       NVVM::GridDimYOp, NVVM::GridDimZOp>,
-           GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
-          converter);
+  patterns.add<
+      GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
+                                  NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
+      GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
+                                  NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
+      GPUIndexIntrinsicOpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
+                                  NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>,
+      GPUIndexIntrinsicOpLowering<gpu::ClusterDimOp, NVVM::ClusterDimXOp,
+                                  NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
+      GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
+                                  NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
+      GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
+                                  NVVM::GridDimYOp, NVVM::GridDimZOp>,
+      GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
 
   // Explicitly drop memory space when lowering private memory
   // attributions since NVVM models it as `alloca`s in the default
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index e0a2b93df3d1fd6..156a993131ad379 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -983,7 +983,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
                          GPUFuncOp kernelFunc, KernelDim3 gridSize,
                          KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
                          ValueRange kernelOperands, Type asyncTokenType,
-                         ValueRange asyncDependencies) {
+                         ValueRange asyncDependencies,
+                         std::optional<KernelDim3> clusterSize) {
   result.addOperands(asyncDependencies);
   if (asyncTokenType)
     result.types.push_back(builder.getType<AsyncTokenType>());
@@ -991,6 +992,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
   // Add grid and block sizes as op operands, followed by the data operands.
   result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
                       getBlockSize.y, getBlockSize.z});
+  if (clusterSize.has_value())
+    result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
   if (dynamicSharedMemorySize)
     result.addOperands(dynamicSharedMemorySize);
   result.addOperands(kernelOperands);
@@ -1006,6 +1009,11 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
   for (auto &sz : prop.operandSegmentSizes)
     sz = 1;
   prop.operandSegmentSizes[0] = asyncDependencies.size();
+  if (!clusterSize.has_value()) {
+    prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
+    prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
+    prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
+  }
   prop.operandSegmentSizes[segmentSizesLen - 3] =
       dynamicSharedMemorySize ? 1 : 0;
   prop.operandSegmentSizes[segmentSizesLen - 2] =
@@ -1016,10 +1024,13 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
                          SymbolRefAttr kernel, KernelDim3 gridSize,
                          KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
-                         ValueRange kernelOperands, Value asyncObject) {
+                         ValueRange kernelOperands, Value asyncObject,
+                         std::optional<KernelDim3> clusterSize) {
   // Add grid and block sizes as op operands, followed by the data operands.
   result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
                       getBlockSize.y, getBlockSize.z});
+  if (clusterSize.has_value())
+    result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
   if (dynamicSharedMemorySize)
     result.addOperands(dynamicSharedMemorySize);
   result.addOperands(kernelOperands);
@@ -1032,6 +1043,11 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
   for (auto &sz : prop.operandSegmentSizes)
     sz = 1;
   prop.operandSegmentSizes[0] = 0;
+  if (!clusterSize.has_value()) {
+    prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
+    prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
+    prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
+  }
   prop.operandSegmentSizes[segmentSizesLen - 3] =
       dynamicSharedMemorySize ? 1 : 0;
   prop.operandSegmentSizes[segmentSizesLen - 2] =
@@ -1065,6 +1081,11 @@ KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
   return KernelDim3{operands[3], operands[4], operands[5]};
 }
 
+KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
+  auto operands = getOperands().drop_front(getAsyncDependencies().size());
+  return KernelDim3{operands[6], operands[7], operands[8]};
+}
+
 LogicalResult LaunchFuncOp::verify() {
   auto module = (*this)->getParentOfType<ModuleOp>();
   if (!module)
@@ -1076,21 +1097,35 @@ LogicalResult LaunchFuncOp::verify() {
                        GPUDialect::getContainerModuleAttrName() +
                        "' attribute");
 
+  if (getClusterSizeX()) {
+    if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
+        getClusterSizeZ().getType() != getClusterSizeX().getType())
+      return emitOpError()
+             << "expects types of the cluster dimensions must be the same";
+  }
+
   return success();
 }
 
-static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy) {
+static ParseResult
+parseLaunchDimType(OpAsmParser &parser, Type &dimTy,
+                   std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
+                   Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
   if (succeeded(parser.parseOptionalColon())) {
     if (parser.parseType(dimTy))
       return failure();
   } else {
     dimTy = IndexType::get(parser.getContext());
   }
+  if (clusterValue.has_value()) {
+    clusterXTy = clusterYTy = clusterZTy = dimTy;
+  }
   return success();
 }
 
-static void printLaunchDimType(OpAsmPrinter &printer, Operation *op,
-                               Type dimTy) {
+static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
+                               Value clusterValue, Type clusterXTy,
+                               Type clusterYTy, Type clusterZTy) {
   if (!dimTy.isIndex())
     printer << ": " << dimTy;
 }
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index cb2d66d5b0d32da..69017efb9a0e67c 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -19,6 +19,8 @@ using namespace mlir::gpu;
 
 // Maximum grid and block dimensions of all known GPUs are less than 2^32.
 static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
+// Maximum cluster size
+static constexpr uint64_t kMaxClusterDim = 8;
 // Maximum subgroups are no larger than 128.
 static constexpr uint64_t kMaxSubgroupSize = 128;
 
@@ -82,6 +84,17 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
   return std::nullopt;
 }
 
+void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+                                     SetIntRangeFn setResultRange) {
+  setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
+}
+
+void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+                                    SetIntRangeFn setResultRange) {
+  uint64_t max = kMaxClusterDim;
+  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
+}
+
 void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                    SetIntRangeFn setResultRange) {
   std::optional<uint64_t> knownVal =
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index a8e743c519135f7..9b63d2a22a7a31f 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -194,6 +194,60 @@ mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY,
                                       extra));
 }
 
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
+    CUfunction function, intptr_t clusterX, intptr_t clusterY,
+    intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
+    intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem,
+    CUstream stream, void **params, void **extra, size_t /*paramsCount*/) {
+  ScopedContext scopedContext;
+  if (smem > 0) {
+    // Avoid checking driver as it's more expensive than if statement
+    int32_t maxShmem = 0;
+    CUdevice device = getDefaultCuDevice();
+    CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
+    CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
+        &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
+        device));
+    if (maxShmem < smem) {
+      fprintf(stderr,
+              "Requested shared memory (%dkb) is larger than maximum allowed "
+              "shared memory (%dkb) for this device\n",
+              smem, maxShmem);
+    }
+    CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
+        function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
+  }
+  CUlaunchConfig config;
+  config.gridDimX = gridX;
+  config.gridDimY = gridY;
+  config.gridDimZ = gridZ;
+  config.blockDimX = blockX;
+  config.blockDimY = blockY;
+  config.blockDimZ = blockZ;
+  config.sharedMemBytes = smem;
+  config.hStream = stream;
+  CUlaunchAttribute launchAttr[2];
+  launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
+  launchAttr[0].value.clusterDim.x = clusterX;
+  launchAttr[0].value.clusterDim.y = clusterY;
+  launchAttr[0].value.clusterDim.z = clusterZ;
+  launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
+  launchAttr[1].value.clusterSchedulingPolicyPreference =
+      CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
+  config.numAttrs = 2;
+  config.attrs = launchAttr;
+
+  debug_print("Launching kernel,"
+              "cluster: %ld, %ld, %ld, "
+              "grid=%ld,%ld,%ld, "
+              "threads: %ld, %ld, %ld, "
+              "smem: %dkb\n",
+              clusterX, clusterY, clusterZ, gridX, gridY, gridZ, blockX, blockY,
+              blockZ, smem);
+
+  CUDA_REPORT_IF_ERROR(cuLaunchKernelEx(&config, function, params, extra));
+}
+
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() {
   ScopedContext scopedContext;
   CUstream stream = nullptr;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
index 47fe6973778cd7f..2acccb7c2fafa47 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
@@ -136,6 +136,9 @@ class LaunchKernel {
   // Get the kernel launch callee.
   FunctionCallee getKernelLaunchFn();
 
+  // Get the kernel launch callee.
+  FunctionCallee getClusterKernelLaunchFn();
+
   // Get the module function callee.
   FunctionCallee getModuleFunctionFn();
 
@@ -228,6 +231,17 @@ llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
           false));
 }
 
+llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
+  return module.getOrInsertFunction(
+      "mgpuLaunchClusterKernel",
+      FunctionType::get(
+          voidTy,
+          ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
+                            intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
+                            i32Ty, ptrTy, ptrTy, ptrTy}),
+          false));
+}
+
 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
   return module.getOrInsertFunction(
       "mgpuModuleGetFunction",
@@ -401,10 +415,22 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
 
   // Create the launch call.
   Value *nullPtr = ConstantPointerNull::get(ptrTy);
-  builder.CreateCall(
-      getKernelLaunchFn(),
-      ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, bz,
-                         dynamicMemorySize, stream, argArray, nullPtr}));
+
+  // Launch kernel with clusters if cluster size is specified.
+  if (op.hasClusterSize()) {
+    mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
+    Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
+          *cz = llvmValue(cluster.z);
+    builder.CreateCall(
+        getClusterKernelLaunchFn(),
+        ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
+                           dynamicMemorySize, stream, argArray, nullPtr}));
+  } else {
+    builder.CreateCall(
+        getKernelLaunchFn(),
+        ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, bz,
+                           dynamicMemorySize, stream, argArray, nullPtr}));
+  }
 
   // Sync & destroy the stream, for synchronous launches.
   if (handleStream) {
diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
index f5462b579b5eb0c..c0b05ef08603332 100644
--- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
@@ -96,3 +96,41 @@ module attributes {gpu.container_module} {
     return
   }
 }
+
+
+// -----
+
+module attributes {gpu.container_module} {
+  // CHECK: gpu.module
+  gpu.module @kernel_module [#nvvm.target] {
+    llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr,
+        %arg2: !llvm.ptr, %arg3: i64, %arg4: i64,
+        %arg5: i64) attributes {gpu.kernel} {
+      llvm.return
+    }
+  }
+
+  func.func @foo(%buffer: memref<?xf32>) {
+  // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
+  // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
+  // CHECK: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32
+  // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64
+    %c8 = arith.constant 8 : index    
+    %c32 = arith.constant 32 : i32
+    %c256 = arith.constant 256 : i32
+    %c2 = arith.constant 2 : index
+
+  // CHECK: gpu.launch_func @kernel_module::@kernel
+  // CHECK: clusters in ([[C2]], [[C2]], [[C2]])
+  // CHECK: blocks in ([[C8]], [[C8]], [[C8]]) threads in ([[C8]], [[C8]], [[C8]]) : i64
+  // CHECK: dynamic_shared_memory_size [[C256]]
+  // CHECK: args([[C32]] : i32, %{{.*}} : !llvm.ptr, %{{.*}} : !llvm.ptr, %{{.*}} : i64, %{{.*}} : i64, %{{.*}} : i64)
+    gpu.launch_func @kernel_module::@kernel
+        clusters in (%c2, %c2, %c2)
+        blocks in (%c8, %c8, %c8)
+        threads in (%c8, %c8, %c8)
+        dynamic_shared_memory_size %c256
+        args(%c32 : i32, %buffer : memref<?xf32>)
+    return
+  }
+}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 680e604151d77fd..2b2170c31050b86 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -57,7 +57,7 @@ module attributes {gpu.container_module} {
   func.func @launch_func_missing_callee_attribute(%sz : index) {
     // expected-error at +1 {{'gpu.launch_func' op requires attribute 'kernel'}}
     "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz)
-        {operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0>}
+        {operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0>}
         : (index, index, index, index, index, index) -> ()
     return
   }
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
new file mode 100644
index 000000000000000..5beba48813480f5
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s \
+// RUN:  -test-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
+// RUN:  | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --shared-libs=%mlir_c_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN:  | FileCheck %s
+
+// CHECK: clusterIdx: (1, 1, 0) in Cluster Dimension: (2, 2, 1) blockIdx: (3, 3, 0) 
+
+module attributes {gpu.container_module} {
+  func.func @main() {
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c4 = arith.constant 4 : index    
+    gpu.launch_func  @gpumodule::@kernel_cluster clusters in(%c2,%c2,%c1)  blocks in (%c4, %c4, %c1) threads in (%c1, %c1, %c1)  
+    return
+  }
+  gpu.module @gpumodule {
+    gpu.func @kernel_cluster() kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 2, 2, 1>} {
+      %cidX = gpu.cluster_id  x
+      %cidY = gpu.cluster_id  y
+      %cidZ = gpu.cluster_id  z
+      %cdimX = gpu.cluster_dim  x
+      %cdimY = gpu.cluster_dim  y
+      %cdimZ = gpu.cluster_dim  z
+      %bidX = gpu.block_id  x
+      %bidY = gpu.block_id  y
+      %bidZ = gpu.block_id  z
+      %cidX_i32 = index.casts %cidX : index to i32
+      %cidY_i32 = index.casts %cidY : index to i32
+      %cidZ_i32 = index.casts %cidZ : index to i32
+      %cdimX_i32 = index.casts %cdimX : index to i32
+      %cdimY_i32 = index.casts %cdimY : index to i32
+      %cdimZ_i32 = index.casts %cdimZ : index to i32
+      %bidX_i32 = index.casts %bidX : index to i32
+      %bidY_i32 = index.casts %bidY : index to i32
+      %bidZ_i32 = index.casts %bidZ : index to i32
+
+      %c3 = arith.constant 3 : index
+      %cnd1 =  arith.cmpi eq, %bidX, %c3 : index
+      %cnd2 =  arith.cmpi eq, %bidY, %c3 : index
+      scf.if %cnd1 {
+        scf.if %cnd2 {
+          gpu.printf "clusterIdx: (%d, %d, %d) in Cluster Dimension: (%d, %d, %d) blockIdx: (%d, %d, %d) \n" 
+            %cidX_i32,
+            %cidY_i32,
+            %cidZ_i32,
+            %cdimX_i32,
+            %cdimY_i32,
+            %cdimZ_i32,
+            %bidX_i32,
+            %bidY_i32,
+            %bidZ_i32      
+            : 
+            i32, i32, i32, i32, i32, i32, i32, i32, i32
+        }
+      }
+      
+      gpu.return
+    }
+  }
+}
+
diff --git a/mlir/test/Target/LLVMIR/gpu.mlir b/mlir/test/Target/LLVMIR/gpu.mlir
index fddbbee962c1aee..190b53bcf2084b4 100644
--- a/mlir/test/Target/LLVMIR/gpu.mlir
+++ b/mlir/test/Target/LLVMIR/gpu.mlir
@@ -75,3 +75,22 @@ module attributes {gpu.container_module} {
   llvm.func @mgpuStreamSynchronize(!llvm.ptr)
   llvm.func @mgpuStreamDestroy(!llvm.ptr)
 }
+
+// -----
+
+// Test cluster/block/thread syntax.
+module attributes {gpu.container_module} {
+  // CHECK: @kernel_module_bin_cst = internal constant [4 x i8] c"BLOB", align 8
+  gpu.binary @kernel_module  [#gpu.object<#nvvm.target, "BLOB">]
+  llvm.func @foo() {
+  // CHECK: [[S2:%.*]] = alloca ptr, i64 0, align 8
+  // CHECK: [[S3:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst)
+  // CHECK: [[S4:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[S3]], ptr @kernel_module_kernel_kernel_name)
+  // CHECK: [[S5:%.*]] = call ptr @mgpuStreamCreate()
+  // CHECK: call void @mgpuLaunchClusterKernel(ptr [[S4]], i64 2, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i32 0, ptr [[S5]], ptr [[S2]], ptr null)
+    %0 = llvm.mlir.constant(1 : index) : i64
+    %1 = llvm.mlir.constant(2 : index) : i64
+    gpu.launch_func @kernel_module::@kernel clusters in (%1, %0, %0) blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64
+    llvm.return
+  }
+}

>From 14dba1bf562b782d0a594abf2d80500ca01a8926 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 21 Nov 2023 09:36:14 +0100
Subject: [PATCH 2/4] remove deprecated lowering

---
 .../GPUCommon/GPUToLLVMConversion.cpp         | 20 -------------------
 1 file changed, 20 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 381d5100a7a3fdc..55936b12c8607b0 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -121,26 +121,6 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
           llvmPointerType, /* void **extra */
           llvmInt64Type    /* size_t paramsCount */
       }};
-  FunctionCallBuilder launchClusterKernelCallBuilder = {
-      "mgpuLaunchClusterKernel",
-      llvmVoidType,
-      {
-          llvmPointerType, /* void* f */
-          llvmIntPtrType,  /* intptr_t clusterXDim */
-          llvmIntPtrType,  /* intptr_t clusteryDim */
-          llvmIntPtrType,  /* intptr_t clusterZDim */
-          llvmIntPtrType,  /* intptr_t gridXDim */
-          llvmIntPtrType,  /* intptr_t gridyDim */
-          llvmIntPtrType,  /* intptr_t gridZDim */
-          llvmIntPtrType,  /* intptr_t blockXDim */
-          llvmIntPtrType,  /* intptr_t blockYDim */
-          llvmIntPtrType,  /* intptr_t blockZDim */
-          llvmInt32Type,   /* unsigned int sharedMemBytes */
-          llvmPointerType, /* void *hstream */
-          llvmPointerType, /* void **kernelParams */
-          llvmPointerType, /* void **extra */
-          llvmInt64Type    /* size_t paramsCount */
-      }};
   FunctionCallBuilder streamCreateCallBuilder = {
       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
   FunctionCallBuilder streamDestroyCallBuilder = {

>From b649b31ed8138c63d9f464428afb777cac68f087 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 24 Nov 2023 09:01:55 +0100
Subject: [PATCH 3/4] fix @qcolombet comments

---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 10 +++++-----
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp     |  3 +++
 mlir/test/Dialect/GPU/ops.mlir             |  3 +++
 3 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f093e4392520263..44ca7f924d23290 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -601,12 +601,11 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
     /// The name of the kernel.
     StringAttr getKernelName();
 
-    /// Has cluster
+    /// Returns true if cluster size is specified.
     bool hasClusterSize() {
-      auto totalSize = getOperands().size();
-      totalSize -= getKernelOperands().size();
-      totalSize -= getAsyncDependencies().size();
-      return totalSize > 7;
+      if (getClusterSizeX() && getClusterSizeY() && getClusterSizeZ())
+        return true;
+      return false;
     }
 
     /// The number of operands passed to the kernel function.
@@ -616,6 +615,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
     Value getKernelOperand(unsigned i);
 
     /// Get the SSA values passed as operands to specify the cluster size.
+    /// When the cluster sizes are not specified, it asserts.
     KernelDim3 getClusterSizeOperandValues();
 
     /// Get the SSA values passed as operands to specify the grid size.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 156a993131ad379..7941219e07833eb 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -31,6 +31,7 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/StringSaver.h"
+#include <cassert>
 
 using namespace mlir;
 using namespace mlir::gpu;
@@ -1082,6 +1083,8 @@ KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
 }
 
 KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
+  assert(hasClusterSize() &&
+         "cluster size is not set, check hasClusterSize() first");
   auto operands = getOperands().drop_front(getAsyncDependencies().size());
   return KernelDim3{operands[6], operands[7], operands[8]};
 }
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index c638e0b21ab6f1f..481934364156376 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -152,6 +152,9 @@ module attributes {gpu.container_module} {
     // CHECK: gpu.launch_func @kernels::@kernel_1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
     gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%0 : f32, %1 : memref<?xf32, 1>)
 
+    // CHECK: gpu.launch_func @kernels::@kernel_1 clusters in (%{{.*}}, %{{.*}}, %{{.*}}) blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
+    gpu.launch_func @kernels::@kernel_1 clusters in (%cst, %cst, %cst) blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%0 : f32, %1 : memref<?xf32, 1>)
+
     gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) dynamic_shared_memory_size %c0 args(%0 : f32, %1 : memref<?xf32, 1>)
 
     // CHECK: gpu.launch_func @kernels::@kernel_2 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}})

>From 4b092f1bc8b8078a1c630dce6e6c0804d00ef551 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 27 Nov 2023 11:02:09 +0100
Subject: [PATCH 4/4] use hasClusterSize

---
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7941219e07833eb..6e46fcb0d74bd8e 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1100,7 +1100,7 @@ LogicalResult LaunchFuncOp::verify() {
                        GPUDialect::getContainerModuleAttrName() +
                        "' attribute");
 
-  if (getClusterSizeX()) {
+  if (hasClusterSize()) {
     if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
         getClusterSizeZ().getType() != getClusterSizeX().getType())
       return emitOpError()



More information about the Mlir-commits mailing list