[Mlir-commits] [mlir] [MLIR][GPU] Add support for non-portable cluster size attribute (PR #95545)
Pradeep Kumar
llvmlistbot at llvm.org
Fri Jun 14 06:32:43 PDT 2024
https://github.com/schwarzschild-radius created https://github.com/llvm/llvm-project/pull/95545
This commit adds support for `nonPortableClusterSize` attribute for `gpu.launch` and `gpu.launch_func` ops. This required to enable cluster sizes greater than 8 in the case of Hopper GPUs. This commits also relaxes the constraint on `gpu.cluster_dim_blocks` Op to allow larger cluster sizes. Added test case under `non-portable-cluster-launch.mlir`
>From 314de0d83bbb963912214163c63de24991494663 Mon Sep 17 00:00:00 2001
From: pradeepku <pradeepku at nvidia.com>
Date: Fri, 14 Jun 2024 12:09:41 +0000
Subject: [PATCH] [MLIR][GPU] Add support for non-portable cluster size
attribute
This commit adds support for `nonPortableClusterSize` attribute for
`gpu.launch` and `gpu.launch_func` ops. This required to enable cluster
sizes greater than 8 in the case of Hopper GPUs. This commits also
relaxes the constraint on `gpu.cluster_dim_blocks` Op to allow larger
cluster sizes. Added test case under `non-portable-cluster-launch.mlir`
---
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 33 ++++-
.../GPUCommon/GPUToLLVMConversion.cpp | 3 +-
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 14 +-
.../GPU/IR/InferIntRangeInterfaceImpls.cpp | 3 +-
.../GPU/Transforms/KernelOutlining.cpp | 3 +-
.../ExecutionEngine/CudaRuntimeWrappers.cpp | 17 ++-
.../LLVMIR/Dialect/GPU/SelectObjectAttr.cpp | 11 +-
.../sm90/non-portable-cluster-launch.mlir | 124 ++++++++++++++++++
mlir/test/Target/LLVMIR/gpu.mlir | 2 +-
9 files changed, 189 insertions(+), 21 deletions(-)
create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/non-portable-cluster-launch.mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 9c5f7ecd8cbe8..1d97f2c03b2cb 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -544,7 +544,8 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
Optional<LaunchIndx>:$clusterSizeZ,
Optional<I32>:$dynamicSharedMemorySize,
Variadic<AnyType>:$kernelOperands,
- Optional<AnyType>:$asyncObject)>,
+ Optional<AnyType>:$asyncObject,
+ OptionalAttr<BoolAttr>:$nonPortableClusterSize)>,
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
let summary = "Launches a function as a GPU kernel";
@@ -585,6 +586,10 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
arguments are present, the Op launches a kernel that clusters the given
thread blocks. This feature is exclusive to certain architectures.
+ The `gpu.launch_func` also supports the following optional runtime attributes:
+ - nonPortableClusterSize - launch kernel with non-portable cluster size (only
+ supported on certain architectures)
+
Example:
```mlir
@@ -640,7 +645,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
// memory to allocate for a workgroup.
args(%arg0 : f32, // (Optional) Kernel arguments.
%arg1 : memref<?xf32, 1>)
- }
+ } { nonPortableClusterSize = true } // Attributes
```
}];
@@ -652,12 +657,14 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
"ValueRange":$kernelOperands,
CArg<"Type", "nullptr">:$asyncTokenType,
CArg<"ValueRange", "{}">:$asyncDependencies,
- CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize)>,
+ CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize,
+ CArg<"BoolAttr", "{}">:$nonPortableClusterSize)>,
OpBuilder<(ins "SymbolRefAttr":$kernel, "KernelDim3":$gridSize,
"KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
"ValueRange":$kernelOperands,
CArg<"Value", "nullptr">:$asyncObject,
- CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize)>
+ CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize,
+ CArg<"BoolAttr", "{}">:$nonPortableClusterSize)>
];
let extraClassDeclaration = [{
@@ -720,7 +727,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
Optional<Index>:$clusterSizeX,
Optional<Index>:$clusterSizeY,
Optional<Index>:$clusterSizeZ,
- Optional<I32>:$dynamicSharedMemorySize)>,
+ Optional<I32>:$dynamicSharedMemorySize,
+ OptionalAttr<BoolAttr>:$nonPortableClusterSize)>,
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
let summary = "GPU kernel launch operation";
@@ -815,10 +823,20 @@ def GPU_LaunchOp : GPU_Op<"launch", [
blocks(%bx, %by, %bz) in (%sz_bx = %3, %sz_by = %4, %sz_bz = %5)
threads(%tx, %ty, %tz) in (%sz_tx = %6, %sz_ty = %7, %sz_tz = %8)
{
- // Cluster, block and thread identifiers, as well as cluster/block/grid
+ // Cluster, block and thread identifiers, as well as cluster/block/grid
// sizes are immediately usable inside body region.
"some_op"(%cx, %bx, %tx) : (index, index, index) -> ()
}
+
+ // Launch with non-portable cluster size attribute.
+ gpu.launch clusters(%cx, %cy, %cz) in (%sz_cx = %0, %sz_cy = %1, %sz_cz = %2)
+ blocks(%bx, %by, %bz) in (%sz_bx = %3, %sz_by = %4, %sz_bz = %5)
+ threads(%tx, %ty, %tz) in (%sz_tx = %6, %sz_ty = %7, %sz_tz = %8)
+ {
+ // Cluster, block and thread identifiers, as well as cluster/block/grid
+ // sizes are immediately usable inside body region.
+ "some_op"(%cx, %bx, %tx) : (index, index, index) -> ()
+ } { nonPortabeClusterSize = true }
```
Rationale: using operation/block arguments gives analyses a clear way of
@@ -843,7 +861,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
CArg<"TypeRange", "{}">:$privateAttributions,
CArg<"Value", "nullptr">:$clusterSizeX,
CArg<"Value", "nullptr">:$clusterSizeY,
- CArg<"Value", "nullptr">:$clusterSizeZ)>
+ CArg<"Value", "nullptr">:$clusterSizeZ,
+ CArg<"BoolAttr", "{}">:$nonPortableClusterSize)>
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 92b28ff9c5873..0e2fe70a9706d 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -971,7 +971,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite(
adaptor.getGridSizeZ()},
gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
adaptor.getBlockSizeZ()},
- adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
+ adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize,
+ adaptor.getNonPortableClusterSizeAttr());
if (launchOp.getAsyncToken())
rewriter.replaceOp(launchOp, {stream});
else
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d8e29da6512d4..bd2bf97343e77 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -649,7 +649,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
Type asyncTokenType, ValueRange asyncDependencies,
TypeRange workgroupAttributions,
TypeRange privateAttributions, Value clusterSizeX,
- Value clusterSizeY, Value clusterSizeZ) {
+ Value clusterSizeY, Value clusterSizeZ,
+ BoolAttr nonPortableClusterSize) {
OpBuilder::InsertionGuard g(builder);
// Add a WorkGroup attribution attribute. This attribute is required to
@@ -674,6 +675,9 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
if (dynamicSharedMemorySize)
result.addOperands(dynamicSharedMemorySize);
+ Properties &prop = result.getOrAddProperties<Properties>();
+ prop.nonPortableClusterSize = nonPortableClusterSize;
+
// Create a kernel body region with kNumConfigRegionAttributes + N memory
// attributions, where the first kNumConfigRegionAttributes arguments have
// `index` type and the rest have the same types as the data operands.
@@ -1085,7 +1089,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
ValueRange kernelOperands, Type asyncTokenType,
ValueRange asyncDependencies,
- std::optional<KernelDim3> clusterSize) {
+ std::optional<KernelDim3> clusterSize,
+ BoolAttr nonPortableClusterSize) {
result.addOperands(asyncDependencies);
if (asyncTokenType)
result.types.push_back(builder.getType<AsyncTokenType>());
@@ -1105,6 +1110,7 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
Properties &prop = result.getOrAddProperties<Properties>();
prop.kernel = kernelSymbol;
+ prop.nonPortableClusterSize = nonPortableClusterSize;
size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
// Initialize the segment sizes to 1.
for (auto &sz : prop.operandSegmentSizes)
@@ -1126,7 +1132,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
SymbolRefAttr kernel, KernelDim3 gridSize,
KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
ValueRange kernelOperands, Value asyncObject,
- std::optional<KernelDim3> clusterSize) {
+ std::optional<KernelDim3> clusterSize,
+ BoolAttr nonPortableClusterSize) {
// Add grid and block sizes as op operands, followed by the data operands.
result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
getBlockSize.y, getBlockSize.z});
@@ -1139,6 +1146,7 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
result.addOperands(asyncObject);
Properties &prop = result.getOrAddProperties<Properties>();
prop.kernel = kernel;
+ prop.nonPortableClusterSize = nonPortableClusterSize;
size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
// Initialize the segment sizes to 1.
for (auto &sz : prop.operandSegmentSizes)
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index 46b85db8b5431..1b09847e0eea6 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -92,7 +92,8 @@ void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
+ uint64_t max = APInt::getMaxValue(64).getZExtValue();
+ setResultRange(getResult(), getIndexRange(1, max));
}
void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index f5e80553ae72a..dd33f0e842e5a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -295,7 +295,8 @@ static void convertToLaunchFuncOp(gpu::LaunchOp launchOp,
launchOp.getBlockSizeOperandValues(),
launchOp.getDynamicSharedMemorySize(), operands,
asyncToken ? asyncToken.getType() : nullptr,
- launchOp.getAsyncDependencies(), clusterSize);
+ launchOp.getAsyncDependencies(), clusterSize,
+ launchOp.getNonPortableClusterSizeAttr());
launchOp.replaceAllUsesWith(launchFunc);
launchOp.erase();
}
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 09dc30365e37c..81dc7b983494a 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -335,11 +335,13 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
#if (CUDA_VERSION >= 12000)
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
- CUfunction function, intptr_t clusterX, intptr_t clusterY,
- intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
- intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem,
- CUstream stream, void **params, void **extra, size_t /*paramsCount*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuLaunchClusterKernel(CUfunction function, intptr_t clusterX,
+ intptr_t clusterY, intptr_t clusterZ, intptr_t gridX,
+ intptr_t gridY, intptr_t gridZ, intptr_t blockX,
+ intptr_t blockY, intptr_t blockZ, int32_t smem,
+ bool nonPortableClusterSize, CUstream stream,
+ void **params, void **extra, size_t /*paramsCount*/) {
ScopedContext scopedContext;
if (smem > 0) {
// Avoid checking driver as it's more expensive than if statement
@@ -358,6 +360,11 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
}
+
+ if (nonPortableClusterSize)
+ CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
+ function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
+
CUlaunchConfig config;
config.gridDimX = gridX;
config.gridDimY = gridY;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
index 0eb33287d608b..bea9f782888de 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
@@ -174,6 +174,7 @@ class LaunchKernel {
Module &module;
IRBuilderBase &builder;
mlir::LLVM::ModuleTranslation &moduleTranslation;
+ Type *i1Ty{};
Type *i32Ty{};
Type *i64Ty{};
Type *voidTy{};
@@ -216,6 +217,7 @@ llvm::LaunchKernel::LaunchKernel(
Module &module, IRBuilderBase &builder,
mlir::LLVM::ModuleTranslation &moduleTranslation)
: module(module), builder(builder), moduleTranslation(moduleTranslation) {
+ i1Ty = builder.getInt1Ty();
i32Ty = builder.getInt32Ty();
i64Ty = builder.getInt64Ty();
ptrTy = builder.getPtrTy(0);
@@ -240,7 +242,7 @@ llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
voidTy,
ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
- i32Ty, ptrTy, ptrTy, ptrTy}),
+ i32Ty, i1Ty, ptrTy, ptrTy, ptrTy}),
false));
}
@@ -371,6 +373,10 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
else
dynamicMemorySize = ConstantInt::get(i32Ty, 0);
+ Value *nonPortableClusterSize = op.getNonPortableClusterSize()
+ ? ConstantInt::get(i1Ty, 1)
+ : ConstantInt::get(i1Ty, 0);
+
// Create the argument array.
Value *argArray = createKernelArgArray(op);
@@ -443,7 +449,8 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
builder.CreateCall(
getClusterKernelLaunchFn(),
ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
- dynamicMemorySize, stream, argArray, nullPtr}));
+ dynamicMemorySize, nonPortableClusterSize, stream,
+ argArray, nullPtr}));
} else {
builder.CreateCall(getKernelLaunchFn(),
ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/non-portable-cluster-launch.mlir b/mlir/test/Integration/GPU/CUDA/sm90/non-portable-cluster-launch.mlir
new file mode 100644
index 0000000000000..ed72116563d6e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/non-portable-cluster-launch.mlir
@@ -0,0 +1,124 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \
+// RUN: | mlir-cpu-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+// CHECK: clusterIdx: (3, 3, 0) in Cluster Dimension: (4, 4, 1) blockIdx: (15, 15, 0)
+// CHECK: clusterIdx: (3, 3, 0) in Cluster Dimension: (4, 4, 1) blockIdx: (15, 15, 0)
+
+module attributes {gpu.container_module} {
+gpu.module @gpumodule {
+ gpu.func @kernel_cluster() kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 2, 2, 1>} {
+ %cidX = gpu.cluster_id x
+ %cidY = gpu.cluster_id y
+ %cidZ = gpu.cluster_id z
+ %cdimX = gpu.cluster_dim_blocks x
+ %cdimY = gpu.cluster_dim_blocks y
+ %cdimZ = gpu.cluster_dim_blocks z
+ %bidX = gpu.block_id x
+ %bidY = gpu.block_id y
+ %bidZ = gpu.block_id z
+ %cidX_i32 = index.casts %cidX : index to i32
+ %cidY_i32 = index.casts %cidY : index to i32
+ %cidZ_i32 = index.casts %cidZ : index to i32
+ %cdimX_i32 = index.casts %cdimX : index to i32
+ %cdimY_i32 = index.casts %cdimY : index to i32
+ %cdimZ_i32 = index.casts %cdimZ : index to i32
+ %bidX_i32 = index.casts %bidX : index to i32
+ %bidY_i32 = index.casts %bidY : index to i32
+ %bidZ_i32 = index.casts %bidZ : index to i32
+
+ %c_1 = arith.constant -1 : index
+ %cBlocksX = gpu.grid_dim x
+ %cN_1 = arith.addi %cBlocksX, %c_1 : index
+ %cnd1 = arith.cmpi eq, %bidX, %cN_1 : index
+ %cnd2 = arith.cmpi eq, %bidY, %cN_1 : index
+ scf.if %cnd1 {
+ scf.if %cnd2 {
+ gpu.printf "clusterIdx: (%d, %d, %d) in Cluster Dimension: (%d, %d, %d) blockIdx: (%d, %d, %d) \n"
+ %cidX_i32,
+ %cidY_i32,
+ %cidZ_i32,
+ %cdimX_i32,
+ %cdimY_i32,
+ %cdimZ_i32,
+ %bidX_i32,
+ %bidY_i32,
+ %bidZ_i32
+ :
+ i32, i32, i32, i32, i32, i32, i32, i32, i32
+ }
+ }
+ gpu.return
+ }
+}
+
+func.func @main() {
+ %cDimX = arith.constant 4 : index
+ %cDimY = arith.constant 4 : index
+ %cDimZ = arith.constant 1 : index
+ %gDimX = arith.constant 16 : index
+ %gDimY = arith.constant 16 : index
+ %gDimZ = arith.constant 1 : index
+ %bDimX = arith.constant 1 : index
+ %bDimY = arith.constant 1 : index
+ %bDimZ = arith.constant 1 : index
+
+ gpu.launch clusters(%cx, %cy, %cz) in (%cluster_x = %cDimX, %cluster_y = %cDimY,
+ %cluster_z = %cDimZ)
+ blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY,
+ %grid_z = %gDimZ)
+ threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY,
+ %block_z = %bDimZ) {
+ %cidX = gpu.cluster_id x
+ %cidY = gpu.cluster_id y
+ %cidZ = gpu.cluster_id z
+ %cdimX = gpu.cluster_dim_blocks x
+ %cdimY = gpu.cluster_dim_blocks y
+ %cdimZ = gpu.cluster_dim_blocks z
+ %bidX = gpu.block_id x
+ %bidY = gpu.block_id y
+ %bidZ = gpu.block_id z
+ %cidX_i32 = index.casts %cidX : index to i32
+ %cidY_i32 = index.casts %cidY : index to i32
+ %cidZ_i32 = index.casts %cidZ : index to i32
+ %cdimX_i32 = index.casts %cdimX : index to i32
+ %cdimY_i32 = index.casts %cdimY : index to i32
+ %cdimZ_i32 = index.casts %cdimZ : index to i32
+ %bidX_i32 = index.casts %bidX : index to i32
+ %bidY_i32 = index.casts %bidY : index to i32
+ %bidZ_i32 = index.casts %bidZ : index to i32
+
+ %c_1 = arith.constant -1 : index
+ %cBlocksX = gpu.grid_dim x
+ %cN_1 = arith.addi %cBlocksX, %c_1 : index
+ %cnd1 = arith.cmpi eq, %bidX, %cN_1 : index
+ %cnd2 = arith.cmpi eq, %bidY, %cN_1 : index
+ scf.if %cnd1 {
+ scf.if %cnd2 {
+ gpu.printf "clusterIdx: (%d, %d, %d) in Cluster Dimension: (%d, %d, %d) blockIdx: (%d, %d, %d) \n"
+ %cidX_i32,
+ %cidY_i32,
+ %cidZ_i32,
+ %cdimX_i32,
+ %cdimY_i32,
+ %cdimZ_i32,
+ %bidX_i32,
+ %bidY_i32,
+ %bidZ_i32
+ :
+ i32, i32, i32, i32, i32, i32, i32, i32, i32
+ }
+ }
+
+ gpu.terminator
+ } { nonPortableClusterSize = true}
+
+ gpu.launch_func @gpumodule::@kernel_cluster clusters in (%cDimX,%cDimY,%cDimZ) blocks in (%gDimX, %gDimY, %gDimZ) threads in (%bDimX, %bDimY, %bDimZ) { nonPortableClusterSize = true }
+ return
+}
+}
diff --git a/mlir/test/Target/LLVMIR/gpu.mlir b/mlir/test/Target/LLVMIR/gpu.mlir
index 88672bd231df8..205adfcd7b453 100644
--- a/mlir/test/Target/LLVMIR/gpu.mlir
+++ b/mlir/test/Target/LLVMIR/gpu.mlir
@@ -94,7 +94,7 @@ module attributes {gpu.container_module} {
// CHECK: [[S3:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst, i64 4)
// CHECK: [[S4:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[S3]], ptr @kernel_module_kernel_kernel_name)
// CHECK: [[S5:%.*]] = call ptr @mgpuStreamCreate()
- // CHECK: call void @mgpuLaunchClusterKernel(ptr [[S4]], i64 2, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i32 0, ptr [[S5]], ptr [[S2]], ptr null)
+ // CHECK: call void @mgpuLaunchClusterKernel(ptr [[S4]], i64 2, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i32 0, i1 false, ptr [[S5]], ptr [[S2]], ptr null)
%0 = llvm.mlir.constant(1 : index) : i64
%1 = llvm.mlir.constant(2 : index) : i64
gpu.launch_func @kernel_module::@kernel clusters in (%1, %0, %0) blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64
More information about the Mlir-commits
mailing list