[llvm] [OpenMP] Adjust 'printf' handling in the OpenMP runtime (PR #123670)
Joseph Huber via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 20 13:29:21 PST 2025
https://github.com/jhuber6 created https://github.com/llvm/llvm-project/pull/123670
Summary:
We used to avoid a lot of this stuff because we didn't properly handle
variadics in device code. That's been solved for now, so we can just
make an internal printf handler that forwards to the external `vprintf`
function. This is either provided by NVIDIA's SDK or by the GPU libc
implementation.
The main reason for doing this is because it prevents the stupid AMDGPU
printf pass from mangling our beautiful printfs!
>From e663bc98d19b80519460788bd5793316259cfc8c Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Mon, 20 Jan 2025 15:20:59 -0600
Subject: [PATCH] [OpenMP] Adjust 'printf' handling in the OpenMP runtime
Summary:
We used to avoid a lot of this stuff because we didn't properly handle
variadics in device code. That's been solved for now, so we can just
make an internal printf handler that forwards to the external `vprintf`
function. This is either provided by NVIDIA's SDK or by the GPU libc
implementation.
The main reason for doing this is because it prevents the stupid AMDGPU
printf pass from mangling our beautiful printfs!
---
offload/DeviceRTL/include/Debug.h | 7 +----
offload/DeviceRTL/include/LibC.h | 9 +++---
offload/DeviceRTL/src/Debug.cpp | 4 +--
offload/DeviceRTL/src/LibC.cpp | 45 +++++++++++----------------
offload/DeviceRTL/src/Parallelism.cpp | 3 +-
offload/DeviceRTL/src/State.cpp | 8 ++---
6 files changed, 32 insertions(+), 44 deletions(-)
diff --git a/offload/DeviceRTL/include/Debug.h b/offload/DeviceRTL/include/Debug.h
index 22998f44a5bea5..98d0fa498d952b 100644
--- a/offload/DeviceRTL/include/Debug.h
+++ b/offload/DeviceRTL/include/Debug.h
@@ -35,15 +35,10 @@ void __assert_fail_internal(const char *expr, const char *msg, const char *file,
__assert_assume(expr); \
}
#define UNREACHABLE(msg) \
- PRINT(msg); \
+ printf(msg); \
__builtin_trap(); \
__builtin_unreachable();
///}
-#define PRINTF(fmt, ...) (void)printf(fmt, ##__VA_ARGS__);
-#define PRINT(str) PRINTF("%s", str)
-
-///}
-
#endif
diff --git a/offload/DeviceRTL/include/LibC.h b/offload/DeviceRTL/include/LibC.h
index 03febdb5083423..94b5e651960674 100644
--- a/offload/DeviceRTL/include/LibC.h
+++ b/offload/DeviceRTL/include/LibC.h
@@ -14,11 +14,10 @@
#include "DeviceTypes.h"
-extern "C" {
+namespace ompx {
-int memcmp(const void *lhs, const void *rhs, size_t count);
-void memset(void *dst, int C, size_t count);
-int printf(const char *format, ...);
-}
+int printf(const char *Format, ...);
+
+} // namespace ompx
#endif
diff --git a/offload/DeviceRTL/src/Debug.cpp b/offload/DeviceRTL/src/Debug.cpp
index b451f17c6bbd89..1d9c9628854222 100644
--- a/offload/DeviceRTL/src/Debug.cpp
+++ b/offload/DeviceRTL/src/Debug.cpp
@@ -36,10 +36,10 @@ void __assert_assume(bool condition) { __builtin_assume(condition); }
void __assert_fail_internal(const char *expr, const char *msg, const char *file,
unsigned line, const char *function) {
if (msg) {
- PRINTF("%s:%u: %s: Assertion %s (`%s`) failed.\n", file, line, function,
+ printf("%s:%u: %s: Assertion %s (`%s`) failed.\n", file, line, function,
msg, expr);
} else {
- PRINTF("%s:%u: %s: Assertion `%s` failed.\n", file, line, function, expr);
+ printf("%s:%u: %s: Assertion `%s` failed.\n", file, line, function, expr);
}
__builtin_trap();
}
diff --git a/offload/DeviceRTL/src/LibC.cpp b/offload/DeviceRTL/src/LibC.cpp
index 291ceb023a69c5..e55008f46269fe 100644
--- a/offload/DeviceRTL/src/LibC.cpp
+++ b/offload/DeviceRTL/src/LibC.cpp
@@ -10,32 +10,11 @@
#pragma omp begin declare target device_type(nohost)
-namespace impl {
-int32_t omp_vprintf(const char *Format, __builtin_va_list vlist);
-}
-
-#ifndef OMPTARGET_HAS_LIBC
-namespace impl {
-#pragma omp begin declare variant match( \
- device = {arch(nvptx, nvptx64)}, \
- implementation = {extension(match_any)})
-extern "C" int vprintf(const char *format, ...);
-int omp_vprintf(const char *Format, __builtin_va_list vlist) {
- return vprintf(Format, vlist);
-}
-#pragma omp end declare variant
-
-#pragma omp begin declare variant match(device = {arch(amdgcn)})
-int omp_vprintf(const char *Format, __builtin_va_list) { return -1; }
-#pragma omp end declare variant
-} // namespace impl
-
-extern "C" int printf(const char *Format, ...) {
- __builtin_va_list vlist;
- __builtin_va_start(vlist, Format);
- return impl::omp_vprintf(Format, vlist);
-}
-#endif // OMPTARGET_HAS_LIBC
+#if defined(__AMDGPU__) && !defined(OMPTARGET_HAS_LIBC)
+extern "C" int vprintf(const char *format, __builtin_va_list) { return -1; }
+#else
+extern "C" int vprintf(const char *format, __builtin_va_list);
+#endif
extern "C" {
[[gnu::weak]] int memcmp(const void *lhs, const void *rhs, size_t count) {
@@ -54,6 +33,20 @@ extern "C" {
for (size_t I = 0; I < count; ++I)
dstc[I] = C;
}
+
+[[gnu::weak]] int printf(const char *Format, ...) {
+ __builtin_va_list vlist;
+ __builtin_va_start(vlist, Format);
+ return ::vprintf(Format, vlist);
+}
+}
+
+namespace ompx {
+[[clang::no_builtin("printf")]] int printf(const char *Format, ...) {
+ __builtin_va_list vlist;
+ __builtin_va_start(vlist, Format);
+ return ::vprintf(Format, vlist);
}
+} // namespace ompx
#pragma omp end declare target
diff --git a/offload/DeviceRTL/src/Parallelism.cpp b/offload/DeviceRTL/src/Parallelism.cpp
index 5286d53b623f0a..a87e363349b1e5 100644
--- a/offload/DeviceRTL/src/Parallelism.cpp
+++ b/offload/DeviceRTL/src/Parallelism.cpp
@@ -36,6 +36,7 @@
#include "DeviceTypes.h"
#include "DeviceUtils.h"
#include "Interface.h"
+#include "LibC.h"
#include "Mapping.h"
#include "State.h"
#include "Synchronization.h"
@@ -74,7 +75,7 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
switch (nargs) {
#include "generated_microtask_cases.gen"
default:
- PRINT("Too many arguments in kmp_invoke_microtask, aborting execution.\n");
+ printf("Too many arguments in kmp_invoke_microtask, aborting execution.\n");
__builtin_trap();
}
}
diff --git a/offload/DeviceRTL/src/State.cpp b/offload/DeviceRTL/src/State.cpp
index 855c74fa58e0a5..100bc8ab47983c 100644
--- a/offload/DeviceRTL/src/State.cpp
+++ b/offload/DeviceRTL/src/State.cpp
@@ -138,8 +138,8 @@ void *SharedMemorySmartStackTy::push(uint64_t Bytes) {
}
if (config::isDebugMode(DeviceDebugKind::CommonIssues))
- PRINT("Shared memory stack full, fallback to dynamic allocation of global "
- "memory will negatively impact performance.\n");
+ printf("Shared memory stack full, fallback to dynamic allocation of global "
+ "memory will negatively impact performance.\n");
void *GlobalMemory = memory::allocGlobal(
AlignedBytes, "Slow path shared memory allocation, insufficient "
"shared memory stack memory!");
@@ -173,7 +173,7 @@ void memory::freeShared(void *Ptr, uint64_t Bytes, const char *Reason) {
void *memory::allocGlobal(uint64_t Bytes, const char *Reason) {
void *Ptr = malloc(Bytes);
if (config::isDebugMode(DeviceDebugKind::CommonIssues) && Ptr == nullptr)
- PRINT("nullptr returned by malloc!\n");
+ printf("nullptr returned by malloc!\n");
return Ptr;
}
@@ -277,7 +277,7 @@ void state::enterDataEnvironment(IdentTy *Ident) {
sizeof(ThreadStates[0]) * mapping::getNumberOfThreadsInBlock();
void *ThreadStatesPtr =
memory::allocGlobal(Bytes, "Thread state array allocation");
- memset(ThreadStatesPtr, 0, Bytes);
+ __builtin_memset(ThreadStatesPtr, 0, Bytes);
if (!atomic::cas(ThreadStatesBitsPtr, uintptr_t(0),
reinterpret_cast<uintptr_t>(ThreadStatesPtr),
atomic::seq_cst, atomic::seq_cst))
More information about the llvm-commits
mailing list