[flang-commits] [flang] [flang][OpenMP] Implement GetGeneratedNestDepthWithReason (PR #191718)
Krzysztof Parzyszek via flang-commits
flang-commits at lists.llvm.org
Sun Apr 12 08:15:04 PDT 2026
https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/191718
For a loop-nest-generating construct this function returns the number of loops in the generated loop nest.
A loop-nest-transformation construct can be thought as replacing N nested loops with K nested loops, where
N = GetAffectedNestDepthWithReason
K = GetGeneratedNestDepthWithReason
>From 124455574867b7c26c13c8c60abf5af62482a2b3 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Sun, 12 Apr 2026 09:02:28 -0500
Subject: [PATCH] [flang][OpenMP] Implement GetGeneratedNestDepthWithReason
For a loop-nest-generating construct this function returns the number of
loops in the generated loop nest.
A loop-nest-transformation construct can be thought as replacing N nested
loops with K nested loops, where
N = GetAffectedNestDepthWithReason
K = GetGeneratedNestDepthWithReason
---
flang/include/flang/Semantics/openmp-utils.h | 5 ++
flang/lib/Semantics/openmp-utils.cpp | 91 ++++++++++++++------
2 files changed, 72 insertions(+), 24 deletions(-)
diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h
index b449988abe9c8..252c88c967750 100644
--- a/flang/include/flang/Semantics/openmp-utils.h
+++ b/flang/include/flang/Semantics/openmp-utils.h
@@ -218,6 +218,11 @@ WithReason<int64_t> GetHeightWithReason(
std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
const parser::OmpDirectiveSpecification &spec, unsigned version,
SemanticsContext *semaCtx = nullptr);
+/// Return the depth of the generated nest(s):
+/// {generated-depth, is-perfect-nest}
+std::pair<WithReason<int64_t>, bool> GetGeneratedNestDepthWithReason(
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx = nullptr);
/// Return the range of the affected nests in the sequence:
/// {first, count}.
/// If the range is "the whole sequence", the return value will be {1, -1}.
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index c901d00695093..7a5b31b1c805e 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -767,37 +767,27 @@ WithReason<int64_t> GetHeightWithReason(
"This construct is not a DO-loop or a loop-transformation construct"_because_en_US);
return {0, reason};
}
- switch (spec.DirName().v) {
- case llvm::omp::Directive::OMPD_flatten:
- if (auto &&value{GetArgumentValueWithReason(
- spec, llvm::omp::Clause::OMPC_depth, version, semaCtx)}) {
- // FLATTEN DEPTH(n) replaces n loops with 1.
- return {int64_t(1) - *value.value, std::move(value.reason)};
- } else {
- Reason reason;
- reason.Say(spec.DirName().source, MsgClauseAbsentAssume,
- GetUpperName(llvm::omp::Clause::OMPC_depth, version), "a depth of 2");
- return {-1, std::move(reason)};
- }
+
+ switch (spec.DirId()) {
+ // These generate loop sequences.
case llvm::omp::Directive::OMPD_fuse:
case llvm::omp::Directive::OMPD_split:
return {0, Reason()};
+ case llvm::omp::Directive::OMPD_flatten:
case llvm::omp::Directive::OMPD_interchange:
case llvm::omp::Directive::OMPD_nothing:
case llvm::omp::Directive::OMPD_reverse:
- return {0, Reason()};
case llvm::omp::Directive::OMPD_stripe:
case llvm::omp::Directive::OMPD_tile:
- return GetNumArgumentsWithReason(
- spec, llvm::omp::Clause::OMPC_sizes, version, semaCtx);
- case llvm::omp::Directive::OMPD_unroll:
- if (isFullUnroll) {
- Reason reason;
- reason.Say(spec.DirName().source, MsgConstructDoesNotResult,
- "Fully unrolled loop", "a loop nest");
- return {-1, std::move(reason)};
+ case llvm::omp::Directive::OMPD_unroll: {
+ auto [cons, _1]{GetAffectedNestDepthWithReason(spec, version, semaCtx)};
+ auto [prod, _2]{GetGeneratedNestDepthWithReason(spec, version, semaCtx)};
+ if (cons && prod) {
+ return WithReason<int64_t>{*prod.value - *cons.value,
+ Reason().Append(cons.reason).Append(prod.reason)};
}
- return {0, Reason()};
+ return {};
+ }
default:
llvm_unreachable("Expecting loop-transforming construct");
}
@@ -1000,6 +990,18 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
if (IsLoopTransforming(dir)) {
switch (dir) {
+ case llvm::omp::Directive::OMPD_flatten:
+ if (auto &&value{GetArgumentValueWithReason(
+ spec, llvm::omp::Clause::OMPC_depth, version, semaCtx)}) {
+ // FLATTEN DEPTH(n) replaces n loops with 1.
+ return {std::move(value), true};
+ } else {
+ Reason reason;
+ reason.Say(spec.DirName().source, MsgClauseAbsentAssume,
+ GetUpperName(llvm::omp::Clause::OMPC_depth, version), "a depth of 2");
+ return {{2, std::move(reason)}, true};
+ }
+ break;
case llvm::omp::Directive::OMPD_interchange: {
// Get the length of the argument list to PERMUTATION.
if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_permutation)) {
@@ -1015,6 +1017,8 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
spec.source, MsgClauseAbsentAssume, name, "a permutation (2, 1)");
return {{2, std::move(reason)}, true};
}
+ case llvm::omp::Directive::OMPD_nothing:
+ return {WithReason<int64_t>(0), false};
case llvm::omp::Directive::OMPD_stripe:
case llvm::omp::Directive::OMPD_tile: {
// Get the length of the argument list to SIZES.
@@ -1035,10 +1039,9 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
return {{1, std::move(reason)}, true};
}
case llvm::omp::Directive::OMPD_reverse:
+ case llvm::omp::Directive::OMPD_split:
case llvm::omp::Directive::OMPD_unroll:
return {WithReason<int64_t>(1), false};
- // TODO: case llvm::omp::Directive::OMPD_flatten:
- // TODO: case llvm::omp::Directive::OMPD_split:
default:
break;
}
@@ -1047,6 +1050,46 @@ std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
return {{}, false};
}
+/// Return the depth of the generated nest(s)
+/// {generated-depth, is-perfect-nest}
+std::pair<WithReason<int64_t>, bool> GetGeneratedNestDepthWithReason(
+ const parser::OmpDirectiveSpecification &spec, unsigned version,
+ SemanticsContext *semaCtx) {
+ llvm::omp::Directive dir{spec.DirId()};
+ if (!IsLoopTransforming(dir)) {
+ return {{}, false};
+ }
+
+ auto [depth, _]{GetAffectedNestDepthWithReason(spec, version, semaCtx)};
+
+ switch (dir) {
+ case llvm::omp::Directive::OMPD_flatten:
+ return {WithReason<int64_t>(1), true};
+ case llvm::omp::Directive::OMPD_fuse:
+ case llvm::omp::Directive::OMPD_split:
+ // These result in loop sequences.
+ return {{}, false};
+ case llvm::omp::Directive::OMPD_interchange:
+ case llvm::omp::Directive::OMPD_nothing:
+ case llvm::omp::Directive::OMPD_reverse:
+ return {depth, true};
+ case llvm::omp::Directive::OMPD_stripe:
+ case llvm::omp::Directive::OMPD_tile:
+ if (depth) {
+ return {
+ WithReason<int64_t>(2 * *depth.value, std::move(depth.reason)), true};
+ }
+ return {{}, true};
+ case llvm::omp::Directive::OMPD_unroll:
+ if (IsFullUnroll(spec)) {
+ return {WithReason<int64_t>(0), false};
+ }
+ return {WithReason<int64_t>(1), true};
+ default:
+ return {{}, false};
+ }
+}
+
/// Return the range of the affected nests in the sequence:
/// {first, count}
WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
More information about the flang-commits
mailing list