[Mlir-commits] [mlir] edf5cae - [mlir][gpu] Support Cluster of Thread Blocks in `gpu.launch_func` (#72871)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 02:05:12 PST 2023


Author: Guray Ozen
Date: 2023-11-27T11:05:07+01:00
New Revision: edf5cae7391cdb097a090ea142dfa7ac6ac03555

URL: https://github.com/llvm/llvm-project/commit/edf5cae7391cdb097a090ea142dfa7ac6ac03555
DIFF: https://github.com/llvm/llvm-project/commit/edf5cae7391cdb097a090ea142dfa7ac6ac03555.diff

LOG: [mlir][gpu] Support Cluster of Thread Blocks in `gpu.launch_func` (#72871)

NVIDIA Hopper architecture introduced the Cooperative Group Array (CGA).
It is a new level of parallelism, allowing clustering of Cooperative
Thread Arrays (CTA) to synchronize and communicate through shared memory
while running concurrently.

This PR enables support for CGA within the `gpu.launch_func` in the GPU
dialect. It extends `gpu.launch_func` to accommodate this functionality.

The GPU dialect remains architecture-agnostic, so we've added CGA
functionality as optional parameters. We want to leverage mechanisms
that we have in the GPU dialects such as outlining and kernel launching,
making it a practical and convenient choice.

An example of this implementation can be seen below:

```
gpu.launch_func @kernel_module::@kernel
                clusters in (%1, %0, %0) // <-- Optional
                blocks in (%0, %0, %0)
                threads in (%0, %0, %0)
```

The PR also introduces index and dimensions Ops specific to clusters,
binding them to NVVM Ops:

```
%cidX = gpu.cluster_id  x
%cidY = gpu.cluster_id  y
%cidZ = gpu.cluster_id  z

%cdimX = gpu.cluster_dim  x
%cdimY = gpu.cluster_dim  y
%cdimZ = gpu.cluster_dim  z
```

We will introduce cluster support in `gpu.launch` Op in an upcoming PR. 

See [the
documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cluster-of-cooperative-thread-arrays)
provided by NVIDIA for details.

Added: 
    mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
    mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
    mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/GPU/ops.mlir
    mlir/test/Target/LLVMIR/gpu.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index e11c5c393648de7..826df0012fb8f0a 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -53,6 +53,32 @@ class GPU_IndexOp<string mnemonic, list<Trait> traits = []> :
   let assemblyFormat = "$dimension attr-dict";
 }
 
+def GPU_ClusterDimOp : GPU_IndexOp<"cluster_dim"> {
+  let description = [{
+    Returns the number of thread blocks in the cluster along
+    the x, y, or z `dimension`.
+
+    Example:
+
+    ```mlir
+    %cDimX = gpu.cluster_dim x
+    ```
+  }];
+}
+
+def GPU_ClusterIdOp : GPU_IndexOp<"cluster_id"> {
+  let description = [{
+    Returns the cluster id, i.e. the index of the current cluster within the 
+    grid along the x, y, or z `dimension`.
+
+    Example:
+
+    ```mlir
+    %cIdY = gpu.cluster_id y
+    ```
+  }];
+}
+
 def GPU_BlockDimOp : GPU_IndexOp<"block_dim"> {
   let description = [{
     Returns the number of threads in the thread block (aka the block size) along
@@ -467,8 +493,15 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
                      "blockSizeY", "blockSizeZ"]>]>,
     Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
                SymbolRefAttr:$kernel,
-               LaunchIndx:$gridSizeX, LaunchIndx:$gridSizeY, LaunchIndx:$gridSizeZ,
-               LaunchIndx:$blockSizeX, LaunchIndx:$blockSizeY, LaunchIndx:$blockSizeZ,
+               LaunchIndx:$gridSizeX, 
+               LaunchIndx:$gridSizeY, 
+               LaunchIndx:$gridSizeZ,
+               LaunchIndx:$blockSizeX, 
+               LaunchIndx:$blockSizeY, 
+               LaunchIndx:$blockSizeZ,
+               Optional<LaunchIndx>:$clusterSizeX,
+               Optional<LaunchIndx>:$clusterSizeY,
+               Optional<LaunchIndx>:$clusterSizeZ,
                Optional<I32>:$dynamicSharedMemorySize,
                Variadic<AnyType>:$kernelOperands,
                Optional<AnyType>:$asyncObject)>,
@@ -506,6 +539,12 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
     The remaining operands if present are passed as arguments to the kernel
     function.
 
+    The `gpu.launch_func` also supports kernel launching with clusters if 
+    supported by the target architecture. The cluster size can be set by 
+    `clusterSizeX`, `clusterSizeY`, and `clusterSizeZ` arguments. When these 
+    arguments are present, the Op launches a kernel that clusters the given 
+    thread blocks. This feature is exclusive to certain architectures.
+
     Example:
 
     ```mlir
@@ -535,6 +574,15 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
           %gDimY = gpu.grid_dim y
           %gDimZ = gpu.grid_dim z
 
+          // (Optional)  Cluster size only for support architectures
+          %cIdX = gpu.cluster_id x
+          %cIdY = gpu.cluster_id y
+          %cIdZ = gpu.cluster_id z
+
+          %cDimX = gpu.cluster_dim x
+          %cDimY = gpu.cluster_dim y
+          %cDimZ = gpu.cluster_dim z
+
           "some_op"(%bx, %tx) : (index, index) -> ()
           %42 = load %arg1[%bx] : memref<?xf32, 1>
         }
@@ -545,6 +593,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
           async                           // (Optional) Don't block host, return token.
           [%t0]                           // (Optional) Execute only after %t0 has completed.
           @kernels::@kernel_1             // Kernel function.
+          clusters in (%cst, %cst, %cst)  // (Optional) Cluster size only for support architectures. 
           blocks in (%cst, %cst, %cst)    // Grid size.
           threads in (%cst, %cst, %cst)   // Block size.
           dynamic_shared_memory_size %s   // (Optional) Amount of dynamic shared
@@ -562,11 +611,13 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
       "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
       "ValueRange":$kernelOperands,
       CArg<"Type", "nullptr">:$asyncTokenType,
-      CArg<"ValueRange", "{}">:$asyncDependencies)>,
+      CArg<"ValueRange", "{}">:$asyncDependencies,
+      CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize)>,
     OpBuilder<(ins "SymbolRefAttr":$kernel, "KernelDim3":$gridSize,
       "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
       "ValueRange":$kernelOperands,
-      CArg<"Value", "nullptr">:$asyncObject)>
+      CArg<"Value", "nullptr">:$asyncObject,
+      CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize)>
   ];
 
   let extraClassDeclaration = [{
@@ -576,12 +627,23 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
     /// The name of the kernel.
     StringAttr getKernelName();
 
+    /// Returns true if cluster size is specified.
+    bool hasClusterSize() {
+      if (getClusterSizeX() && getClusterSizeY() && getClusterSizeZ())
+        return true;
+      return false;
+    }
+
     /// The number of operands passed to the kernel function.
     unsigned getNumKernelOperands();
 
     /// The i-th operand passed to the kernel function.
     Value getKernelOperand(unsigned i);
 
+    /// Get the SSA values passed as operands to specify the cluster size.
+    /// When the cluster sizes are not specified, it asserts.
+    KernelDim3 getClusterSizeOperandValues();
+
     /// Get the SSA values passed as operands to specify the grid size.
     KernelDim3 getGridSizeOperandValues();
 
@@ -597,10 +659,11 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
   let assemblyFormat = [{
       custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
       (`<` $asyncObject^ `:` type($asyncObject) `>`)?
-      $kernel
+      $kernel      
+      ( `clusters` `in` ` ` `(` $clusterSizeX^ `,` $clusterSizeY `,` $clusterSizeZ `)` )?
       `blocks` `in` ` ` `(` $gridSizeX `,` $gridSizeY `,` $gridSizeZ `)`
       `threads` `in` ` ` `(` $blockSizeX `,` $blockSizeY `,` $blockSizeZ `)`
-      custom<LaunchDimType>(type($gridSizeX))
+      custom<LaunchDimType>(type($gridSizeX), ref($clusterSizeX), type($clusterSizeX), type($clusterSizeY), type($clusterSizeZ))
       (`dynamic_shared_memory_size` $dynamicSharedMemorySize^)?
       custom<LaunchFuncOperands>($kernelOperands, type($kernelOperands)) attr-dict
   }];

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3dd8aae81c5933f..2da97c20e9c984e 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -1128,13 +1128,19 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
         loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(),
         rewriter, /*useBarePtrCallConv=*/kernelBarePtrCallConv);
 
+    std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
+    if (launchOp.hasClusterSize()) {
+      clusterSize =
+          gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
+                          adaptor.getClusterSizeZ()};
+    }
     rewriter.create<gpu::LaunchFuncOp>(
         launchOp.getLoc(), launchOp.getKernelAttr(),
         gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
                         adaptor.getGridSizeZ()},
         gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
                         adaptor.getBlockSizeZ()},
-        adaptor.getDynamicSharedMemorySize(), arguments, stream);
+        adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
     if (launchOp.getAsyncToken())
       rewriter.replaceOp(launchOp, {stream});
     else

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 86a77f557cb9579..9456784c406aebb 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -313,17 +313,20 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns) {
   populateWithGenerated(patterns);
   patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
-  patterns
-      .add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
-                                       NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
-           GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
-                                       NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
-           GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
-                                       NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
-           GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
-                                       NVVM::GridDimYOp, NVVM::GridDimZOp>,
-           GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
-          converter);
+  patterns.add<
+      GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
+                                  NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
+      GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
+                                  NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
+      GPUIndexIntrinsicOpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
+                                  NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>,
+      GPUIndexIntrinsicOpLowering<gpu::ClusterDimOp, NVVM::ClusterDimXOp,
+                                  NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
+      GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
+                                  NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
+      GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
+                                  NVVM::GridDimYOp, NVVM::GridDimZOp>,
+      GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
 
   patterns.add<GPUDynamicSharedMemoryOpLowering>(
       converter, NVVM::kSharedMemoryAlignmentBit);

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 9517c053c8360ef..1b6db1fb0c79f7c 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -32,6 +32,7 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/StringSaver.h"
+#include <cassert>
 
 using namespace mlir;
 using namespace mlir::gpu;
@@ -985,7 +986,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
                          GPUFuncOp kernelFunc, KernelDim3 gridSize,
                          KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
                          ValueRange kernelOperands, Type asyncTokenType,
-                         ValueRange asyncDependencies) {
+                         ValueRange asyncDependencies,
+                         std::optional<KernelDim3> clusterSize) {
   result.addOperands(asyncDependencies);
   if (asyncTokenType)
     result.types.push_back(builder.getType<AsyncTokenType>());
@@ -993,6 +995,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
   // Add grid and block sizes as op operands, followed by the data operands.
   result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
                       getBlockSize.y, getBlockSize.z});
+  if (clusterSize.has_value())
+    result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
   if (dynamicSharedMemorySize)
     result.addOperands(dynamicSharedMemorySize);
   result.addOperands(kernelOperands);
@@ -1008,6 +1012,11 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
   for (auto &sz : prop.operandSegmentSizes)
     sz = 1;
   prop.operandSegmentSizes[0] = asyncDependencies.size();
+  if (!clusterSize.has_value()) {
+    prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
+    prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
+    prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
+  }
   prop.operandSegmentSizes[segmentSizesLen - 3] =
       dynamicSharedMemorySize ? 1 : 0;
   prop.operandSegmentSizes[segmentSizesLen - 2] =
@@ -1018,10 +1027,13 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
                          SymbolRefAttr kernel, KernelDim3 gridSize,
                          KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
-                         ValueRange kernelOperands, Value asyncObject) {
+                         ValueRange kernelOperands, Value asyncObject,
+                         std::optional<KernelDim3> clusterSize) {
   // Add grid and block sizes as op operands, followed by the data operands.
   result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
                       getBlockSize.y, getBlockSize.z});
+  if (clusterSize.has_value())
+    result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
   if (dynamicSharedMemorySize)
     result.addOperands(dynamicSharedMemorySize);
   result.addOperands(kernelOperands);
@@ -1034,6 +1046,11 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
   for (auto &sz : prop.operandSegmentSizes)
     sz = 1;
   prop.operandSegmentSizes[0] = 0;
+  if (!clusterSize.has_value()) {
+    prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
+    prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
+    prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
+  }
   prop.operandSegmentSizes[segmentSizesLen - 3] =
       dynamicSharedMemorySize ? 1 : 0;
   prop.operandSegmentSizes[segmentSizesLen - 2] =
@@ -1067,6 +1084,13 @@ KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
   return KernelDim3{operands[3], operands[4], operands[5]};
 }
 
+KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
+  assert(hasClusterSize() &&
+         "cluster size is not set, check hasClusterSize() first");
+  auto operands = getOperands().drop_front(getAsyncDependencies().size());
+  return KernelDim3{operands[6], operands[7], operands[8]};
+}
+
 LogicalResult LaunchFuncOp::verify() {
   auto module = (*this)->getParentOfType<ModuleOp>();
   if (!module)
@@ -1078,21 +1102,35 @@ LogicalResult LaunchFuncOp::verify() {
                        GPUDialect::getContainerModuleAttrName() +
                        "' attribute");
 
+  if (hasClusterSize()) {
+    if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
+        getClusterSizeZ().getType() != getClusterSizeX().getType())
+      return emitOpError()
+             << "expects types of the cluster dimensions must be the same";
+  }
+
   return success();
 }
 
-static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy) {
+static ParseResult
+parseLaunchDimType(OpAsmParser &parser, Type &dimTy,
+                   std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
+                   Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
   if (succeeded(parser.parseOptionalColon())) {
     if (parser.parseType(dimTy))
       return failure();
   } else {
     dimTy = IndexType::get(parser.getContext());
   }
+  if (clusterValue.has_value()) {
+    clusterXTy = clusterYTy = clusterZTy = dimTy;
+  }
   return success();
 }
 
-static void printLaunchDimType(OpAsmPrinter &printer, Operation *op,
-                               Type dimTy) {
+static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
+                               Value clusterValue, Type clusterXTy,
+                               Type clusterYTy, Type clusterZTy) {
   if (!dimTy.isIndex())
     printer << ": " << dimTy;
 }

diff  --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index cb2d66d5b0d32da..69017efb9a0e67c 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -19,6 +19,8 @@ using namespace mlir::gpu;
 
 // Maximum grid and block dimensions of all known GPUs are less than 2^32.
 static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
+// Maximum cluster size
+static constexpr uint64_t kMaxClusterDim = 8;
 // Maximum subgroups are no larger than 128.
 static constexpr uint64_t kMaxSubgroupSize = 128;
 
@@ -82,6 +84,17 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
   return std::nullopt;
 }
 
+void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+                                     SetIntRangeFn setResultRange) {
+  setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
+}
+
+void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+                                    SetIntRangeFn setResultRange) {
+  uint64_t max = kMaxClusterDim;
+  setResultRange(getResult(), getIndexRange(0, max - 1ULL));
+}
+
 void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                    SetIntRangeFn setResultRange) {
   std::optional<uint64_t> knownVal =

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index a8e743c519135f7..9b63d2a22a7a31f 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -194,6 +194,60 @@ mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY,
                                       extra));
 }
 
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
+    CUfunction function, intptr_t clusterX, intptr_t clusterY,
+    intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
+    intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem,
+    CUstream stream, void **params, void **extra, size_t /*paramsCount*/) {
+  ScopedContext scopedContext;
+  if (smem > 0) {
+    // Avoid checking driver as it's more expensive than if statement
+    int32_t maxShmem = 0;
+    CUdevice device = getDefaultCuDevice();
+    CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
+    CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
+        &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
+        device));
+    if (maxShmem < smem) {
+      fprintf(stderr,
+              "Requested shared memory (%dkb) is larger than maximum allowed "
+              "shared memory (%dkb) for this device\n",
+              smem, maxShmem);
+    }
+    CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
+        function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
+  }
+  CUlaunchConfig config;
+  config.gridDimX = gridX;
+  config.gridDimY = gridY;
+  config.gridDimZ = gridZ;
+  config.blockDimX = blockX;
+  config.blockDimY = blockY;
+  config.blockDimZ = blockZ;
+  config.sharedMemBytes = smem;
+  config.hStream = stream;
+  CUlaunchAttribute launchAttr[2];
+  launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
+  launchAttr[0].value.clusterDim.x = clusterX;
+  launchAttr[0].value.clusterDim.y = clusterY;
+  launchAttr[0].value.clusterDim.z = clusterZ;
+  launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
+  launchAttr[1].value.clusterSchedulingPolicyPreference =
+      CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
+  config.numAttrs = 2;
+  config.attrs = launchAttr;
+
+  debug_print("Launching kernel,"
+              "cluster: %ld, %ld, %ld, "
+              "grid=%ld,%ld,%ld, "
+              "threads: %ld, %ld, %ld, "
+              "smem: %dkb\n",
+              clusterX, clusterY, clusterZ, gridX, gridY, gridZ, blockX, blockY,
+              blockZ, smem);
+
+  CUDA_REPORT_IF_ERROR(cuLaunchKernelEx(&config, function, params, extra));
+}
+
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() {
   ScopedContext scopedContext;
   CUstream stream = nullptr;

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
index 47fe6973778cd7f..2acccb7c2fafa47 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
@@ -136,6 +136,9 @@ class LaunchKernel {
   // Get the kernel launch callee.
   FunctionCallee getKernelLaunchFn();
 
+  // Get the kernel launch callee.
+  FunctionCallee getClusterKernelLaunchFn();
+
   // Get the module function callee.
   FunctionCallee getModuleFunctionFn();
 
@@ -228,6 +231,17 @@ llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
           false));
 }
 
+llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
+  return module.getOrInsertFunction(
+      "mgpuLaunchClusterKernel",
+      FunctionType::get(
+          voidTy,
+          ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
+                            intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
+                            i32Ty, ptrTy, ptrTy, ptrTy}),
+          false));
+}
+
 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
   return module.getOrInsertFunction(
       "mgpuModuleGetFunction",
@@ -401,10 +415,22 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
 
   // Create the launch call.
   Value *nullPtr = ConstantPointerNull::get(ptrTy);
-  builder.CreateCall(
-      getKernelLaunchFn(),
-      ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, bz,
-                         dynamicMemorySize, stream, argArray, nullPtr}));
+
+  // Launch kernel with clusters if cluster size is specified.
+  if (op.hasClusterSize()) {
+    mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
+    Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
+          *cz = llvmValue(cluster.z);
+    builder.CreateCall(
+        getClusterKernelLaunchFn(),
+        ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
+                           dynamicMemorySize, stream, argArray, nullPtr}));
+  } else {
+    builder.CreateCall(
+        getKernelLaunchFn(),
+        ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by, bz,
+                           dynamicMemorySize, stream, argArray, nullPtr}));
+  }
 
   // Sync & destroy the stream, for synchronous launches.
   if (handleStream) {

diff  --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
index f5462b579b5eb0c..c0b05ef08603332 100644
--- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
@@ -96,3 +96,41 @@ module attributes {gpu.container_module} {
     return
   }
 }
+
+
+// -----
+
+module attributes {gpu.container_module} {
+  // CHECK: gpu.module
+  gpu.module @kernel_module [#nvvm.target] {
+    llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr,
+        %arg2: !llvm.ptr, %arg3: i64, %arg4: i64,
+        %arg5: i64) attributes {gpu.kernel} {
+      llvm.return
+    }
+  }
+
+  func.func @foo(%buffer: memref<?xf32>) {
+  // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
+  // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
+  // CHECK: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32
+  // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64
+    %c8 = arith.constant 8 : index    
+    %c32 = arith.constant 32 : i32
+    %c256 = arith.constant 256 : i32
+    %c2 = arith.constant 2 : index
+
+  // CHECK: gpu.launch_func @kernel_module::@kernel
+  // CHECK: clusters in ([[C2]], [[C2]], [[C2]])
+  // CHECK: blocks in ([[C8]], [[C8]], [[C8]]) threads in ([[C8]], [[C8]], [[C8]]) : i64
+  // CHECK: dynamic_shared_memory_size [[C256]]
+  // CHECK: args([[C32]] : i32, %{{.*}} : !llvm.ptr, %{{.*}} : !llvm.ptr, %{{.*}} : i64, %{{.*}} : i64, %{{.*}} : i64)
+    gpu.launch_func @kernel_module::@kernel
+        clusters in (%c2, %c2, %c2)
+        blocks in (%c8, %c8, %c8)
+        threads in (%c8, %c8, %c8)
+        dynamic_shared_memory_size %c256
+        args(%c32 : i32, %buffer : memref<?xf32>)
+    return
+  }
+}

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index df9921ef14d3b51..3a2197ad4d5a172 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -57,7 +57,7 @@ module attributes {gpu.container_module} {
   func.func @launch_func_missing_callee_attribute(%sz : index) {
     // expected-error at +1 {{'gpu.launch_func' op requires attribute 'kernel'}}
     "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz)
-        {operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0>}
+        {operandSegmentSizes = array<i32: 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0>}
         : (index, index, index, index, index, index) -> ()
     return
   }

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index c638e0b21ab6f1f..481934364156376 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -152,6 +152,9 @@ module attributes {gpu.container_module} {
     // CHECK: gpu.launch_func @kernels::@kernel_1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
     gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%0 : f32, %1 : memref<?xf32, 1>)
 
+    // CHECK: gpu.launch_func @kernels::@kernel_1 clusters in (%{{.*}}, %{{.*}}, %{{.*}}) blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
+    gpu.launch_func @kernels::@kernel_1 clusters in (%cst, %cst, %cst) blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%0 : f32, %1 : memref<?xf32, 1>)
+
     gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) dynamic_shared_memory_size %c0 args(%0 : f32, %1 : memref<?xf32, 1>)
 
     // CHECK: gpu.launch_func @kernels::@kernel_2 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}})

diff  --git a/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
new file mode 100644
index 000000000000000..5beba48813480f5
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s \
+// RUN:  -test-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
+// RUN:  | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --shared-libs=%mlir_c_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN:  | FileCheck %s
+
+// CHECK: clusterIdx: (1, 1, 0) in Cluster Dimension: (2, 2, 1) blockIdx: (3, 3, 0) 
+
+module attributes {gpu.container_module} {
+  func.func @main() {
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c4 = arith.constant 4 : index    
+    gpu.launch_func  @gpumodule::@kernel_cluster clusters in(%c2,%c2,%c1)  blocks in (%c4, %c4, %c1) threads in (%c1, %c1, %c1)  
+    return
+  }
+  gpu.module @gpumodule {
+    gpu.func @kernel_cluster() kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 2, 2, 1>} {
+      %cidX = gpu.cluster_id  x
+      %cidY = gpu.cluster_id  y
+      %cidZ = gpu.cluster_id  z
+      %cdimX = gpu.cluster_dim  x
+      %cdimY = gpu.cluster_dim  y
+      %cdimZ = gpu.cluster_dim  z
+      %bidX = gpu.block_id  x
+      %bidY = gpu.block_id  y
+      %bidZ = gpu.block_id  z
+      %cidX_i32 = index.casts %cidX : index to i32
+      %cidY_i32 = index.casts %cidY : index to i32
+      %cidZ_i32 = index.casts %cidZ : index to i32
+      %cdimX_i32 = index.casts %cdimX : index to i32
+      %cdimY_i32 = index.casts %cdimY : index to i32
+      %cdimZ_i32 = index.casts %cdimZ : index to i32
+      %bidX_i32 = index.casts %bidX : index to i32
+      %bidY_i32 = index.casts %bidY : index to i32
+      %bidZ_i32 = index.casts %bidZ : index to i32
+
+      %c3 = arith.constant 3 : index
+      %cnd1 =  arith.cmpi eq, %bidX, %c3 : index
+      %cnd2 =  arith.cmpi eq, %bidY, %c3 : index
+      scf.if %cnd1 {
+        scf.if %cnd2 {
+          gpu.printf "clusterIdx: (%d, %d, %d) in Cluster Dimension: (%d, %d, %d) blockIdx: (%d, %d, %d) \n" 
+            %cidX_i32,
+            %cidY_i32,
+            %cidZ_i32,
+            %cdimX_i32,
+            %cdimY_i32,
+            %cdimZ_i32,
+            %bidX_i32,
+            %bidY_i32,
+            %bidZ_i32      
+            : 
+            i32, i32, i32, i32, i32, i32, i32, i32, i32
+        }
+      }
+      
+      gpu.return
+    }
+  }
+}
+

diff  --git a/mlir/test/Target/LLVMIR/gpu.mlir b/mlir/test/Target/LLVMIR/gpu.mlir
index fddbbee962c1aee..190b53bcf2084b4 100644
--- a/mlir/test/Target/LLVMIR/gpu.mlir
+++ b/mlir/test/Target/LLVMIR/gpu.mlir
@@ -75,3 +75,22 @@ module attributes {gpu.container_module} {
   llvm.func @mgpuStreamSynchronize(!llvm.ptr)
   llvm.func @mgpuStreamDestroy(!llvm.ptr)
 }
+
+// -----
+
+// Test cluster/block/thread syntax.
+module attributes {gpu.container_module} {
+  // CHECK: @kernel_module_bin_cst = internal constant [4 x i8] c"BLOB", align 8
+  gpu.binary @kernel_module  [#gpu.object<#nvvm.target, "BLOB">]
+  llvm.func @foo() {
+  // CHECK: [[S2:%.*]] = alloca ptr, i64 0, align 8
+  // CHECK: [[S3:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst)
+  // CHECK: [[S4:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[S3]], ptr @kernel_module_kernel_kernel_name)
+  // CHECK: [[S5:%.*]] = call ptr @mgpuStreamCreate()
+  // CHECK: call void @mgpuLaunchClusterKernel(ptr [[S4]], i64 2, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i32 0, ptr [[S5]], ptr [[S2]], ptr null)
+    %0 = llvm.mlir.constant(1 : index) : i64
+    %1 = llvm.mlir.constant(2 : index) : i64
+    gpu.launch_func @kernel_module::@kernel clusters in (%1, %0, %0) blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64
+    llvm.return
+  }
+}


        


More information about the Mlir-commits mailing list