[flang-commits] [flang] 6b21cf8 - [flang][cuda] Compute grid x when calling a kernel with <<<*, block>>> (#115538)
via flang-commits
flang-commits at lists.llvm.org
Fri Nov 8 14:34:29 PST 2024
Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-11-08T14:34:26-08:00
New Revision: 6b21cf8ccad84e2670e458d8bdaccbd0ae37b46b
URL: https://github.com/llvm/llvm-project/commit/6b21cf8ccad84e2670e458d8bdaccbd0ae37b46b
DIFF: https://github.com/llvm/llvm-project/commit/6b21cf8ccad84e2670e458d8bdaccbd0ae37b46b.diff
LOG: [flang][cuda] Compute grid x when calling a kernel with <<<*, block>>> (#115538)
`-1, 1, 1` is passed when calling a kernel with the `<<<*, block>>>`
syntax. Query the device to compute the grid.x value.
Added:
Modified:
flang/runtime/CUDA/kernel.cpp
Removed:
################################################################################
diff --git a/flang/runtime/CUDA/kernel.cpp b/flang/runtime/CUDA/kernel.cpp
index abb7ebb72e5923..88cdf3cf426229 100644
--- a/flang/runtime/CUDA/kernel.cpp
+++ b/flang/runtime/CUDA/kernel.cpp
@@ -25,6 +25,55 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
blockDim.x = blockX;
blockDim.y = blockY;
blockDim.z = blockZ;
+ unsigned nbNegGridDim{0};
+ if (gridX < 0) {
+ ++nbNegGridDim;
+ }
+ if (gridY < 0) {
+ ++nbNegGridDim;
+ }
+ if (gridZ < 0) {
+ ++nbNegGridDim;
+ }
+ if (nbNegGridDim == 1) {
+ int maxBlocks, nbBlocks, dev, multiProcCount;
+ cudaError_t err1, err2;
+ nbBlocks = blockDim.x * blockDim.y * blockDim.z;
+ cudaGetDevice(&dev);
+ err1 = cudaDeviceGetAttribute(
+ &multiProcCount, cudaDevAttrMultiProcessorCount, dev);
+ err2 = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+ &maxBlocks, kernel, nbBlocks, smem);
+ if (err1 == cudaSuccess && err2 == cudaSuccess) {
+ maxBlocks = multiProcCount * maxBlocks;
+ }
+ if (maxBlocks > 0) {
+ if (gridDim.x > 0) {
+ maxBlocks = maxBlocks / gridDim.x;
+ }
+ if (gridDim.y > 0) {
+ maxBlocks = maxBlocks / gridDim.y;
+ }
+ if (gridDim.z > 0) {
+ maxBlocks = maxBlocks / gridDim.z;
+ }
+ if (maxBlocks < 1) {
+ maxBlocks = 1;
+ }
+ if (gridX < 0) {
+ gridDim.x = maxBlocks;
+ }
+ if (gridY < 0) {
+ gridDim.y = maxBlocks;
+ }
+ if (gridZ < 0) {
+ gridDim.z = maxBlocks;
+ }
+ }
+ } else if (nbNegGridDim > 1) {
+ Fortran::runtime::Terminator terminator{__FILE__, __LINE__};
+ terminator.Crash("Too many invalid grid dimensions");
+ }
cudaStream_t stream = 0; // TODO stream managment
CUDA_REPORT_IF_ERROR(
cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, stream));
@@ -41,6 +90,55 @@ void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
config.blockDim.x = blockX;
config.blockDim.y = blockY;
config.blockDim.z = blockZ;
+ unsigned nbNegGridDim{0};
+ if (gridX < 0) {
+ ++nbNegGridDim;
+ }
+ if (gridY < 0) {
+ ++nbNegGridDim;
+ }
+ if (gridZ < 0) {
+ ++nbNegGridDim;
+ }
+ if (nbNegGridDim == 1) {
+ int maxBlocks, nbBlocks, dev, multiProcCount;
+ cudaError_t err1, err2;
+ nbBlocks = config.blockDim.x * config.blockDim.y * config.blockDim.z;
+ cudaGetDevice(&dev);
+ err1 = cudaDeviceGetAttribute(
+ &multiProcCount, cudaDevAttrMultiProcessorCount, dev);
+ err2 = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+ &maxBlocks, kernel, nbBlocks, smem);
+ if (err1 == cudaSuccess && err2 == cudaSuccess) {
+ maxBlocks = multiProcCount * maxBlocks;
+ }
+ if (maxBlocks > 0) {
+ if (config.gridDim.x > 0) {
+ maxBlocks = maxBlocks / config.gridDim.x;
+ }
+ if (config.gridDim.y > 0) {
+ maxBlocks = maxBlocks / config.gridDim.y;
+ }
+ if (config.gridDim.z > 0) {
+ maxBlocks = maxBlocks / config.gridDim.z;
+ }
+ if (maxBlocks < 1) {
+ maxBlocks = 1;
+ }
+ if (gridX < 0) {
+ config.gridDim.x = maxBlocks;
+ }
+ if (gridY < 0) {
+ config.gridDim.y = maxBlocks;
+ }
+ if (gridZ < 0) {
+ config.gridDim.z = maxBlocks;
+ }
+ }
+ } else if (nbNegGridDim > 1) {
+ Fortran::runtime::Terminator terminator{__FILE__, __LINE__};
+ terminator.Crash("Too many invalid grid dimensions");
+ }
config.dynamicSmemBytes = smem;
config.stream = 0; // TODO stream managment
cudaLaunchAttribute launchAttr[1];
More information about the flang-commits
mailing list