[Mlir-commits] [mlir] [MLIR][GPU] Add support for non-portable cluster size attribute (PR #95545)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 14 06:33:14 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Pradeep Kumar (schwarzschild-radius)

<details>
<summary>Changes</summary>

This commit adds support for `nonPortableClusterSize` attribute for `gpu.launch` and `gpu.launch_func` ops. This required to enable cluster sizes greater than 8 in the case of Hopper GPUs. This commits also relaxes the constraint on `gpu.cluster_dim_blocks` Op to allow larger cluster sizes. Added test case under `non-portable-cluster-launch.mlir`

---

Patch is 20.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95545.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+26-7) 
- (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+2-1) 
- (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+11-3) 
- (modified) mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp (+2-1) 
- (modified) mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp (+2-1) 
- (modified) mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp (+12-5) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp (+9-2) 
- (added) mlir/test/Integration/GPU/CUDA/sm90/non-portable-cluster-launch.mlir (+124) 
- (modified) mlir/test/Target/LLVMIR/gpu.mlir (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 9c5f7ecd8cbe8..1d97f2c03b2cb 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -544,7 +544,8 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
                Optional<LaunchIndx>:$clusterSizeZ,
                Optional<I32>:$dynamicSharedMemorySize,
                Variadic<AnyType>:$kernelOperands,
-               Optional<AnyType>:$asyncObject)>,
+               Optional<AnyType>:$asyncObject,
+               OptionalAttr<BoolAttr>:$nonPortableClusterSize)>,
     Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
   let summary = "Launches a function as a GPU kernel";
 
@@ -585,6 +586,10 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
     arguments are present, the Op launches a kernel that clusters the given
     thread blocks. This feature is exclusive to certain architectures.
 
+    The `gpu.launch_func` also supports the following optional runtime attributes:
+    - nonPortableClusterSize - launch kernel with non-portable cluster size (only
+    supported on certain architectures)
+
     Example:
 
     ```mlir
@@ -640,7 +645,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
                                           // memory to allocate for a workgroup.
           args(%arg0 : f32,               // (Optional) Kernel arguments.
                %arg1 : memref<?xf32, 1>)
-    }
+    } { nonPortableClusterSize = true }   // Attributes
     ```
   }];
 
@@ -652,12 +657,14 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
       "ValueRange":$kernelOperands,
       CArg<"Type", "nullptr">:$asyncTokenType,
       CArg<"ValueRange", "{}">:$asyncDependencies,
-      CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize)>,
+      CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize,
+      CArg<"BoolAttr", "{}">:$nonPortableClusterSize)>,
     OpBuilder<(ins "SymbolRefAttr":$kernel, "KernelDim3":$gridSize,
       "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
       "ValueRange":$kernelOperands,
       CArg<"Value", "nullptr">:$asyncObject,
-      CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize)>
+      CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize,
+      CArg<"BoolAttr", "{}">:$nonPortableClusterSize)>
   ];
 
   let extraClassDeclaration = [{
@@ -720,7 +727,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
                Optional<Index>:$clusterSizeX, 
                Optional<Index>:$clusterSizeY, 
                Optional<Index>:$clusterSizeZ,
-               Optional<I32>:$dynamicSharedMemorySize)>,
+               Optional<I32>:$dynamicSharedMemorySize,
+               OptionalAttr<BoolAttr>:$nonPortableClusterSize)>,
     Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
   let summary = "GPU kernel launch operation";
 
@@ -815,10 +823,20 @@ def GPU_LaunchOp : GPU_Op<"launch", [
                blocks(%bx, %by, %bz) in (%sz_bx = %3, %sz_by = %4, %sz_bz = %5)
                threads(%tx, %ty, %tz) in (%sz_tx = %6, %sz_ty = %7, %sz_tz = %8)
     {
-      // Cluster, block and thread identifiers, as well as cluster/block/grid 
+      // Cluster, block and thread identifiers, as well as cluster/block/grid
       // sizes are immediately usable inside body region.
       "some_op"(%cx, %bx, %tx) : (index, index, index) -> ()
     }
+
+    // Launch with non-portable cluster size attribute.
+    gpu.launch clusters(%cx, %cy, %cz) in (%sz_cx = %0, %sz_cy = %1, %sz_cz = %2)
+               blocks(%bx, %by, %bz) in (%sz_bx = %3, %sz_by = %4, %sz_bz = %5)
+               threads(%tx, %ty, %tz) in (%sz_tx = %6, %sz_ty = %7, %sz_tz = %8)
+    {
+      // Cluster, block and thread identifiers, as well as cluster/block/grid
+      // sizes are immediately usable inside body region.
+      "some_op"(%cx, %bx, %tx) : (index, index, index) -> ()
+    } { nonPortabeClusterSize = true }
     ```
 
     Rationale: using operation/block arguments gives analyses a clear way of
@@ -843,7 +861,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
       CArg<"TypeRange", "{}">:$privateAttributions,
       CArg<"Value", "nullptr">:$clusterSizeX,
       CArg<"Value", "nullptr">:$clusterSizeY,
-      CArg<"Value", "nullptr">:$clusterSizeZ)>
+      CArg<"Value", "nullptr">:$clusterSizeZ,
+      CArg<"BoolAttr", "{}">:$nonPortableClusterSize)>
   ];
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 92b28ff9c5873..0e2fe70a9706d 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -971,7 +971,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
                       adaptor.getGridSizeZ()},
       gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
                       adaptor.getBlockSizeZ()},
-      adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
+      adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize,
+      adaptor.getNonPortableClusterSizeAttr());
   if (launchOp.getAsyncToken())
     rewriter.replaceOp(launchOp, {stream});
   else
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d8e29da6512d4..bd2bf97343e77 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -649,7 +649,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
                      Type asyncTokenType, ValueRange asyncDependencies,
                      TypeRange workgroupAttributions,
                      TypeRange privateAttributions, Value clusterSizeX,
-                     Value clusterSizeY, Value clusterSizeZ) {
+                     Value clusterSizeY, Value clusterSizeZ,
+                     BoolAttr nonPortableClusterSize) {
   OpBuilder::InsertionGuard g(builder);
 
   // Add a WorkGroup attribution attribute. This attribute is required to
@@ -674,6 +675,9 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
   if (dynamicSharedMemorySize)
     result.addOperands(dynamicSharedMemorySize);
 
+  Properties &prop = result.getOrAddProperties<Properties>();
+  prop.nonPortableClusterSize = nonPortableClusterSize;
+
   // Create a kernel body region with kNumConfigRegionAttributes + N memory
   // attributions, where the first kNumConfigRegionAttributes arguments have
   // `index` type and the rest have the same types as the data operands.
@@ -1085,7 +1089,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
                          KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
                          ValueRange kernelOperands, Type asyncTokenType,
                          ValueRange asyncDependencies,
-                         std::optional<KernelDim3> clusterSize) {
+                         std::optional<KernelDim3> clusterSize,
+                         BoolAttr nonPortableClusterSize) {
   result.addOperands(asyncDependencies);
   if (asyncTokenType)
     result.types.push_back(builder.getType<AsyncTokenType>());
@@ -1105,6 +1110,7 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
 
   Properties &prop = result.getOrAddProperties<Properties>();
   prop.kernel = kernelSymbol;
+  prop.nonPortableClusterSize = nonPortableClusterSize;
   size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
   // Initialize the segment sizes to 1.
   for (auto &sz : prop.operandSegmentSizes)
@@ -1126,7 +1132,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
                          SymbolRefAttr kernel, KernelDim3 gridSize,
                          KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
                          ValueRange kernelOperands, Value asyncObject,
-                         std::optional<KernelDim3> clusterSize) {
+                         std::optional<KernelDim3> clusterSize,
+                         BoolAttr nonPortableClusterSize) {
   // Add grid and block sizes as op operands, followed by the data operands.
   result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
                       getBlockSize.y, getBlockSize.z});
@@ -1139,6 +1146,7 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
     result.addOperands(asyncObject);
   Properties &prop = result.getOrAddProperties<Properties>();
   prop.kernel = kernel;
+  prop.nonPortableClusterSize = nonPortableClusterSize;
   size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
   // Initialize the segment sizes to 1.
   for (auto &sz : prop.operandSegmentSizes)
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index 46b85db8b5431..1b09847e0eea6 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -92,7 +92,8 @@ void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
 
 void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
                                            SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
+  uint64_t max = APInt::getMaxValue(64).getZExtValue();
+  setResultRange(getResult(), getIndexRange(1, max));
 }
 
 void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index f5e80553ae72a..dd33f0e842e5a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -295,7 +295,8 @@ static void convertToLaunchFuncOp(gpu::LaunchOp launchOp,
       launchOp.getBlockSizeOperandValues(),
       launchOp.getDynamicSharedMemorySize(), operands,
       asyncToken ? asyncToken.getType() : nullptr,
-      launchOp.getAsyncDependencies(), clusterSize);
+      launchOp.getAsyncDependencies(), clusterSize,
+      launchOp.getNonPortableClusterSizeAttr());
   launchOp.replaceAllUsesWith(launchFunc);
   launchOp.erase();
 }
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 09dc30365e37c..81dc7b983494a 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -335,11 +335,13 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
 
 #if (CUDA_VERSION >= 12000)
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
-    CUfunction function, intptr_t clusterX, intptr_t clusterY,
-    intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
-    intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem,
-    CUstream stream, void **params, void **extra, size_t /*paramsCount*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuLaunchClusterKernel(CUfunction function, intptr_t clusterX,
+                        intptr_t clusterY, intptr_t clusterZ, intptr_t gridX,
+                        intptr_t gridY, intptr_t gridZ, intptr_t blockX,
+                        intptr_t blockY, intptr_t blockZ, int32_t smem,
+                        bool nonPortableClusterSize, CUstream stream,
+                        void **params, void **extra, size_t /*paramsCount*/) {
   ScopedContext scopedContext;
   if (smem > 0) {
     // Avoid checking driver as it's more expensive than if statement
@@ -358,6 +360,11 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
     CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
         function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
   }
+
+  if (nonPortableClusterSize)
+    CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
+        function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
+
   CUlaunchConfig config;
   config.gridDimX = gridX;
   config.gridDimY = gridY;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
index 0eb33287d608b..bea9f782888de 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
@@ -174,6 +174,7 @@ class LaunchKernel {
   Module &module;
   IRBuilderBase &builder;
   mlir::LLVM::ModuleTranslation &moduleTranslation;
+  Type *i1Ty{};
   Type *i32Ty{};
   Type *i64Ty{};
   Type *voidTy{};
@@ -216,6 +217,7 @@ llvm::LaunchKernel::LaunchKernel(
     Module &module, IRBuilderBase &builder,
     mlir::LLVM::ModuleTranslation &moduleTranslation)
     : module(module), builder(builder), moduleTranslation(moduleTranslation) {
+  i1Ty = builder.getInt1Ty();
   i32Ty = builder.getInt32Ty();
   i64Ty = builder.getInt64Ty();
   ptrTy = builder.getPtrTy(0);
@@ -240,7 +242,7 @@ llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
           voidTy,
           ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
                             intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
-                            i32Ty, ptrTy, ptrTy, ptrTy}),
+                            i32Ty, i1Ty, ptrTy, ptrTy, ptrTy}),
           false));
 }
 
@@ -371,6 +373,10 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
   else
     dynamicMemorySize = ConstantInt::get(i32Ty, 0);
 
+  Value *nonPortableClusterSize = op.getNonPortableClusterSize()
+                                      ? ConstantInt::get(i1Ty, 1)
+                                      : ConstantInt::get(i1Ty, 0);
+
   // Create the argument array.
   Value *argArray = createKernelArgArray(op);
 
@@ -443,7 +449,8 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
     builder.CreateCall(
         getClusterKernelLaunchFn(),
         ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
-                           dynamicMemorySize, stream, argArray, nullPtr}));
+                           dynamicMemorySize, nonPortableClusterSize, stream,
+                           argArray, nullPtr}));
   } else {
     builder.CreateCall(getKernelLaunchFn(),
                        ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/non-portable-cluster-launch.mlir b/mlir/test/Integration/GPU/CUDA/sm90/non-portable-cluster-launch.mlir
new file mode 100644
index 0000000000000..ed72116563d6e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/non-portable-cluster-launch.mlir
@@ -0,0 +1,124 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
+// RUN:  | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --shared-libs=%mlir_c_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+// CHECK: clusterIdx: (3, 3, 0) in Cluster Dimension: (4, 4, 1) blockIdx: (15, 15, 0)
+// CHECK: clusterIdx: (3, 3, 0) in Cluster Dimension: (4, 4, 1) blockIdx: (15, 15, 0)
+
+module attributes {gpu.container_module} {
+gpu.module @gpumodule {
+  gpu.func @kernel_cluster() kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 2, 2, 1>} {
+    %cidX = gpu.cluster_id  x
+    %cidY = gpu.cluster_id  y
+    %cidZ = gpu.cluster_id  z
+    %cdimX = gpu.cluster_dim_blocks  x
+    %cdimY = gpu.cluster_dim_blocks  y
+    %cdimZ = gpu.cluster_dim_blocks  z
+    %bidX = gpu.block_id  x
+    %bidY = gpu.block_id  y
+    %bidZ = gpu.block_id  z
+    %cidX_i32 = index.casts %cidX : index to i32
+    %cidY_i32 = index.casts %cidY : index to i32
+    %cidZ_i32 = index.casts %cidZ : index to i32
+    %cdimX_i32 = index.casts %cdimX : index to i32
+    %cdimY_i32 = index.casts %cdimY : index to i32
+    %cdimZ_i32 = index.casts %cdimZ : index to i32
+    %bidX_i32 = index.casts %bidX : index to i32
+    %bidY_i32 = index.casts %bidY : index to i32
+    %bidZ_i32 = index.casts %bidZ : index to i32
+
+    %c_1 = arith.constant -1 : index
+    %cBlocksX = gpu.grid_dim x
+    %cN_1 = arith.addi %cBlocksX, %c_1 : index
+    %cnd1 =  arith.cmpi eq, %bidX, %cN_1 : index
+    %cnd2 =  arith.cmpi eq, %bidY, %cN_1 : index
+    scf.if %cnd1 {
+      scf.if %cnd2 {
+        gpu.printf "clusterIdx: (%d, %d, %d) in Cluster Dimension: (%d, %d, %d) blockIdx: (%d, %d, %d) \n"
+          %cidX_i32,
+          %cidY_i32,
+          %cidZ_i32,
+          %cdimX_i32,
+          %cdimY_i32,
+          %cdimZ_i32,
+          %bidX_i32,
+          %bidY_i32,
+          %bidZ_i32
+          :
+          i32, i32, i32, i32, i32, i32, i32, i32, i32
+      }
+    }
+    gpu.return
+  }
+}
+
+func.func @main() {
+  %cDimX = arith.constant 4 : index
+  %cDimY = arith.constant 4 : index
+  %cDimZ = arith.constant 1 : index
+  %gDimX = arith.constant 16 : index
+  %gDimY = arith.constant 16 : index
+  %gDimZ = arith.constant 1 : index
+  %bDimX = arith.constant 1 : index
+  %bDimY = arith.constant 1 : index
+  %bDimZ = arith.constant 1 : index
+
+  gpu.launch clusters(%cx, %cy, %cz) in (%cluster_x = %cDimX, %cluster_y = %cDimY,
+                                       %cluster_z = %cDimZ)
+             blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY,
+                                       %grid_z = %gDimZ)
+             threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY,
+                                        %block_z = %bDimZ) {
+      %cidX = gpu.cluster_id  x
+      %cidY = gpu.cluster_id  y
+      %cidZ = gpu.cluster_id  z
+      %cdimX = gpu.cluster_dim_blocks  x
+      %cdimY = gpu.cluster_dim_blocks  y
+      %cdimZ = gpu.cluster_dim_blocks  z
+      %bidX = gpu.block_id  x
+      %bidY = gpu.block_id  y
+      %bidZ = gpu.block_id  z
+      %cidX_i32 = index.casts %cidX : index to i32
+      %cidY_i32 = index.casts %cidY : index to i32
+      %cidZ_i32 = index.casts %cidZ : index to i32
+      %cdimX_i32 = index.casts %cdimX : index to i32
+      %cdimY_i32 = index.casts %cdimY : index to i32
+      %cdimZ_i32 = index.casts %cdimZ : index to i32
+      %bidX_i32 = index.casts %bidX : index to i32
+      %bidY_i32 = index.casts %bidY : index to i32
+      %bidZ_i32 = index.casts %bidZ : index to i32
+
+      %c_1 = arith.constant -1 : index
+      %cBlocksX = gpu.grid_dim x
+      %cN_1 = arith.addi %cBlocksX, %c_1 : index
+      %cnd1 =  arith.cmpi eq, %bidX, %cN_1 : index
+      %cnd2 =  arith.cmpi eq, %bidY, %cN_1 : index
+      scf.if %cnd1 {
+        scf.if %cnd2 {
+          gpu.printf "clusterIdx: (%d, %d, %d) in Cluster Dimension: (%d, %d, %d) blockIdx: (%d, %d, %d) \n"
+            %cidX_i32,
+            %cidY_i32,
+            %cidZ_i32,
+            %cdimX_i32,
+            %cdimY_i32,
+            %cdimZ_i32,
+            %bidX_i32,
+            %bidY_i32,
+            %bidZ_i32
+            :
+            i32, i32, i32, i32, i32, i32, i32, i32, i32
+        }
+      }
+
+      gpu.terminator
+  } { nonPortableClusterSize = true}
+
+  gpu.launch_func  @gpumodule::@kernel_cluster clusters in (%cDimX,%cDimY,%cDimZ)  blocks in (%gDimX, %gDimY, %gDimZ) threads in (%bDimX, %bDimY, %bDimZ) { nonPortableClusterSize = true }
+  return
+}
+}
diff --git a/mlir/test/Target/LLVMIR/gpu.mlir b/mlir/test/Target/LLVMIR/gpu.mlir
index 88672bd231df8..205adfcd7b453 100644
--- a/mlir/test/Target/LLVMIR/gpu.mlir
+++ b/mlir/test/Target/LLVMIR/gpu.mlir
@@ -94,7 +94,7 @@ module attributes {gpu.container_module} {
   // CHECK: [[S3:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst, i64 4)
   // CHECK: [[S4:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[S3]], ptr @kernel_module_kernel_kernel_name)
   // CHECK: [[S5:%.*]] = call ptr @mgpuStreamCreate()
-  // CHECK: call void @mgpuLaunchClusterKernel(ptr [[S4]], i64 2, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i32 0, ptr [[S5]], ptr [[S2]], ptr null)
+  // CHECK: call void @mgpuLaunchClusterKernel(ptr [[S4]], i64 2, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i32 0, i1 false, ptr [[S5]], ptr [[S2]], ptr null)
     %0 = llvm.mlir.constant(1 : index) : i64
     %1 = llvm.mlir.constant(2 : index) : i64
     gpu.launch_func @kernel_module::@kernel clusters in (%1, %0, %0) blocks in (%0, %0, %0) threads in (%0, %0...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/95545


More information about the Mlir-commits mailing list