[llvm] [Offload] Allow CUDA Kernels to use arbitrarily large shared memory (PR #145963)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 26 13:26:58 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-offload

Author: Giorgi Gvalia (gvalson)

<details>
<summary>Changes</summary>

Previously, the user was not able to use more than 48 KB of shared memory on NVIDIA GPUs. In order to do so, setting the function attribute `CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK` is required, which was not present in the code base. With this commit, we add the ability toset this attribute, allowing the user to utilize the full power of their GPU.

In order to not have to reset the function attribute for each launch of the same kernel, we keep track of the maximum memory limit (as the variable `MaxDynCGroupMemLimit`) and only set the attribute if our desired amount exceeds the limit. By default, this limit is set to 48 KB.

Feedback is greatly appreciated, especially around setting the new variable as mutable. I did this becuase the `launchImpl` method is const and I am not able to modify my variable otherwise.

---
Full diff: https://github.com/llvm/llvm-project/pull/145963.diff


3 Files Affected:

- (modified) offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp (+1) 
- (modified) offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h (+2) 
- (modified) offload/plugins-nextgen/cuda/src/rtl.cpp (+14) 


``````````diff
diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
index e5332686fcffb..361a781e8f9b6 100644
--- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
+++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
@@ -31,6 +31,7 @@ DLWRAP(cuDeviceGet, 2)
 DLWRAP(cuDeviceGetAttribute, 3)
 DLWRAP(cuDeviceGetCount, 1)
 DLWRAP(cuFuncGetAttribute, 3)
+DLWRAP(cuFuncSetAttribute, 3)
 
 // Device info
 DLWRAP(cuDeviceGetName, 3)
diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
index 1c5b421768894..b6c022c8e7e8b 100644
--- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
+++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
@@ -258,6 +258,7 @@ typedef enum CUdevice_attribute_enum {
 
 typedef enum CUfunction_attribute_enum {
   CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 0,
+  CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES = 8,
 } CUfunction_attribute;
 
 typedef enum CUctx_flags_enum {
@@ -295,6 +296,7 @@ CUresult cuDeviceGet(CUdevice *, int);
 CUresult cuDeviceGetAttribute(int *, CUdevice_attribute, CUdevice);
 CUresult cuDeviceGetCount(int *);
 CUresult cuFuncGetAttribute(int *, CUfunction_attribute, CUfunction);
+CUresult cuFuncSetAttribute(CUfunction, CUfunction_attribute, int);
 
 // Device info
 CUresult cuDeviceGetName(char *, int, CUdevice);
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 0e662b038c363..fd9528061b55e 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -160,6 +160,9 @@ struct CUDAKernelTy : public GenericKernelTy {
 private:
   /// The CUDA kernel function to execute.
   CUfunction Func;
+  /// The maximum amount of dynamic shared memory per thread group. By default,
+  /// this is set to 48 KB.
+  mutable uint32_t MaxDynCGroupMemLimit = 49152;
 };
 
 /// Class wrapping a CUDA stream reference. These are the objects handled by the
@@ -1300,6 +1303,17 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
   if (GenericDevice.getRPCServer())
     GenericDevice.Plugin.getRPCServer().Thread->notify();
 
+  // In case we require more memory than the current limit.
+  if (MaxDynCGroupMem >= MaxDynCGroupMemLimit) {
+    CUresult AttrResult = cuFuncSetAttribute(
+        Func,
+        CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
+        MaxDynCGroupMem);
+    Plugin::check(AttrResult,
+        "Error in cuLaunchKernel while setting the memory limits: %s");
+    MaxDynCGroupMemLimit = MaxDynCGroupMem;
+  }
+
   CUresult Res = cuLaunchKernel(Func, NumBlocks[0], NumBlocks[1], NumBlocks[2],
                                 NumThreads[0], NumThreads[1], NumThreads[2],
                                 MaxDynCGroupMem, Stream, nullptr, Config);

``````````

</details>


https://github.com/llvm/llvm-project/pull/145963


More information about the llvm-commits mailing list