[flang-commits] [flang] [flang][OpenMP] Refactor interface of WithOmpDeclarative (PR #200876)
Krzysztof Parzyszek via flang-commits
flang-commits at lists.llvm.org
Wed Jun 3 05:12:16 PDT 2026
https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/200876
>From 8f2e7890b13db3e5bb4516d0c2ba9376864b13f0 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Thu, 28 May 2026 12:29:06 -0500
Subject: [PATCH 1/2] [flang][OpenMP] Refactor interface of WithOmpDeclarative
The two major changes are that:
1. The clause sets are not optional anymore. In the absence of any
declarative directives (REQURIRES in this case), the set will simply
be empty.
2. The optional memory order member will serve as the value of the
argument to the ATOMIC_DEFAULT_MEM_ORDER clause, and will only be
meaningful (and required) when the clause is a member of the clause set.
Additionally,
- Rename the RequiredClauses type alias to OmpClauseSet, since it will be
reused for other purposes in the future.
- Remove the has_* functions since they are not necessary, and when more
members are added these functions will only add to the clutter.
- Add a version_ member for printing directive/clause names.
---
flang/include/flang/Semantics/openmp-utils.h | 20 ++++++
flang/include/flang/Semantics/symbol.h | 30 ++++-----
flang/lib/Lower/OpenMP/Atomic.cpp | 4 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 5 +-
flang/lib/Semantics/check-omp-structure.cpp | 4 +-
flang/lib/Semantics/mod-file.cpp | 34 +++++-----
flang/lib/Semantics/resolve-directives.cpp | 63 +++++++------------
flang/lib/Semantics/symbol.cpp | 44 ++++++++-----
.../Semantics/OpenMP/requires-modfile.f90 | 6 +-
9 files changed, 110 insertions(+), 100 deletions(-)
diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h
index 4001b337193a1..c2e89fe829ce0 100644
--- a/flang/include/flang/Semantics/openmp-utils.h
+++ b/flang/include/flang/Semantics/openmp-utils.h
@@ -127,6 +127,26 @@ std::optional<int64_t> GetIntValueFromExpr(
return std::nullopt;
}
+// There are several clauses that take an optional, compile-time
+// constant bool argument. Those clauses are stored as std::optional, e.g.
+// OmpClause::ReverseOffload -> std::optional<OmpReverseOffloadClause>.
+// Retrieve the logical value if present.
+template <typename ClauseTy>
+std::optional<bool> GetLogicalArgument(
+ const std::optional<ClauseTy> &maybeClause, SemanticsContext &semaCtx) {
+ if (maybeClause) {
+ // Scalar<Logical<Constant<common::Indirection<Expr>>>>
+ auto &parserExpr{parser::UnwrapRef<parser::Expr>(*maybeClause)};
+ evaluate::ExpressionAnalyzer ea{semaCtx};
+ if (auto &&maybeExpr{ea.Analyze(parserExpr)}) {
+ if (auto v{GetLogicalValue(*maybeExpr)}) {
+ return *v;
+ }
+ }
+ }
+ return std::nullopt;
+}
+
std::optional<bool> IsContiguous(
SemanticsContext &semaCtx, const parser::OmpObject &object);
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index 775ac5ca3dcbc..a7647665973d2 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -52,22 +52,15 @@ using MutableSymbolVector = std::vector<MutableSymbolRef>;
// Mixin for details with OpenMP declarative constructs.
class WithOmpDeclarative {
public:
- // The set of requirements for any program unit include requirements
- // from any module used in the program unit.
- using RequiresClauses =
+ using OmpClauseSet =
common::EnumSet<llvm::omp::Clause, llvm::omp::Clause_enumSize>;
- bool has_ompRequires() const { return ompRequires_.has_value(); }
- const RequiresClauses *ompRequires() const {
- return ompRequires_ ? &*ompRequires_ : nullptr;
- }
- void set_ompRequires(RequiresClauses clauses) { ompRequires_ = clauses; }
+ const OmpClauseSet &ompRequires() const { return ompRequires_; }
+ void set_ompRequires(OmpClauseSet clauses) { ompRequires_ = clauses; }
- bool has_ompAtomicDefaultMemOrder() const {
- return ompAtomicDefaultMemOrder_.has_value();
- }
- const common::OmpMemoryOrderType *ompAtomicDefaultMemOrder() const {
- return ompAtomicDefaultMemOrder_ ? &*ompAtomicDefaultMemOrder_ : nullptr;
+ const std::optional<common::OmpMemoryOrderType> &
+ ompAtomicDefaultMemOrder() const {
+ return ompAtomicDefaultMemOrder_;
}
void set_ompAtomicDefaultMemOrder(common::OmpMemoryOrderType flags) {
ompAtomicDefaultMemOrder_ = flags;
@@ -76,8 +69,17 @@ class WithOmpDeclarative {
friend llvm::raw_ostream &operator<<(
llvm::raw_ostream &, const WithOmpDeclarative &);
+ void set_version(unsigned version) { version_ = version; }
+
private:
- std::optional<RequiresClauses> ompRequires_;
+ unsigned version_;
+ // The set of clauses from a REQUIRES directive. Only applicable
+ // to program unit symbols (i.e. scopes of the REQUIRES directive).
+ // The set of requirements for any program unit include requirements
+ // from any module used in the program unit.
+ OmpClauseSet ompRequires_;
+ // The argument to ATOMIC_DEFAULT_MEM_ORDER. Only needed when the ADMO
+ // clause is present in the ompRequires_ set.
std::optional<common::OmpMemoryOrderType> ompAtomicDefaultMemOrder_;
};
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index b80564fddd943..e7745ee2c8547 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -192,7 +192,9 @@ getMemoryOrderFromRequires(const semantics::Scope &scope) {
using WithOmpDeclarative = semantics::WithOmpDeclarative;
if constexpr (std::is_convertible_v<decltype(s),
const WithOmpDeclarative &>) {
- return s.ompAtomicDefaultMemOrder();
+ if (auto &admo{s.ompAtomicDefaultMemOrder()}) {
+ return &*admo;
+ }
}
return static_cast<const common::OmpMemoryOrderType *>(nullptr);
},
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 620eeadbc0711..6114dfe392145 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -5208,14 +5208,13 @@ void Fortran::lower::genOpenMPRequires(mlir::Operation *mod,
if (auto offloadMod =
llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) {
- semantics::WithOmpDeclarative::RequiresClauses reqs;
+ semantics::WithOmpDeclarative::OmpClauseSet reqs;
if (symbol) {
common::visit(
[&](const auto &details) {
if constexpr (std::is_base_of_v<semantics::WithOmpDeclarative,
std::decay_t<decltype(details)>>) {
- if (details.has_ompRequires())
- reqs = *details.ompRequires();
+ reqs = details.ompRequires();
}
},
symbol->details());
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index c450c3fbfeb43..f44cdfd4e6768 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -614,9 +614,7 @@ bool OmpStructureChecker::HasRequires(llvm::omp::Clause req) {
[&](const auto &details) {
if constexpr (std::is_convertible_v<decltype(details),
const WithOmpDeclarative &>) {
- if (auto *reqs{details.ompRequires()}) {
- return reqs->test(req);
- }
+ return details.ompRequires().test(req);
}
return false;
},
diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp
index f5e66a04c3f11..6e89ae92b882a 100644
--- a/flang/lib/Semantics/mod-file.cpp
+++ b/flang/lib/Semantics/mod-file.cpp
@@ -364,34 +364,36 @@ void ModFileWriter::PrepareRenamings(const Scope &scope) {
}
}
-static void PutOpenMPRequirements(llvm::raw_ostream &os, const Symbol &symbol) {
- using RequiresClauses = WithOmpDeclarative::RequiresClauses;
+static void PutOpenMPRequirements(
+ llvm::raw_ostream &os, const Symbol &symbol, SemanticsContext &semaCtx) {
+ using OmpClauseSet = WithOmpDeclarative::OmpClauseSet;
using OmpMemoryOrderType = common::OmpMemoryOrderType;
+ unsigned version{semaCtx.langOptions().OpenMPVersion};
const auto [reqs, order]{common::visit(
[&](auto &&details)
- -> std::pair<const RequiresClauses *, const OmpMemoryOrderType *> {
+ -> std::pair<const OmpClauseSet *, const OmpMemoryOrderType *> {
if constexpr (std::is_convertible_v<decltype(details),
const WithOmpDeclarative &>) {
- return {details.ompRequires(), details.ompAtomicDefaultMemOrder()};
+ if (const auto &memOrder{details.ompAtomicDefaultMemOrder()}) {
+ return {&details.ompRequires(), &*memOrder};
+ }
+ return {&details.ompRequires(), nullptr};
} else {
return {nullptr, nullptr};
}
},
symbol.details())};
- if (order) {
- llvm::omp::Clause admo{llvm::omp::Clause::OMPC_atomic_default_mem_order};
- os << "!$omp requires "
- << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(admo))
- << '(' << parser::ToLowerCaseLetters(EnumToString(*order)) << ")\n";
- }
- if (reqs) {
+ if (reqs->count()) {
os << "!$omp requires";
- reqs->IterateOverMembers([&](llvm::omp::Clause f) {
- if (f != llvm::omp::Clause::OMPC_atomic_default_mem_order) {
- os << ' '
- << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f));
+ reqs->IterateOverMembers([&, order = order](llvm::omp::Clause f) {
+ os << ' '
+ << parser::ToLowerCaseLetters(
+ llvm::omp::getOpenMPClauseName(f, version));
+ if (f == llvm::omp::Clause::OMPC_atomic_default_mem_order) {
+ os << '(' << parser::ToLowerCaseLetters(EnumToString(DEREF(order)))
+ << ')';
}
});
os << "\n";
@@ -435,7 +437,7 @@ void ModFileWriter::PutSymbols(
for (const Symbol &symbol : uses) {
PutUse(symbol);
}
- PutOpenMPRequirements(decls_, DEREF(scope.symbol()));
+ PutOpenMPRequirements(decls_, DEREF(scope.symbol()), context_);
for (const auto &set : scope.equivalenceSets()) {
if (!set.empty() &&
!set.front().symbol.test(Symbol::Flag::CompilerCreated)) {
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 2fa59adf7f3af..b86bf64d18bd3 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -638,34 +638,19 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
void Post(const parser::OpenMPFlushConstruct &) { PopContext(); }
bool Pre(const parser::OmpRequiresDirective &x) {
- using RequiresClauses = WithOmpDeclarative::RequiresClauses;
+ using OmpClauseSet = WithOmpDeclarative::OmpClauseSet;
PushContext(x.source, llvm::omp::Directive::OMPD_requires);
- auto getArgument{[&](auto &&maybeClause) {
- if (maybeClause) {
- // Scalar<Logical<Constant<common::Indirection<Expr>>>>
- auto &parserExpr{parser::UnwrapRef<parser::Expr>(*maybeClause)};
- evaluate::ExpressionAnalyzer ea{context_};
- if (auto &&maybeExpr{ea.Analyze(parserExpr)}) {
- if (auto v{omp::GetLogicalValue(*maybeExpr)}) {
- return *v;
- }
- }
- }
- // If the argument is missing, it is assumed to be true.
- return true;
- }};
-
// Gather information from the clauses.
- RequiresClauses reqs;
- const common::OmpMemoryOrderType *memOrder{nullptr};
+ OmpClauseSet reqs;
+ std::optional<common::OmpMemoryOrderType> memOrder;
for (const parser::OmpClause &clause : x.v.Clauses().v) {
using OmpClause = parser::OmpClause;
reqs |= common::visit(
common::visitors{
- [&](const OmpClause::AtomicDefaultMemOrder &atomic) {
- memOrder = &atomic.v.v;
- return RequiresClauses{};
+ [&](const OmpClause::AtomicDefaultMemOrder &admo) {
+ memOrder = admo.v.v;
+ return OmpClauseSet{clause.Id()};
},
[&](auto &&s) {
using TypeS = llvm::remove_cvref_t<decltype(s)>;
@@ -676,18 +661,18 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
std::is_same_v<TypeS, OmpClause::SelfMaps> ||
std::is_same_v<TypeS, OmpClause::UnifiedAddress> ||
std::is_same_v<TypeS, OmpClause::UnifiedSharedMemory>) {
- if (getArgument(s.v)) {
- return RequiresClauses{clause.Id()};
+ if (omp::GetLogicalArgument(s.v, context_).value_or(true)) {
+ return OmpClauseSet{clause.Id()};
}
}
- return RequiresClauses{};
+ return OmpClauseSet{};
},
},
clause.u);
}
// Merge clauses into parents' symbols details.
- AddOmpRequiresToScope(currScope(), &reqs, memOrder);
+ AddOmpRequiresToScope(currScope(), reqs, memOrder);
return true;
}
void Post(const parser::OmpRequiresDirective &) { PopContext(); }
@@ -1055,9 +1040,8 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
void CheckObjectIsPrivatizable(
const parser::Name &, const Symbol &, Symbol::Flag);
- void AddOmpRequiresToScope(Scope &,
- const WithOmpDeclarative::RequiresClauses *,
- const common::OmpMemoryOrderType *);
+ void AddOmpRequiresToScope(Scope &, const WithOmpDeclarative::OmpClauseSet &,
+ const std::optional<common::OmpMemoryOrderType> &);
void CreateImplicitSymbols(const parser::Name &, const Symbol *symbol);
@@ -3213,30 +3197,25 @@ void OmpAttributeVisitor::CheckObjectIsPrivatizable(
}
void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope,
- const WithOmpDeclarative::RequiresClauses *reqs,
- const common::OmpMemoryOrderType *memOrder) {
+ const WithOmpDeclarative::OmpClauseSet &reqs,
+ const std::optional<common::OmpMemoryOrderType> &memOrder) {
+ unsigned version{context_.langOptions().OpenMPVersion};
const Scope &programUnit{omp::GetProgramUnit(scope)};
- using RequiresClauses = WithOmpDeclarative::RequiresClauses;
- RequiresClauses combinedReqs{reqs ? *reqs : RequiresClauses{}};
if (auto *symbol{const_cast<Symbol *>(programUnit.symbol())}) {
common::visit(
[&](auto &details) {
if constexpr (std::is_convertible_v<decltype(&details),
WithOmpDeclarative *>) {
- if (combinedReqs.any()) {
- if (const RequiresClauses *otherReqs{details.ompRequires()}) {
- combinedReqs |= *otherReqs;
- }
- details.set_ompRequires(combinedReqs);
+ if (reqs.any()) {
+ details.set_ompRequires(reqs | details.ompRequires());
+ details.set_version(version);
}
if (memOrder) {
- if (details.has_ompAtomicDefaultMemOrder() &&
- *details.ompAtomicDefaultMemOrder() != *memOrder) {
- unsigned version{context_.langOptions().OpenMPVersion};
+ if (auto &admo{details.ompAtomicDefaultMemOrder()};
+ admo && *admo != *memOrder) {
context_.Say(programUnit.sourceRange(),
- "Conflicting '%s' REQUIRES clauses found in compilation "
- "unit"_err_en_US,
+ "Conflicting '%s' REQUIRES clauses found in compilation unit"_err_en_US,
parser::omp::GetUpperName(
llvm::omp::Clause::OMPC_atomic_default_mem_order,
version));
diff --git a/flang/lib/Semantics/symbol.cpp b/flang/lib/Semantics/symbol.cpp
index ed0715a422e78..7712091a03210 100644
--- a/flang/lib/Semantics/symbol.cpp
+++ b/flang/lib/Semantics/symbol.cpp
@@ -72,25 +72,35 @@ static void DumpList(llvm::raw_ostream &os, const char *label, const T &list) {
llvm::raw_ostream &operator<<(
llvm::raw_ostream &os, const WithOmpDeclarative &x) {
- if (x.has_ompRequires() || x.has_ompAtomicDefaultMemOrder()) {
- os << " OmpRequirements:(";
- if (const common::OmpMemoryOrderType *admo{x.ompAtomicDefaultMemOrder()}) {
- os << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(
- llvm::omp::Clause::OMPC_atomic_default_mem_order))
- << '(' << parser::ToLowerCaseLetters(EnumToString(*admo)) << ')';
- if (x.has_ompRequires()) {
+ using OmpClauseSet = WithOmpDeclarative::OmpClauseSet;
+
+ auto toLower = [](std::string_view sv) {
+ return parser::ToLowerCaseLetters(sv);
+ };
+ auto getLowerName = [&](llvm::omp::Clause c) {
+ return toLower(llvm::omp::getOpenMPClauseName(c, x.version_));
+ };
+ auto printClauses = [&](const OmpClauseSet &cs) {
+ size_t idx{0}, size{cs.count()};
+ cs.IterateOverMembers([&](llvm::omp::Clause c) {
+ os << getLowerName(c);
+ switch (c) {
+ case llvm::omp::Clause::OMPC_atomic_default_mem_order:
+ os << '(' << toLower(EnumToString(*x.ompAtomicDefaultMemOrder()))
+ << ')';
+ break;
+ default:
+ break;
+ }
+ if (++idx < size) {
os << ',';
}
- }
- if (const WithOmpDeclarative::RequiresClauses *reqs{x.ompRequires()}) {
- size_t num{0}, size{reqs->count()};
- reqs->IterateOverMembers([&](llvm::omp::Clause f) {
- os << parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f));
- if (++num < size) {
- os << ',';
- }
- });
- }
+ });
+ };
+
+ if (const OmpClauseSet &reqs{x.ompRequires()}; reqs.count()) {
+ os << " OmpRequirements:(";
+ printClauses(reqs);
os << ')';
}
return os;
diff --git a/flang/test/Semantics/OpenMP/requires-modfile.f90 b/flang/test/Semantics/OpenMP/requires-modfile.f90
index 52a43c2ef37ac..76d83458f4dc2 100644
--- a/flang/test/Semantics/OpenMP/requires-modfile.f90
+++ b/flang/test/Semantics/OpenMP/requires-modfile.f90
@@ -29,8 +29,7 @@ module fold
!Expect: req.mod
!module req
-!!$omp requires atomic_default_mem_order(seq_cst)
-!!$omp requires reverse_offload
+!!$omp requires atomic_default_mem_order(seq_cst) reverse_offload
!contains
!subroutine f00()
!end
@@ -42,8 +41,7 @@ module fold
!module user
!use req,only:f00
!use req,only:f01
-!!$omp requires atomic_default_mem_order(seq_cst)
-!!$omp requires reverse_offload
+!!$omp requires atomic_default_mem_order(seq_cst) reverse_offload
!end
!Expect: fold.mod
>From c85545ddf158323dad51cf454d3dcd410f9ca0d5 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 3 Jun 2026 07:11:34 -0500
Subject: [PATCH 2/2] Siplify function aliases
---
flang/lib/Semantics/symbol.cpp | 10 +++-------
1 file changed, 3 insertions(+), 7 deletions(-)
diff --git a/flang/lib/Semantics/symbol.cpp b/flang/lib/Semantics/symbol.cpp
index 7712091a03210..65bd6fc265d53 100644
--- a/flang/lib/Semantics/symbol.cpp
+++ b/flang/lib/Semantics/symbol.cpp
@@ -74,16 +74,12 @@ llvm::raw_ostream &operator<<(
llvm::raw_ostream &os, const WithOmpDeclarative &x) {
using OmpClauseSet = WithOmpDeclarative::OmpClauseSet;
- auto toLower = [](std::string_view sv) {
- return parser::ToLowerCaseLetters(sv);
- };
- auto getLowerName = [&](llvm::omp::Clause c) {
- return toLower(llvm::omp::getOpenMPClauseName(c, x.version_));
- };
+ auto toLower = parser::ToLowerCaseLetters;
+
auto printClauses = [&](const OmpClauseSet &cs) {
size_t idx{0}, size{cs.count()};
cs.IterateOverMembers([&](llvm::omp::Clause c) {
- os << getLowerName(c);
+ os << toLower(llvm::omp::getOpenMPClauseName(c, x.version_));
switch (c) {
case llvm::omp::Clause::OMPC_atomic_default_mem_order:
os << '(' << toLower(EnumToString(*x.ompAtomicDefaultMemOrder()))
More information about the flang-commits
mailing list