[flang-commits] [flang] [flang][OpenMP] Implement checks of intervening code (PR #185295)
Krzysztof Parzyszek via flang-commits
flang-commits at lists.llvm.org
Mon Mar 9 07:37:31 PDT 2026
https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/185295
>From 2d7d763b4cb55c3112454d5ad0022048125bc713 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 19 Feb 2026 13:16:34 -0600
Subject: [PATCH 1/5] [flang][OpenMP] Import ExecutionPartIterator et al into
semantics::omp, NFC
---
flang/include/flang/Semantics/openmp-utils.h | 13 +++++++++++++
flang/lib/Semantics/check-omp-loop.cpp | 3 ---
2 files changed, 13 insertions(+), 3 deletions(-)
diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h
index 90fd36708de0e..221e4cb23eada 100644
--- a/flang/include/flang/Semantics/openmp-utils.h
+++ b/flang/include/flang/Semantics/openmp-utils.h
@@ -27,6 +27,14 @@
#include <type_traits>
#include <utility>
+namespace Fortran::parser::omp {
+struct ExecutionPartIterator;
+struct LoopNestIterator;
+template <typename T> struct ExecutionPartRange;
+using BlockRange = ExecutionPartRange<ExecutionPartIterator>;
+using LoopRange = ExecutionPartRange<LoopNestIterator>;
+} // namespace Fortran::parser::omp
+
namespace Fortran::semantics {
class Scope;
class SemanticsContext;
@@ -34,6 +42,11 @@ class Symbol;
// Add this namespace to avoid potential conflicts
namespace omp {
+using Fortran::parser::omp::ExecutionPartIterator;
+using Fortran::parser::omp::LoopNestIterator;
+using Fortran::parser::omp::BlockRange;
+using Fortran::parser::omp::LoopRange;
+
template <typename T, typename U = std::remove_const_t<T>> U AsRvalue(T &t) {
return U(t);
}
diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp
index f81bde981594d..e13ea820c7ef6 100644
--- a/flang/lib/Semantics/check-omp-loop.cpp
+++ b/flang/lib/Semantics/check-omp-loop.cpp
@@ -334,7 +334,6 @@ static std::optional<size_t> CountGeneratedNests(const parser::Block &block) {
// messages about a potentially incorrect loop count.
// In such cases reset the count to nullopt. Once it becomes nullopt,
// keep it that way.
- using LoopRange = parser::omp::LoopRange;
std::optional<size_t> numLoops{0};
for (auto &epc : LoopRange(block, LoopRange::Step::Over)) {
if (auto genCount{CountGeneratedNests(epc)}) {
@@ -371,7 +370,6 @@ void OmpStructureChecker::CheckNestedConstruct(
// Check constructs contained in the body of the loop construct.
auto &body{std::get<parser::Block>(x.t)};
- using BlockRange = parser::omp::BlockRange;
for (auto &stmt : BlockRange(body, BlockRange::Step::Over)) {
if (auto *dir{parser::Unwrap<parser::CompilerDirective>(stmt)}) {
context_.Say(dir->source,
@@ -495,7 +493,6 @@ void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
void OmpStructureChecker::CheckIterationVariableType(
const parser::OpenMPLoopConstruct &x) {
- using LoopRange = parser::omp::LoopRange;
auto &body{std::get<parser::Block>(x.t)};
for (auto &construct : LoopRange(body, LoopRange::Step::Into)) {
// 'construct' can also be OpenMPLoopConstruct
>From 77235406105e7e686f1ca19bbfb4956aeef9257e Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 18 Feb 2026 10:53:13 -0600
Subject: [PATCH 2/5] [flang][OpenMP] Remember original range in
ExecutionPartIterator
Storing the original range (instead of just the "remaining part")
will allow the iterator component to be reused.
---
flang/include/flang/Parser/openmp-utils.h | 25 ++++++++++++++++++-----
flang/lib/Parser/openmp-utils.cpp | 11 ++++------
2 files changed, 24 insertions(+), 12 deletions(-)
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index f23e52585d567..20754ad28d26d 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -296,14 +296,28 @@ struct ExecutionPartIterator {
using IteratorType = Block::const_iterator;
using IteratorRange = llvm::iterator_range<IteratorType>;
+ // An iterator range with a third iterator indicating a position inside
+ // the range.
+ struct IteratorGauge : public IteratorRange {
+ IteratorGauge(IteratorType b, IteratorType e)
+ : IteratorRange(b, e), at(b) {}
+ IteratorGauge(IteratorRange r) : IteratorRange(r), at(r.begin()) {}
+
+ bool atEnd() const { return at == end(); }
+ IteratorType at;
+ };
+
struct Construct {
Construct(IteratorType b, IteratorType e, const ExecutionPartConstruct *c)
- : range(b, e), owner(c) {}
+ : location(b, e), owner(c) {}
template <typename R>
Construct(const R &r, const ExecutionPartConstruct *c)
- : range(r), owner(c) {}
+ : location(r), owner(c) {}
Construct(const Construct &c) = default;
- IteratorRange range;
+ // The original range of the construct with the current position in it.
+ // The location.at is the construct currently being pointed at, or
+ // stepped into.
+ IteratorGauge location;
const ExecutionPartConstruct *owner;
};
@@ -332,6 +346,7 @@ struct ExecutionPartIterator {
bool valid() const { return !stack_.empty(); }
+ const std::vector<Construct> &stack() const { return stack_; }
decltype(auto) operator*() const { return *at(); }
bool operator==(const ExecutionPartIterator &other) const {
if (valid() != other.valid()) {
@@ -339,7 +354,7 @@ struct ExecutionPartIterator {
}
// Invalid iterators are considered equal.
return !valid() ||
- stack_.back().range.begin() == other.stack_.back().range.begin();
+ stack_.back().location.at == other.stack_.back().location.at;
}
bool operator!=(const ExecutionPartIterator &other) const {
return !(*this == other);
@@ -368,7 +383,7 @@ struct ExecutionPartIterator {
using iterator_category = std::forward_iterator_tag;
private:
- IteratorType at() const { return stack_.back().range.begin(); };
+ IteratorType at() const { return stack_.back().location.at; };
// If the iterator is not at a legal location, keep advancing it until
// it lands at a legal location or becomes invalid.
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index 6d4326af78344..f83a658104396 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -243,8 +243,7 @@ void ExecutionPartIterator::step() {
} else if (auto *loop{GetDoConstruct(*where)}) {
stack_.emplace_back(std::get<Block>(loop->t), &*where);
} else {
- stack_.back().range =
- IteratorRange(std::next(where), stack_.back().range.end());
+ ++stack_.back().location.at;
}
adjust();
}
@@ -254,8 +253,7 @@ void ExecutionPartIterator::next() {
// Advance the iterator to the next legal position. If the current
// position is a DO-loop or a loop construct, step over it.
if (valid()) {
- stack_.back().range =
- IteratorRange(std::next(at()), stack_.back().range.end());
+ ++stack_.back().location.at;
adjust();
}
}
@@ -264,11 +262,10 @@ void ExecutionPartIterator::adjust() {
// If the iterator is not at a legal location, keep advancing it until
// it lands at a legal location or becomes invalid.
while (valid()) {
- if (stack_.back().range.empty()) {
+ if (stack_.back().location.atEnd()) {
stack_.pop_back();
if (valid()) {
- stack_.back().range =
- IteratorRange(std::next(at()), stack_.back().range.end());
+ ++stack_.back().location.at;
}
} else if (auto *block{GetFortranBlockConstruct(*at())}) {
stack_.emplace_back(std::get<Block>(block->t), &*at());
>From dfa96281667b3523448d1fc3afacf798d6283438 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 26 Feb 2026 11:28:32 -0600
Subject: [PATCH 3/5] [flang][OpenMP] Move two functions to openmp-utils.cpp,
NFC
Move `IsLoopTransforming` and `IsFullUnroll` from check-omp-loop.cpp
to openmp-utils.cpp.
---
flang/include/flang/Semantics/openmp-utils.h | 3 ++
flang/lib/Semantics/check-omp-loop.cpp | 29 --------------------
flang/lib/Semantics/openmp-utils.cpp | 27 ++++++++++++++++++
3 files changed, 30 insertions(+), 29 deletions(-)
diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h
index 221e4cb23eada..a10d826e4050c 100644
--- a/flang/include/flang/Semantics/openmp-utils.h
+++ b/flang/include/flang/Semantics/openmp-utils.h
@@ -114,6 +114,9 @@ bool IsAssignment(const parser::ActionStmt *x);
bool IsPointerAssignment(const evaluate::Assignment &x);
MaybeExpr MakeEvaluateExpr(const parser::OmpStylizedInstance &inp);
+
+bool IsLoopTransforming(llvm::omp::Directive dir);
+bool IsFullUnroll(const parser::OpenMPLoopConstruct &x);
} // namespace omp
} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp
index e13ea820c7ef6..d6e5a3f0aa7fb 100644
--- a/flang/lib/Semantics/check-omp-loop.cpp
+++ b/flang/lib/Semantics/check-omp-loop.cpp
@@ -38,8 +38,6 @@
#include <variant>
namespace Fortran::semantics {
-static bool IsLoopTransforming(llvm::omp::Directive dir);
-static bool IsFullUnroll(const parser::OpenMPLoopConstruct &x);
static std::optional<size_t> CountGeneratedNests(
const parser::ExecutionPartConstruct &epc);
static std::optional<size_t> CountGeneratedNests(const parser::Block &block);
@@ -248,33 +246,6 @@ void OmpStructureChecker::CheckSIMDNest(const parser::OpenMPConstruct &c) {
}
}
-static bool IsLoopTransforming(llvm::omp::Directive dir) {
- switch (dir) {
- // TODO case llvm::omp::Directive::OMPD_flatten:
- case llvm::omp::Directive::OMPD_fuse:
- case llvm::omp::Directive::OMPD_interchange:
- case llvm::omp::Directive::OMPD_nothing:
- case llvm::omp::Directive::OMPD_reverse:
- // TODO case llvm::omp::Directive::OMPD_split:
- case llvm::omp::Directive::OMPD_stripe:
- case llvm::omp::Directive::OMPD_tile:
- case llvm::omp::Directive::OMPD_unroll:
- return true;
- default:
- return false;
- }
-}
-
-static bool IsFullUnroll(const parser::OpenMPLoopConstruct &x) {
- const parser::OmpDirectiveSpecification &beginSpec{x.BeginDir()};
-
- if (beginSpec.DirName().v == llvm::omp::Directive::OMPD_unroll) {
- return parser::omp::FindClause(
- beginSpec, llvm::omp::Clause::OMPC_partial) == nullptr;
- }
- return false;
-}
-
// Count the number of loop nests generated by `epc`. This is just a helper
// function for counting the number of loop nests in a parser::Block.
static std::optional<size_t> CountGeneratedNests(
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index 937938a0d10ce..dbc7e216c4788 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -523,4 +523,31 @@ MaybeExpr MakeEvaluateExpr(const parser::OmpStylizedInstance &inp) {
},
instance.u);
}
+
+bool IsLoopTransforming(llvm::omp::Directive dir) {
+ switch (dir) {
+ // TODO case llvm::omp::Directive::OMPD_flatten:
+ case llvm::omp::Directive::OMPD_fuse:
+ case llvm::omp::Directive::OMPD_interchange:
+ case llvm::omp::Directive::OMPD_nothing:
+ case llvm::omp::Directive::OMPD_reverse:
+ // TODO case llvm::omp::Directive::OMPD_split:
+ case llvm::omp::Directive::OMPD_stripe:
+ case llvm::omp::Directive::OMPD_tile:
+ case llvm::omp::Directive::OMPD_unroll:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool IsFullUnroll(const parser::OpenMPLoopConstruct &x) {
+ const parser::OmpDirectiveSpecification &beginSpec{x.BeginDir()};
+
+ if (beginSpec.DirName().v == llvm::omp::Directive::OMPD_unroll) {
+ return parser::omp::FindClause(
+ beginSpec, llvm::omp::Clause::OMPC_partial) == nullptr;
+ }
+ return false;
+}
} // namespace Fortran::semantics::omp
>From c6b0684b934357d1e4c93f3300f2e69fc7856880 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 26 Feb 2026 11:20:56 -0600
Subject: [PATCH 4/5] [flang][OpenMP] Refactor CountGeneratedNests, NFC
Extract handling of individual constructs into a helper function.
Change the base count type to `int64_t` to match the type used
in GetIntValue.
Rename the function to GetNumGeneratedNests.
---
flang/include/flang/Semantics/openmp-utils.h | 4 ++
flang/lib/Semantics/check-omp-loop.cpp | 63 ++++++--------------
flang/lib/Semantics/openmp-utils.cpp | 58 ++++++++++++++++++
3 files changed, 80 insertions(+), 45 deletions(-)
diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h
index a10d826e4050c..7f6c2824a986a 100644
--- a/flang/include/flang/Semantics/openmp-utils.h
+++ b/flang/include/flang/Semantics/openmp-utils.h
@@ -117,6 +117,10 @@ MaybeExpr MakeEvaluateExpr(const parser::OmpStylizedInstance &inp);
bool IsLoopTransforming(llvm::omp::Directive dir);
bool IsFullUnroll(const parser::OpenMPLoopConstruct &x);
+
+std::optional<int64_t> GetNumGeneratedNestsFrom(
+ const parser::ExecutionPartConstruct &epc,
+ std::optional<int64_t> nestedCount);
} // namespace omp
} // namespace Fortran::semantics
diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp
index d6e5a3f0aa7fb..45f4798d0c3c6 100644
--- a/flang/lib/Semantics/check-omp-loop.cpp
+++ b/flang/lib/Semantics/check-omp-loop.cpp
@@ -38,9 +38,9 @@
#include <variant>
namespace Fortran::semantics {
-static std::optional<size_t> CountGeneratedNests(
+static std::optional<int64_t> GetNumGeneratedNests(
const parser::ExecutionPartConstruct &epc);
-static std::optional<size_t> CountGeneratedNests(const parser::Block &block);
+static std::optional<int64_t> GetNumGeneratedNests(const parser::Block &block);
} // namespace Fortran::semantics
namespace {
@@ -248,7 +248,7 @@ void OmpStructureChecker::CheckSIMDNest(const parser::OpenMPConstruct &c) {
// Count the number of loop nests generated by `epc`. This is just a helper
// function for counting the number of loop nests in a parser::Block.
-static std::optional<size_t> CountGeneratedNests(
+static std::optional<int64_t> GetNumGeneratedNests(
const parser::ExecutionPartConstruct &epc) {
if (parser::Unwrap<parser::DoConstruct>(epc)) {
return 1;
@@ -258,56 +258,29 @@ static std::optional<size_t> CountGeneratedNests(
const parser::OmpDirectiveSpecification &beginSpec{omp.BeginDir()};
llvm::omp::Directive dir{beginSpec.DirName().v};
- // TODO: Handle split, apply.
- if (IsFullUnroll(omp)) {
- return std::nullopt;
- }
- if (dir == llvm::omp::Directive::OMPD_fuse) {
- auto nestedCount{CountGeneratedNests(std::get<parser::Block>(omp.t))};
- // If there are no loops nested inside of FUSE, then the construct is
- // invalid. This case will be diagnosed when analyzing the body of the FUSE
- // construct itself, not when checking a construct in which the FUSE is
- // nested.
- // Returning std::nullopt prevents error messages caused by the same
- // problem from being emitted for every enclosing loop construct, for
- // example:
- // !$omp do ! error: this should contain a loop (superfluous)
- // !$omp fuse ! error: this should contain a loop
- // !$omp end fuse
- if (!nestedCount || *nestedCount == 0) {
- return std::nullopt;
- }
- auto *clause{
- parser::omp::FindClause(beginSpec, llvm::omp::Clause::OMPC_looprange)};
- if (!clause) {
- return 1;
- }
-
- auto *loopRange{parser::Unwrap<parser::OmpLooprangeClause>(*clause)};
- std::optional<int64_t> count{GetIntValue(std::get<1>(loopRange->t))};
- if (!count || *count <= 0) {
- return std::nullopt;
- }
- if (static_cast<size_t>(*count) <= *nestedCount) {
- return 1 + *nestedCount - static_cast<size_t>(*count);
- }
- return std::nullopt;
+ switch (dir) {
+ case llvm::omp::Directive::OMPD_fuse:
+ case llvm::omp::Directive::OMPD_nothing:
+ return GetNumGeneratedNestsFrom(
+ epc, GetNumGeneratedNests(std::get<parser::Block>(omp.t)));
+ default:
+ break;
}
// For every other loop construct return 1.
return 1;
}
-static std::optional<size_t> CountGeneratedNests(const parser::Block &block) {
+static std::optional<int64_t> GetNumGeneratedNests(const parser::Block &block) {
// Count the number of loops in the associated block. If there are any
// malformed construct in there, getting the number may be meaningless.
// These issues will be diagnosed elsewhere, and we should not emit any
// messages about a potentially incorrect loop count.
// In such cases reset the count to nullopt. Once it becomes nullopt,
// keep it that way.
- std::optional<size_t> numLoops{0};
+ std::optional<int64_t> numLoops{0};
for (auto &epc : LoopRange(block, LoopRange::Step::Over)) {
- if (auto genCount{CountGeneratedNests(epc)}) {
+ if (auto genCount{GetNumGeneratedNests(epc)}) {
*numLoops += *genCount;
} else {
numLoops = std::nullopt;
@@ -363,7 +336,7 @@ void OmpStructureChecker::CheckNestedConstruct(
// Check if a loop-nest-associated construct has only one top-level loop
// in it.
- if (std::optional<size_t> numLoops{CountGeneratedNests(body)}) {
+ if (std::optional<int64_t> numLoops{GetNumGeneratedNests(body)}) {
if (*numLoops == 0) {
context_.Say(beginSpec.DirName().source,
"This construct should contain a DO-loop or a loop-nest-generating OpenMP construct"_err_en_US);
@@ -371,7 +344,7 @@ void OmpStructureChecker::CheckNestedConstruct(
auto assoc{llvm::omp::getDirectiveAssociation(beginSpec.DirName().v)};
if (*numLoops > 1 && assoc == llvm::omp::Association::LoopNest) {
context_.Say(beginSpec.DirName().source,
- "This construct applies to a loop nest, but has a loop sequence of length %zu"_err_en_US,
+ "This construct applies to a loop nest, but has a loop sequence of length %ld"_err_en_US,
*numLoops);
}
}
@@ -585,11 +558,11 @@ void OmpStructureChecker::CheckLooprangeBounds(
if (!first || !count || *first <= 0 || *count <= 0) {
return;
}
- auto requiredCount{static_cast<size_t>(*first + *count - 1)};
- if (auto loopCount{CountGeneratedNests(std::get<parser::Block>(x.t))}) {
+ int64_t requiredCount{*first + *count - 1};
+ if (auto loopCount{GetNumGeneratedNests(std::get<parser::Block>(x.t))}) {
if (*loopCount < requiredCount) {
context_.Say(clause->source,
- "The specified loop range requires %zu loops, but the loop sequence has a length of %zu"_err_en_US,
+ "The specified loop range requires %ld loops, but the loop sequence has a length of %ld"_err_en_US,
requiredCount, *loopCount);
}
}
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index dbc7e216c4788..533242287a667 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -550,4 +550,62 @@ bool IsFullUnroll(const parser::OpenMPLoopConstruct &x) {
}
return false;
}
+
+std::optional<int64_t> GetNumGeneratedNestsFrom(
+ const parser::ExecutionPartConstruct &epc,
+ std::optional<int64_t> nestedCount) {
+ if (parser::Unwrap<parser::DoConstruct>(epc)) {
+ return 1;
+ }
+
+ auto &omp{DEREF(parser::Unwrap<parser::OpenMPLoopConstruct>(epc))};
+ const parser::OmpDirectiveSpecification &beginSpec{omp.BeginDir()};
+ llvm::omp::Directive dir{beginSpec.DirId()};
+ if (!IsLoopTransforming(dir)) {
+ return 0;
+ }
+
+ // TODO: Handle split, apply.
+ if (IsFullUnroll(omp)) {
+ return std::nullopt;
+ }
+
+ if (dir == llvm::omp::Directive::OMPD_fuse) {
+ // If there are no loops nested inside of FUSE, then the construct is
+ // invalid. This case will be diagnosed when analyzing the body of the FUSE
+ // construct itself, not when checking a construct in which the FUSE is
+ // nested.
+ // Returning std::nullopt prevents error messages caused by the same
+ // problem from being emitted for every enclosing loop construct, for
+ // example:
+ // !$omp do ! error: this should contain a loop (superfluous)
+ // !$omp fuse ! error: this should contain a loop
+ // !$omp end fuse
+ if (!nestedCount || *nestedCount == 0) {
+ return std::nullopt;
+ }
+ auto *clause{
+ parser::omp::FindClause(beginSpec, llvm::omp::Clause::OMPC_looprange)};
+ if (!clause) {
+ return 1;
+ }
+
+ auto *loopRange{parser::Unwrap<parser::OmpLooprangeClause>(*clause)};
+ std::optional<int64_t> count{GetIntValue(std::get<1>(loopRange->t))};
+ if (!count || *count <= 0) {
+ return std::nullopt;
+ }
+ if (*count <= *nestedCount) {
+ return 1 + *nestedCount - *count;
+ }
+ return std::nullopt;
+ }
+
+ if (dir == llvm::omp::Directive::OMPD_nothing) {
+ return nestedCount;
+ }
+
+ // For every other loop construct return 1.
+ return 1;
+}
} // namespace Fortran::semantics::omp
>From 45b8c77bfcfa9a25590058853af1fb1974585c89 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 26 Feb 2026 15:06:12 -0600
Subject: [PATCH 5/5] [flang][OpenMP] Implement checks of intervening code
Invalid intervening code will cause the containing loop to be the final
loop in the loop nest. Transparent intervening code will not affect
perfect nesting if present. Currently compiler directives are considered
transparent to allow code mixing OpenMP and such directives to compile.
---
flang/lib/Semantics/openmp-utils.cpp | 147 ++++++++++++++++++++++++++-
1 file changed, 145 insertions(+), 2 deletions(-)
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index 533242287a667..0b50160053012 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -19,6 +19,7 @@
#include "flang/Common/visit.h"
#include "flang/Evaluate/check-expression.h"
#include "flang/Evaluate/expression.h"
+#include "flang/Evaluate/match.h"
#include "flang/Evaluate/tools.h"
#include "flang/Evaluate/traverse.h"
#include "flang/Evaluate/type.h"
@@ -244,14 +245,17 @@ bool IsMapExitingType(parser::OmpMapType::Value type) {
}
}
-MaybeExpr GetEvaluateExpr(const parser::Expr &parserExpr) {
- const parser::TypedExpr &typedExpr{parserExpr.typedExpr};
+static MaybeExpr GetEvaluateExprFromTyped(const parser::TypedExpr &typedExpr) {
// ForwardOwningPointer typedExpr
// `- GenericExprWrapper ^.get()
// `- std::optional<Expr> ^->v
return DEREF(typedExpr.get()).v;
}
+MaybeExpr GetEvaluateExpr(const parser::Expr &parserExpr) {
+ return GetEvaluateExprFromTyped(parserExpr.typedExpr);
+}
+
std::optional<evaluate::DynamicType> GetDynamicType(
const parser::Expr &parserExpr) {
if (auto maybeExpr{GetEvaluateExpr(parserExpr)}) {
@@ -551,6 +555,145 @@ bool IsFullUnroll(const parser::OpenMPLoopConstruct &x) {
return false;
}
+namespace {
+// Helper class to check if a given evaluate::Expr is an array expression.
+// This does not check any proper subexpressions of the expression (except
+// parentheses).
+struct ArrayExpressionRecognizer {
+ template <TypeCategory C>
+ static bool isArrayExpression(
+ const evaluate::Expr<evaluate::SomeKind<C>> &x) {
+ return common::visit([](auto &&s) { return isArrayExpression(s); }, x.u);
+ }
+
+ template <TypeCategory C, int K>
+ static bool isArrayExpression(const evaluate::Expr<evaluate::Type<C, K>> &x) {
+ return common::visit([](auto &&s) { return isArrayExpression(s); },
+ evaluate::match::deparen(x).u);
+ }
+
+ template <typename T>
+ static bool isArrayExpression(const evaluate::Designator<T> &x) {
+ if (auto *sym{std::get_if<SymbolRef>(&x.u)}) {
+ return (*sym)->Rank() != 0;
+ }
+ if (auto *array{std::get_if<evaluate::ArrayRef>(&x.u)}) {
+ return llvm::any_of(array->subscript(), [](const evaluate::Subscript &s) {
+ // A vector subscript will not be a Triplet, but will have rank > 0.
+ return std::holds_alternative<evaluate::Triplet>(s.u) || s.Rank() > 0;
+ });
+ }
+ return false;
+ }
+
+ template <typename T> static bool isArrayExpression(const T &x) {
+ return false;
+ }
+
+ static bool isArrayExpression(const evaluate::Expr<evaluate::SomeType> &x) {
+ return common::visit([](auto &&s) { return isArrayExpression(s); }, x.u);
+ }
+};
+
+// Helper class to check if a given evaluate::Expr contains a subexpression
+// (not necessarily proper) that is an array expression.
+struct ArrayExpressionFinder
+ : public evaluate::AnyTraverse<ArrayExpressionFinder> {
+ using Base = evaluate::AnyTraverse<ArrayExpressionFinder>;
+ using Base::operator();
+ ArrayExpressionFinder() : Base(*this) {}
+
+ template <typename T>
+ bool operator()(const evaluate::Designator<T> &x) const {
+ return ArrayExpressionRecognizer::isArrayExpression(x);
+ }
+};
+
+// Helper class to check if any array expressions contained in the given
+// evaluate::Expr satisfy the criteria for being in "intervening code".
+struct ArrayExpressionChecker {
+ template <typename T> bool Pre(const T &) { return true; }
+ template <typename T> void Post(const T &) {}
+
+ bool Pre(const parser::Expr &parserExpr) {
+ // If we have found a prohibited expression, skip the rest of the
+ // traversal.
+ if (!rejected) {
+ if (auto expr{GetEvaluateExpr(parserExpr)}) {
+ rejected = ArrayExpressionFinder{}(*expr);
+ }
+ }
+ return !rejected;
+ }
+
+ bool rejected{false};
+};
+} // namespace
+
+static bool ContainsInvalidArrayExpression(
+ const parser::ExecutionPartConstruct &x) {
+ ArrayExpressionChecker checker;
+ parser::Walk(x, checker);
+ return checker.rejected;
+}
+
+bool IsValidInterveningCode(const parser::ExecutionPartConstruct &x) {
+ static auto isScalar = [](const parser::Variable &variable) {
+ if (auto expr{GetEvaluateExprFromTyped(variable.typedExpr)}) {
+ return expr->Rank() == 0;
+ }
+ return false;
+ };
+
+ auto *exec{parser::Unwrap<parser::ExecutableConstruct>(x)};
+ if (!exec) {
+ // DATA, ENTRY, FORMAT, NAMELIST are not explicitly prohibited in a CLN
+ // although they are likely disallowed due to other requirements.
+ // Return true, they should be rejected elsewhere if necessary.
+ return true;
+ }
+
+ if (auto *action{parser::Unwrap<parser::ActionStmt>(exec->u)}) {
+ if (parser::Unwrap<parser::CycleStmt>(action->u) ||
+ parser::Unwrap<parser::ExitStmt>(action->u) ||
+ parser::Unwrap<parser::ForallStmt>(action->u) ||
+ parser::Unwrap<parser::WhereStmt>(action->u)) {
+ return false;
+ }
+ if (auto *assign{parser::Unwrap<parser::AssignmentStmt>(&action->u)}) {
+ if (!isScalar(std::get<parser::Variable>(assign->t))) {
+ return false;
+ }
+ }
+ } else { // Not ActionStmt
+ if (parser::Unwrap<parser::LabelDoStmt>(exec->u) ||
+ parser::Unwrap<parser::DoConstruct>(exec->u) ||
+ parser::Unwrap<parser::ForallConstruct>(exec->u) ||
+ parser::Unwrap<parser::WhereConstruct>(exec->u)) {
+ return false;
+ }
+ if (auto *omp{parser::Unwrap<parser::OpenMPConstruct>(exec->u)}) {
+ auto dirName{GetOmpDirectiveName(*omp)};
+ if (llvm::omp::getDirectiveCategory(dirName.v) ==
+ llvm::omp::Category::Executable) {
+ return false;
+ }
+ }
+ }
+
+ if (ContainsInvalidArrayExpression(x)) {
+ return false;
+ }
+
+ return true;
+}
+
+bool IsTransparentInterveningCode(const parser::ExecutionPartConstruct &x) {
+ // Tolerate compiler directives in perfect nests.
+ return parser::Unwrap<parser::CompilerDirective>(x) ||
+ parser::Unwrap<parser::ContinueStmt>(x);
+}
+
std::optional<int64_t> GetNumGeneratedNestsFrom(
const parser::ExecutionPartConstruct &epc,
std::optional<int64_t> nestedCount) {
More information about the flang-commits
mailing list