[flang-commits] [flang] [flang][cuda] Use cuda runtime API (PR #103488)
via flang-commits
flang-commits at lists.llvm.org
Tue Aug 13 16:53:57 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
CUDA Fortran is meant to be an equivalent to the runtime API Therefore, it makes more sense to use the cuda rt API in the allocators for CUF.
---
Full diff: https://github.com/llvm/llvm-project/pull/103488.diff
4 Files Affected:
- (modified) flang/include/flang/Runtime/CUDA/allocator.h (+3-4)
- (modified) flang/runtime/CUDA/CMakeLists.txt (+8-2)
- (modified) flang/runtime/CUDA/allocator.cpp (+10-14)
- (modified) flang/unittests/Runtime/CUDA/AllocatorCUF.cpp (+1-32)
``````````diff
diff --git a/flang/include/flang/Runtime/CUDA/allocator.h b/flang/include/flang/Runtime/CUDA/allocator.h
index f0bfc1548e6458..4527c9f18fa054 100644
--- a/flang/include/flang/Runtime/CUDA/allocator.h
+++ b/flang/include/flang/Runtime/CUDA/allocator.h
@@ -13,11 +13,10 @@
#include "flang/Runtime/entry-names.h"
#define CUDA_REPORT_IF_ERROR(expr) \
- [](CUresult result) { \
- if (!result) \
+ [](cudaError_t err) { \
+ if (err == cudaSuccess) \
return; \
- const char *name = nullptr; \
- cuGetErrorName(result, &name); \
+ const char *name = cudaGetErrorName(err); \
if (!name) \
name = "<unknown>"; \
Terminator terminator{__FILE__, __LINE__}; \
diff --git a/flang/runtime/CUDA/CMakeLists.txt b/flang/runtime/CUDA/CMakeLists.txt
index 88243536139e46..53c5b8823c56b0 100644
--- a/flang/runtime/CUDA/CMakeLists.txt
+++ b/flang/runtime/CUDA/CMakeLists.txt
@@ -7,14 +7,20 @@
#===------------------------------------------------------------------------===#
include_directories(${CUDAToolkit_INCLUDE_DIRS})
-find_library(CUDA_RUNTIME_LIBRARY cuda HINTS ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES} REQUIRED)
add_flang_library(CufRuntime
allocator.cpp
descriptor.cpp
)
+
+if (BUILD_SHARED_LIBS)
+ set(CUF_LIBRARY ${CUDA_LIBRARIES})
+else()
+ set(CUF_LIBRARY ${CUDA_cudart_static_LIBRARY})
+endif()
+
target_link_libraries(CufRuntime
PRIVATE
FortranRuntime
- ${CUDA_RUNTIME_LIBRARY}
+ ${CUF_LIBRARY}
)
diff --git a/flang/runtime/CUDA/allocator.cpp b/flang/runtime/CUDA/allocator.cpp
index bd657b800c61e8..d4a473d58e86cd 100644
--- a/flang/runtime/CUDA/allocator.cpp
+++ b/flang/runtime/CUDA/allocator.cpp
@@ -15,7 +15,7 @@
#include "flang/ISO_Fortran_binding_wrapper.h"
#include "flang/Runtime/allocator-registry.h"
-#include "cuda.h"
+#include "cuda_runtime.h"
namespace Fortran::runtime::cuda {
extern "C" {
@@ -34,32 +34,28 @@ void RTDEF(CUFRegisterAllocator)() {
void *CUFAllocPinned(std::size_t sizeInBytes) {
void *p;
- CUDA_REPORT_IF_ERROR(cuMemAllocHost(&p, sizeInBytes));
+ CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&p, sizeInBytes));
return p;
}
-void CUFFreePinned(void *p) { CUDA_REPORT_IF_ERROR(cuMemFreeHost(p)); }
+void CUFFreePinned(void *p) { CUDA_REPORT_IF_ERROR(cudaFreeHost(p)); }
void *CUFAllocDevice(std::size_t sizeInBytes) {
- CUdeviceptr p = 0;
- CUDA_REPORT_IF_ERROR(cuMemAlloc(&p, sizeInBytes));
- return reinterpret_cast<void *>(p);
+ void *p;
+ CUDA_REPORT_IF_ERROR(cudaMalloc(&p, sizeInBytes));
+ return p;
}
-void CUFFreeDevice(void *p) {
- CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(p)));
-}
+void CUFFreeDevice(void *p) { CUDA_REPORT_IF_ERROR(cudaFree(p)); }
void *CUFAllocManaged(std::size_t sizeInBytes) {
- CUdeviceptr p = 0;
+ void *p;
CUDA_REPORT_IF_ERROR(
- cuMemAllocManaged(&p, sizeInBytes, CU_MEM_ATTACH_GLOBAL));
+ cudaMallocManaged((void **)&p, sizeInBytes, cudaMemAttachGlobal));
return reinterpret_cast<void *>(p);
}
-void CUFFreeManaged(void *p) {
- CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(p)));
-}
+void CUFFreeManaged(void *p) { CUDA_REPORT_IF_ERROR(cudaFree(p)); }
void *CUFAllocUnified(std::size_t sizeInBytes) {
// Call alloc managed for the time being.
diff --git a/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp b/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp
index 9f5ec289ee8f74..b51ff0ac006cc6 100644
--- a/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp
+++ b/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp
@@ -14,7 +14,7 @@
#include "flang/Runtime/allocatable.h"
#include "flang/Runtime/allocator-registry.h"
-#include "cuda.h"
+#include "cuda_runtime.h"
using namespace Fortran::runtime;
using namespace Fortran::runtime::cuda;
@@ -25,38 +25,9 @@ static OwningPtr<Descriptor> createAllocatable(
CFI_attribute_allocatable);
}
-thread_local static int32_t defaultDevice = 0;
-
-CUdevice getDefaultCuDevice() {
- CUdevice device;
- CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
- return device;
-}
-
-class ScopedContext {
-public:
- ScopedContext() {
- // Static reference to CUDA primary context for device ordinal
- // defaultDevice.
- static CUcontext context = [] {
- CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
- CUcontext ctx;
- // Note: this does not affect the current context.
- CUDA_REPORT_IF_ERROR(
- cuDevicePrimaryCtxRetain(&ctx, getDefaultCuDevice()));
- return ctx;
- }();
-
- CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context));
- }
-
- ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); }
-};
-
TEST(AllocatableCUFTest, SimpleDeviceAllocate) {
using Fortran::common::TypeCategory;
RTNAME(CUFRegisterAllocator)();
- ScopedContext ctx;
// REAL(4), DEVICE, ALLOCATABLE :: a(:)
auto a{createAllocatable(TypeCategory::Real, 4)};
a->SetAllocIdx(kDeviceAllocatorPos);
@@ -74,7 +45,6 @@ TEST(AllocatableCUFTest, SimpleDeviceAllocate) {
TEST(AllocatableCUFTest, SimplePinnedAllocate) {
using Fortran::common::TypeCategory;
RTNAME(CUFRegisterAllocator)();
- ScopedContext ctx;
// INTEGER(4), PINNED, ALLOCATABLE :: a(:)
auto a{createAllocatable(TypeCategory::Integer, 4)};
EXPECT_FALSE(a->HasAddendum());
@@ -93,7 +63,6 @@ TEST(AllocatableCUFTest, SimplePinnedAllocate) {
TEST(AllocatableCUFTest, DescriptorAllocationTest) {
using Fortran::common::TypeCategory;
RTNAME(CUFRegisterAllocator)();
- ScopedContext ctx;
// REAL(4), DEVICE, ALLOCATABLE :: a(:)
auto a{createAllocatable(TypeCategory::Real, 4)};
Descriptor *desc = nullptr;
``````````
</details>
https://github.com/llvm/llvm-project/pull/103488
More information about the flang-commits
mailing list