[Mlir-commits] [mlir] [mlir][gpu] Support Cluster of Thread Blocks in `gpu.launch_func` (PR #72871)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 20 06:06:42 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
NVIDIA Hopper architecture introduced the Cooperative Group Array (CGA). It is a new level of parallelism, allowing clustering of Cooperative Thread Arrays (CTA) to synchronize and communicate through shared memory while running concurrently.
This PR enables support for CGA within the `gpu.launch_func` in the GPU dialect. It extends `gpu.launch_func` to accommodate this functionality.
The GPU dialect remains architecture-agnostic, so we've added CGA functionality as optional parameters. We want to leverage mechanisms that we have in the GPU dialects such as outlining and kernel launching, making it a practical and convenient choice.
An example of this implementation can be seen below:
```
gpu.launch_func @<!-- -->kernel_module::@<!-- -->kernel
clusters in (%1, %0, %0) // <-- Optional
blocks in (%0, %0, %0)
threads in (%0, %0, %0)
```
The PR also introduces index and dimensions Ops specific to clusters, binding them to NVVM Ops:
```
%cidX = gpu.cluster_id x
%cidY = gpu.cluster_id y
%cidZ = gpu.cluster_id z
%cdimX = gpu.cluster_dim x
%cdimY = gpu.cluster_dim y
%cdimZ = gpu.cluster_dim z
```
We will introduce cluster support in `gpu.launch` Op in an upcoming PR.
See [the documentation](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cluster-of-cooperative-thread-arrays) provided by NVIDIA for details.
---
Patch is 29.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72871.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+69-6)
- (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+27-1)
- (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+14-11)
- (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+40-5)
- (modified) mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp (+13)
- (modified) mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp (+54)
- (modified) mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp (+30-4)
- (modified) mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir (+38)
- (modified) mlir/test/Dialect/GPU/invalid.mlir (+1-1)
- (added) mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir (+65)
- (modified) mlir/test/Target/LLVMIR/gpu.mlir (+19)
``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 632cdd96c6d4c2b..f093e4392520263 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -53,6 +53,32 @@ class GPU_IndexOp<string mnemonic, list<Trait> traits = []> :
let assemblyFormat = "$dimension attr-dict";
}
+def GPU_ClusterDimOp : GPU_IndexOp<"cluster_dim"> {
+ let description = [{
+ Returns the number of thread blocks in the cluster along
+ the x, y, or z `dimension`.
+
+ Example:
+
+ ```mlir
+ %cDimX = gpu.cluster_dim x
+ ```
+ }];
+}
+
+def GPU_ClusterIdOp : GPU_IndexOp<"cluster_id"> {
+ let description = [{
+ Returns the cluster id, i.e. the index of the current cluster within the
+ grid along the x, y, or z `dimension`.
+
+ Example:
+
+ ```mlir
+ %cIdY = gpu.cluster_id y
+ ```
+ }];
+}
+
def GPU_BlockDimOp : GPU_IndexOp<"block_dim"> {
let description = [{
Returns the number of threads in the thread block (aka the block size) along
@@ -441,8 +467,15 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
"blockSizeY", "blockSizeZ"]>]>,
Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
SymbolRefAttr:$kernel,
- LaunchIndx:$gridSizeX, LaunchIndx:$gridSizeY, LaunchIndx:$gridSizeZ,
- LaunchIndx:$blockSizeX, LaunchIndx:$blockSizeY, LaunchIndx:$blockSizeZ,
+ LaunchIndx:$gridSizeX,
+ LaunchIndx:$gridSizeY,
+ LaunchIndx:$gridSizeZ,
+ LaunchIndx:$blockSizeX,
+ LaunchIndx:$blockSizeY,
+ LaunchIndx:$blockSizeZ,
+ Optional<LaunchIndx>:$clusterSizeX,
+ Optional<LaunchIndx>:$clusterSizeY,
+ Optional<LaunchIndx>:$clusterSizeZ,
Optional<I32>:$dynamicSharedMemorySize,
Variadic<AnyType>:$kernelOperands,
Optional<AnyType>:$asyncObject)>,
@@ -480,6 +513,12 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
The remaining operands if present are passed as arguments to the kernel
function.
+ The `gpu.launch_func` also supports kernel launching with clusters if
+ supported by the target architecture. The cluster size can be set by
+ `clusterSizeX`, `clusterSizeY`, and `clusterSizeZ` arguments. When these
+ arguments are present, the Op launches a kernel that clusters the given
+ thread blocks. This feature is exclusive to certain architectures.
+
Example:
```mlir
@@ -509,6 +548,15 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
%gDimY = gpu.grid_dim y
%gDimZ = gpu.grid_dim z
+ // (Optional) Cluster size only for support architectures
+ %cIdX = gpu.cluster_id x
+ %cIdY = gpu.cluster_id y
+ %cIdZ = gpu.cluster_id z
+
+ %cDimX = gpu.cluster_dim x
+ %cDimY = gpu.cluster_dim y
+ %cDimZ = gpu.cluster_dim z
+
"some_op"(%bx, %tx) : (index, index) -> ()
%42 = load %arg1[%bx] : memref<?xf32, 1>
}
@@ -519,6 +567,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
async // (Optional) Don't block host, return token.
[%t0] // (Optional) Execute only after %t0 has completed.
@kernels::@kernel_1 // Kernel function.
+ clusters in (%cst, %cst, %cst) // (Optional) Cluster size only for support architectures.
blocks in (%cst, %cst, %cst) // Grid size.
threads in (%cst, %cst, %cst) // Block size.
dynamic_shared_memory_size %s // (Optional) Amount of dynamic shared
@@ -536,11 +585,13 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
"KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
"ValueRange":$kernelOperands,
CArg<"Type", "nullptr">:$asyncTokenType,
- CArg<"ValueRange", "{}">:$asyncDependencies)>,
+ CArg<"ValueRange", "{}">:$asyncDependencies,
+ CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize)>,
OpBuilder<(ins "SymbolRefAttr":$kernel, "KernelDim3":$gridSize,
"KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
"ValueRange":$kernelOperands,
- CArg<"Value", "nullptr">:$asyncObject)>
+ CArg<"Value", "nullptr">:$asyncObject,
+ CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize)>
];
let extraClassDeclaration = [{
@@ -550,12 +601,23 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
/// The name of the kernel.
StringAttr getKernelName();
+ /// Has cluster
+ bool hasClusterSize() {
+ auto totalSize = getOperands().size();
+ totalSize -= getKernelOperands().size();
+ totalSize -= getAsyncDependencies().size();
+ return totalSize > 7;
+ }
+
/// The number of operands passed to the kernel function.
unsigned getNumKernelOperands();
/// The i-th operand passed to the kernel function.
Value getKernelOperand(unsigned i);
+ /// Get the SSA values passed as operands to specify the cluster size.
+ KernelDim3 getClusterSizeOperandValues();
+
/// Get the SSA values passed as operands to specify the grid size.
KernelDim3 getGridSizeOperandValues();
@@ -571,10 +633,11 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
(`<` $asyncObject^ `:` type($asyncObject) `>`)?
- $kernel
+ $kernel
+ ( `clusters` `in` ` ` `(` $clusterSizeX^ `,` $clusterSizeY `,` $clusterSizeZ `)` )?
`blocks` `in` ` ` `(` $gridSizeX `,` $gridSizeY `,` $gridSizeZ `)`
`threads` `in` ` ` `(` $blockSizeX `,` $blockSizeY `,` $blockSizeZ `)`
- custom<LaunchDimType>(type($gridSizeX))
+ custom<LaunchDimType>(type($gridSizeX), ref($clusterSizeX), type($clusterSizeX), type($clusterSizeY), type($clusterSizeZ))
(`dynamic_shared_memory_size` $dynamicSharedMemorySize^)?
custom<LaunchFuncOperands>($kernelOperands, type($kernelOperands)) attr-dict
}];
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 7bac8f5a8f0e03b..381d5100a7a3fdc 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -121,6 +121,26 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
llvmPointerType, /* void **extra */
llvmInt64Type /* size_t paramsCount */
}};
+ FunctionCallBuilder launchClusterKernelCallBuilder = {
+ "mgpuLaunchClusterKernel",
+ llvmVoidType,
+ {
+ llvmPointerType, /* void* f */
+ llvmIntPtrType, /* intptr_t clusterXDim */
+ llvmIntPtrType, /* intptr_t clusteryDim */
+ llvmIntPtrType, /* intptr_t clusterZDim */
+ llvmIntPtrType, /* intptr_t gridXDim */
+ llvmIntPtrType, /* intptr_t gridyDim */
+ llvmIntPtrType, /* intptr_t gridZDim */
+ llvmIntPtrType, /* intptr_t blockXDim */
+ llvmIntPtrType, /* intptr_t blockYDim */
+ llvmIntPtrType, /* intptr_t blockZDim */
+ llvmInt32Type, /* unsigned int sharedMemBytes */
+ llvmPointerType, /* void *hstream */
+ llvmPointerType, /* void **kernelParams */
+ llvmPointerType, /* void **extra */
+ llvmInt64Type /* size_t paramsCount */
+ }};
FunctionCallBuilder streamCreateCallBuilder = {
"mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
FunctionCallBuilder streamDestroyCallBuilder = {
@@ -1128,13 +1148,19 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(),
rewriter, /*useBarePtrCallConv=*/kernelBarePtrCallConv);
+ std::optional<gpu::KernelDim3> clusterSize = std::nullopt;
+ if (launchOp.hasClusterSize()) {
+ clusterSize =
+ gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(),
+ adaptor.getClusterSizeZ()};
+ }
rewriter.create<gpu::LaunchFuncOp>(
launchOp.getLoc(), launchOp.getKernelAttr(),
gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(),
adaptor.getGridSizeZ()},
gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
adaptor.getBlockSizeZ()},
- adaptor.getDynamicSharedMemorySize(), arguments, stream);
+ adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize);
if (launchOp.getAsyncToken())
rewriter.replaceOp(launchOp, {stream});
else
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 935e3d2a4095003..5b353b0c3bcbd76 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -313,17 +313,20 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
populateWithGenerated(patterns);
patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
- patterns
- .add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
- NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
- GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
- NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
- GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
- NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
- GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
- NVVM::GridDimYOp, NVVM::GridDimZOp>,
- GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
- converter);
+ patterns.add<
+ GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
+ NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
+ NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
+ NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::ClusterDimOp, NVVM::ClusterDimXOp,
+ NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
+ NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
+ GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
+ NVVM::GridDimYOp, NVVM::GridDimZOp>,
+ GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index e0a2b93df3d1fd6..156a993131ad379 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -983,7 +983,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
GPUFuncOp kernelFunc, KernelDim3 gridSize,
KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
ValueRange kernelOperands, Type asyncTokenType,
- ValueRange asyncDependencies) {
+ ValueRange asyncDependencies,
+ std::optional<KernelDim3> clusterSize) {
result.addOperands(asyncDependencies);
if (asyncTokenType)
result.types.push_back(builder.getType<AsyncTokenType>());
@@ -991,6 +992,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
// Add grid and block sizes as op operands, followed by the data operands.
result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
getBlockSize.y, getBlockSize.z});
+ if (clusterSize.has_value())
+ result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
if (dynamicSharedMemorySize)
result.addOperands(dynamicSharedMemorySize);
result.addOperands(kernelOperands);
@@ -1006,6 +1009,11 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
for (auto &sz : prop.operandSegmentSizes)
sz = 1;
prop.operandSegmentSizes[0] = asyncDependencies.size();
+ if (!clusterSize.has_value()) {
+ prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
+ prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
+ prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
+ }
prop.operandSegmentSizes[segmentSizesLen - 3] =
dynamicSharedMemorySize ? 1 : 0;
prop.operandSegmentSizes[segmentSizesLen - 2] =
@@ -1016,10 +1024,13 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
SymbolRefAttr kernel, KernelDim3 gridSize,
KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
- ValueRange kernelOperands, Value asyncObject) {
+ ValueRange kernelOperands, Value asyncObject,
+ std::optional<KernelDim3> clusterSize) {
// Add grid and block sizes as op operands, followed by the data operands.
result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
getBlockSize.y, getBlockSize.z});
+ if (clusterSize.has_value())
+ result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
if (dynamicSharedMemorySize)
result.addOperands(dynamicSharedMemorySize);
result.addOperands(kernelOperands);
@@ -1032,6 +1043,11 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
for (auto &sz : prop.operandSegmentSizes)
sz = 1;
prop.operandSegmentSizes[0] = 0;
+ if (!clusterSize.has_value()) {
+ prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
+ prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
+ prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
+ }
prop.operandSegmentSizes[segmentSizesLen - 3] =
dynamicSharedMemorySize ? 1 : 0;
prop.operandSegmentSizes[segmentSizesLen - 2] =
@@ -1065,6 +1081,11 @@ KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
return KernelDim3{operands[3], operands[4], operands[5]};
}
+KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
+ auto operands = getOperands().drop_front(getAsyncDependencies().size());
+ return KernelDim3{operands[6], operands[7], operands[8]};
+}
+
LogicalResult LaunchFuncOp::verify() {
auto module = (*this)->getParentOfType<ModuleOp>();
if (!module)
@@ -1076,21 +1097,35 @@ LogicalResult LaunchFuncOp::verify() {
GPUDialect::getContainerModuleAttrName() +
"' attribute");
+ if (getClusterSizeX()) {
+ if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
+ getClusterSizeZ().getType() != getClusterSizeX().getType())
+ return emitOpError()
+ << "expects types of the cluster dimensions must be the same";
+ }
+
return success();
}
-static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy) {
+static ParseResult
+parseLaunchDimType(OpAsmParser &parser, Type &dimTy,
+ std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
+ Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
if (succeeded(parser.parseOptionalColon())) {
if (parser.parseType(dimTy))
return failure();
} else {
dimTy = IndexType::get(parser.getContext());
}
+ if (clusterValue.has_value()) {
+ clusterXTy = clusterYTy = clusterZTy = dimTy;
+ }
return success();
}
-static void printLaunchDimType(OpAsmPrinter &printer, Operation *op,
- Type dimTy) {
+static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
+ Value clusterValue, Type clusterXTy,
+ Type clusterYTy, Type clusterZTy) {
if (!dimTy.isIndex())
printer << ": " << dimTy;
}
diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
index cb2d66d5b0d32da..69017efb9a0e67c 100644
--- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp
@@ -19,6 +19,8 @@ using namespace mlir::gpu;
// Maximum grid and block dimensions of all known GPUs are less than 2^32.
static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
+// Maximum cluster size
+static constexpr uint64_t kMaxClusterDim = 8;
// Maximum subgroups are no larger than 128.
static constexpr uint64_t kMaxSubgroupSize = 128;
@@ -82,6 +84,17 @@ static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
return std::nullopt;
}
+void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+ SetIntRangeFn setResultRange) {
+ setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
+}
+
+void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
+ SetIntRangeFn setResultRange) {
+ uint64_t max = kMaxClusterDim;
+ setResultRange(getResult(), getIndexRange(0, max - 1ULL));
+}
+
void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) {
std::optional<uint64_t> knownVal =
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index a8e743c519135f7..9b63d2a22a7a31f 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -194,6 +194,60 @@ mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY,
extra));
}
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
+ CUfunction function, intptr_t clusterX, intptr_t clusterY,
+ intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
+ intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem,
+ CUstream stream, void **params, void **extra, size_t /*paramsCount*/) {
+ ScopedContext scopedContext;
+ if (smem > 0) {
+ // Avoid checking driver as it's more expensive than if statement
+ int32_t maxShmem = 0;
+ CUdevice device = getDefaultCuDevice();
+ CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
+ CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
+ &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
+ device));
+ if (maxShmem < smem) {
+ fprintf(stderr,
+ "Requested shared memory (%dkb) is larger than maximum allowed "
+ "shared memory (%dkb) for this device\n",
+ smem, maxShmem);
+ }
+ CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
+ function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
+ }
+ CUlaunchConfig config;
+ config.gridDimX = gridX;
+ config.gridDimY = gridY;
+ config.gridDimZ = gridZ;
+ config.blockDimX = blockX;
+ config.blockDimY = blockY;
+ config.blockDimZ = blockZ;
+ config.sharedMemBytes = smem;
+ config.hStream = stream;
+ CUlaunchAttribute launchAttr[2];
+ launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
+ launchAttr[0].value.clusterDim.x = clusterX;
+ launchAttr[0].value.clusterDim.y = clusterY;
+ launchAttr[0].value.clusterDim.z = clusterZ;
+ launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
+ launchAttr[1].value.clusterSchedulingPolicyPreference =
+ CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
+ config.numAttrs = 2;
+ config.attrs = launchAttr;
+
+ debug_print("Launching kernel,"
+ "cluster: %ld, %ld, %ld, "
+ "grid=%ld,%ld,%ld, "
+ "threads: %ld, %ld, %ld, "
+ "smem: %dkb\n...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/72871
More information about the Mlir-commits
mailing list