[flang-commits] [flang] Revert "[flang][cuda] Use cuda runtime API" (PR #104232)
via flang-commits
flang-commits at lists.llvm.org
Wed Aug 14 13:45:24 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-runtime
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Reverts llvm/llvm-project#<!-- -->103488
---
Full diff: https://github.com/llvm/llvm-project/pull/104232.diff
4 Files Affected:
- (modified) flang/include/flang/Runtime/CUDA/allocator.h (+4-3)
- (modified) flang/runtime/CUDA/CMakeLists.txt (+2-8)
- (modified) flang/runtime/CUDA/allocator.cpp (+14-10)
- (modified) flang/unittests/Runtime/CUDA/AllocatorCUF.cpp (+32-1)
``````````diff
diff --git a/flang/include/flang/Runtime/CUDA/allocator.h b/flang/include/flang/Runtime/CUDA/allocator.h
index 4527c9f18fa054..f0bfc1548e6458 100644
--- a/flang/include/flang/Runtime/CUDA/allocator.h
+++ b/flang/include/flang/Runtime/CUDA/allocator.h
@@ -13,10 +13,11 @@
#include "flang/Runtime/entry-names.h"
#define CUDA_REPORT_IF_ERROR(expr) \
- [](cudaError_t err) { \
- if (err == cudaSuccess) \
+ [](CUresult result) { \
+ if (!result) \
return; \
- const char *name = cudaGetErrorName(err); \
+ const char *name = nullptr; \
+ cuGetErrorName(result, &name); \
if (!name) \
name = "<unknown>"; \
Terminator terminator{__FILE__, __LINE__}; \
diff --git a/flang/runtime/CUDA/CMakeLists.txt b/flang/runtime/CUDA/CMakeLists.txt
index 53c5b8823c56b0..88243536139e46 100644
--- a/flang/runtime/CUDA/CMakeLists.txt
+++ b/flang/runtime/CUDA/CMakeLists.txt
@@ -7,20 +7,14 @@
#===------------------------------------------------------------------------===#
include_directories(${CUDAToolkit_INCLUDE_DIRS})
+find_library(CUDA_RUNTIME_LIBRARY cuda HINTS ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES} REQUIRED)
add_flang_library(CufRuntime
allocator.cpp
descriptor.cpp
)
-
-if (BUILD_SHARED_LIBS)
- set(CUF_LIBRARY ${CUDA_LIBRARIES})
-else()
- set(CUF_LIBRARY ${CUDA_cudart_static_LIBRARY})
-endif()
-
target_link_libraries(CufRuntime
PRIVATE
FortranRuntime
- ${CUF_LIBRARY}
+ ${CUDA_RUNTIME_LIBRARY}
)
diff --git a/flang/runtime/CUDA/allocator.cpp b/flang/runtime/CUDA/allocator.cpp
index d4a473d58e86cd..bd657b800c61e8 100644
--- a/flang/runtime/CUDA/allocator.cpp
+++ b/flang/runtime/CUDA/allocator.cpp
@@ -15,7 +15,7 @@
#include "flang/ISO_Fortran_binding_wrapper.h"
#include "flang/Runtime/allocator-registry.h"
-#include "cuda_runtime.h"
+#include "cuda.h"
namespace Fortran::runtime::cuda {
extern "C" {
@@ -34,28 +34,32 @@ void RTDEF(CUFRegisterAllocator)() {
void *CUFAllocPinned(std::size_t sizeInBytes) {
void *p;
- CUDA_REPORT_IF_ERROR(cudaMallocHost((void **)&p, sizeInBytes));
+ CUDA_REPORT_IF_ERROR(cuMemAllocHost(&p, sizeInBytes));
return p;
}
-void CUFFreePinned(void *p) { CUDA_REPORT_IF_ERROR(cudaFreeHost(p)); }
+void CUFFreePinned(void *p) { CUDA_REPORT_IF_ERROR(cuMemFreeHost(p)); }
void *CUFAllocDevice(std::size_t sizeInBytes) {
- void *p;
- CUDA_REPORT_IF_ERROR(cudaMalloc(&p, sizeInBytes));
- return p;
+ CUdeviceptr p = 0;
+ CUDA_REPORT_IF_ERROR(cuMemAlloc(&p, sizeInBytes));
+ return reinterpret_cast<void *>(p);
}
-void CUFFreeDevice(void *p) { CUDA_REPORT_IF_ERROR(cudaFree(p)); }
+void CUFFreeDevice(void *p) {
+ CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(p)));
+}
void *CUFAllocManaged(std::size_t sizeInBytes) {
- void *p;
+ CUdeviceptr p = 0;
CUDA_REPORT_IF_ERROR(
- cudaMallocManaged((void **)&p, sizeInBytes, cudaMemAttachGlobal));
+ cuMemAllocManaged(&p, sizeInBytes, CU_MEM_ATTACH_GLOBAL));
return reinterpret_cast<void *>(p);
}
-void CUFFreeManaged(void *p) { CUDA_REPORT_IF_ERROR(cudaFree(p)); }
+void CUFFreeManaged(void *p) {
+ CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(p)));
+}
void *CUFAllocUnified(std::size_t sizeInBytes) {
// Call alloc managed for the time being.
diff --git a/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp b/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp
index b51ff0ac006cc6..9f5ec289ee8f74 100644
--- a/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp
+++ b/flang/unittests/Runtime/CUDA/AllocatorCUF.cpp
@@ -14,7 +14,7 @@
#include "flang/Runtime/allocatable.h"
#include "flang/Runtime/allocator-registry.h"
-#include "cuda_runtime.h"
+#include "cuda.h"
using namespace Fortran::runtime;
using namespace Fortran::runtime::cuda;
@@ -25,9 +25,38 @@ static OwningPtr<Descriptor> createAllocatable(
CFI_attribute_allocatable);
}
+thread_local static int32_t defaultDevice = 0;
+
+CUdevice getDefaultCuDevice() {
+ CUdevice device;
+ CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
+ return device;
+}
+
+class ScopedContext {
+public:
+ ScopedContext() {
+ // Static reference to CUDA primary context for device ordinal
+ // defaultDevice.
+ static CUcontext context = [] {
+ CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
+ CUcontext ctx;
+ // Note: this does not affect the current context.
+ CUDA_REPORT_IF_ERROR(
+ cuDevicePrimaryCtxRetain(&ctx, getDefaultCuDevice()));
+ return ctx;
+ }();
+
+ CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context));
+ }
+
+ ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); }
+};
+
TEST(AllocatableCUFTest, SimpleDeviceAllocate) {
using Fortran::common::TypeCategory;
RTNAME(CUFRegisterAllocator)();
+ ScopedContext ctx;
// REAL(4), DEVICE, ALLOCATABLE :: a(:)
auto a{createAllocatable(TypeCategory::Real, 4)};
a->SetAllocIdx(kDeviceAllocatorPos);
@@ -45,6 +74,7 @@ TEST(AllocatableCUFTest, SimpleDeviceAllocate) {
TEST(AllocatableCUFTest, SimplePinnedAllocate) {
using Fortran::common::TypeCategory;
RTNAME(CUFRegisterAllocator)();
+ ScopedContext ctx;
// INTEGER(4), PINNED, ALLOCATABLE :: a(:)
auto a{createAllocatable(TypeCategory::Integer, 4)};
EXPECT_FALSE(a->HasAddendum());
@@ -63,6 +93,7 @@ TEST(AllocatableCUFTest, SimplePinnedAllocate) {
TEST(AllocatableCUFTest, DescriptorAllocationTest) {
using Fortran::common::TypeCategory;
RTNAME(CUFRegisterAllocator)();
+ ScopedContext ctx;
// REAL(4), DEVICE, ALLOCATABLE :: a(:)
auto a{createAllocatable(TypeCategory::Real, 4)};
Descriptor *desc = nullptr;
``````````
</details>
https://github.com/llvm/llvm-project/pull/104232
More information about the flang-commits
mailing list