[llvm-branch-commits] [llvm] [openmp] [OpenMP][clang] 6.0: num_threads strict (part 2: device runtime) (PR #146404)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Jun 30 12:02:39 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-offload
Author: Robert Imschweiler (ro-i)
<details>
<summary>Changes</summary>
OpenMP 6.0 12.1.2 specifies the behavior of the strict modifier for the num_threads clause on parallel directives, along with the message and severity clauses. This commit implements necessary device runtime changes.
---
Full diff: https://github.com/llvm/llvm-project/pull/146404.diff
3 Files Affected:
- (modified) offload/DeviceRTL/include/DeviceTypes.h (+6)
- (modified) offload/DeviceRTL/src/Parallelism.cpp (+60-18)
- (modified) openmp/runtime/src/kmp.h (+1)
``````````diff
diff --git a/offload/DeviceRTL/include/DeviceTypes.h b/offload/DeviceRTL/include/DeviceTypes.h
index 2e5d92380f040..111143a5578f1 100644
--- a/offload/DeviceRTL/include/DeviceTypes.h
+++ b/offload/DeviceRTL/include/DeviceTypes.h
@@ -136,6 +136,12 @@ struct omp_lock_t {
void *Lock;
};
+// see definition in openmp/runtime kmp.h
+typedef enum omp_severity_t {
+ severity_warning = 1,
+ severity_fatal = 2
+} omp_severity_t;
+
using InterWarpCopyFnTy = void (*)(void *src, int32_t warp_num);
using ShuffleReductFnTy = void (*)(void *rhsData, int16_t lane_id,
int16_t lane_offset, int16_t shortCircuit);
diff --git a/offload/DeviceRTL/src/Parallelism.cpp b/offload/DeviceRTL/src/Parallelism.cpp
index 08ce616aee1c4..78438a60454b8 100644
--- a/offload/DeviceRTL/src/Parallelism.cpp
+++ b/offload/DeviceRTL/src/Parallelism.cpp
@@ -45,7 +45,24 @@ using namespace ompx;
namespace {
-uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
+void num_threads_strict_error(int32_t nt_strict, int32_t nt_severity,
+ const char *nt_message, int32_t requested,
+ int32_t actual) {
+ if (nt_message)
+ printf("%s\n", nt_message);
+ else
+ printf("The computed number of threads (%u) does not match the requested "
+ "number of threads (%d). Consider that it might not be supported "
+ "to select exactly %d threads on this target device.\n",
+ actual, requested, requested);
+ if (nt_severity == severity_fatal)
+ __builtin_trap();
+}
+
+uint32_t determineNumberOfThreads(int32_t NumThreadsClause,
+ int32_t nt_strict = false,
+ int32_t nt_severity = severity_fatal,
+ const char *nt_message = nullptr) {
uint32_t NThreadsICV =
NumThreadsClause != -1 ? NumThreadsClause : icv::NThreads;
uint32_t NumThreads = mapping::getMaxTeamThreads();
@@ -55,13 +72,17 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
// SPMD mode allows any number of threads, for generic mode we round down to a
// multiple of WARPSIZE since it is legal to do so in OpenMP.
- if (mapping::isSPMDMode())
- return NumThreads;
+ if (!mapping::isSPMDMode()) {
+ if (NumThreads < mapping::getWarpSize())
+ NumThreads = 1;
+ else
+ NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
+ }
- if (NumThreads < mapping::getWarpSize())
- NumThreads = 1;
- else
- NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
+ if (NumThreadsClause != -1 && nt_strict &&
+ NumThreads != static_cast<uint32_t>(NumThreadsClause))
+ num_threads_strict_error(nt_strict, nt_severity, nt_message,
+ NumThreadsClause, NumThreads);
return NumThreads;
}
@@ -82,12 +103,14 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
extern "C" {
-[[clang::always_inline]] void __kmpc_parallel_spmd(IdentTy *ident,
- int32_t num_threads,
- void *fn, void **args,
- const int64_t nargs) {
+[[clang::always_inline]] void
+__kmpc_parallel_spmd(IdentTy *ident, int32_t num_threads, void *fn, void **args,
+ const int64_t nargs, int32_t nt_strict = false,
+ int32_t nt_severity = severity_fatal,
+ const char *nt_message = nullptr) {
uint32_t TId = mapping::getThreadIdInBlock();
- uint32_t NumThreads = determineNumberOfThreads(num_threads);
+ uint32_t NumThreads =
+ determineNumberOfThreads(num_threads, nt_strict, nt_severity, nt_message);
uint32_t PTeamSize =
NumThreads == mapping::getMaxTeamThreads() ? 0 : NumThreads;
// Avoid the race between the read of the `icv::Level` above and the write
@@ -140,10 +163,11 @@ extern "C" {
return;
}
-[[clang::always_inline]] void
-__kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
- int32_t num_threads, int proc_bind, void *fn,
- void *wrapper_fn, void **args, int64_t nargs) {
+[[clang::always_inline]] void __kmpc_parallel_51(
+ IdentTy *ident, int32_t, int32_t if_expr, int32_t num_threads,
+ int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
+ int32_t nt_strict = false, int32_t nt_severity = severity_fatal,
+ const char *nt_message = nullptr) {
uint32_t TId = mapping::getThreadIdInBlock();
// Assert the parallelism level is zero if disabled by the user.
@@ -156,6 +180,12 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
// 3) nested parallel regions
if (OMP_UNLIKELY(!if_expr || state::HasThreadState ||
(config::mayUseNestedParallelism() && icv::Level))) {
+ // OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also have
+ // effect when parallel execution is disabled by a corresponding if clause
+ // attached to the parallel directive.
+ if (nt_strict && num_threads > 1)
+ num_threads_strict_error(nt_strict, nt_severity, nt_message, num_threads,
+ 1);
state::DateEnvironmentRAII DERAII(ident);
++icv::Level;
invokeMicrotask(TId, 0, fn, args, nargs);
@@ -169,12 +199,14 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
// This was moved to its own routine so it could be called directly
// in certain situations to avoid resource consumption of unused
// logic in parallel_51.
- __kmpc_parallel_spmd(ident, num_threads, fn, args, nargs);
+ __kmpc_parallel_spmd(ident, num_threads, fn, args, nargs, nt_strict,
+ nt_severity, nt_message);
return;
}
- uint32_t NumThreads = determineNumberOfThreads(num_threads);
+ uint32_t NumThreads =
+ determineNumberOfThreads(num_threads, nt_strict, nt_severity, nt_message);
uint32_t MaxTeamThreads = mapping::getMaxTeamThreads();
uint32_t PTeamSize = NumThreads == MaxTeamThreads ? 0 : NumThreads;
@@ -277,6 +309,16 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
__kmpc_end_sharing_variables();
}
+[[clang::always_inline]] void __kmpc_parallel_60(
+ IdentTy *ident, int32_t id, int32_t if_expr, int32_t num_threads,
+ int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
+ int32_t nt_strict = false, int32_t nt_severity = severity_fatal,
+ const char *nt_message = nullptr) {
+ return __kmpc_parallel_51(ident, id, if_expr, num_threads, proc_bind, fn,
+ wrapper_fn, args, nargs, nt_strict, nt_severity,
+ nt_message);
+}
+
[[clang::noinline]] bool __kmpc_kernel_parallel(ParallelRegionFnTy *WorkFn) {
// Work function and arguments for L1 parallel region.
*WorkFn = state::ParallelRegionFn;
diff --git a/openmp/runtime/src/kmp.h b/openmp/runtime/src/kmp.h
index a2cacc8792b15..983e1c34f76b8 100644
--- a/openmp/runtime/src/kmp.h
+++ b/openmp/runtime/src/kmp.h
@@ -4666,6 +4666,7 @@ static inline int __kmp_adjust_gtid_for_hidden_helpers(int gtid) {
}
// Support for error directive
+// See definition in offload/DeviceRTL DeviceTypes.h
typedef enum kmp_severity_t {
severity_warning = 1,
severity_fatal = 2
``````````
</details>
https://github.com/llvm/llvm-project/pull/146404
More information about the llvm-branch-commits
mailing list