[flang-commits] [flang] [flang][OpenMP] Set REQUIRES flags on program unit symbol (PR #163448)
Krzysztof Parzyszek via flang-commits
flang-commits at lists.llvm.org
Tue Oct 14 13:48:32 PDT 2025
https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/163448
REQUIRES clauses apply to the compilation unit, which the OpenMP spec defines as the program unit in Fortran.
Don't set REQUIRES flags on all containing scopes, only on the containng program unit, where flags coming from different directives are gathered. If we wanted to set the flags on subprograms, we would need to first accummulate all of them, then propagate them down to all subprograms. That is not done as it is not necessary (the containing program unit is always available).
>From 4ccd96e07cc1fb910203c771ea5801208d99dde4 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Fri, 10 Oct 2025 10:55:48 -0500
Subject: [PATCH] [flang][OpenMP] Set REQUIRES flags on program unit symbol
REQUIRES clauses apply to the compilation unit, which the OpenMP spec
defines as the program unit in Fortran.
Don't set REQUIRES flags on all containing scopes, only on the containng
program unit, where flags coming from different directives are gathered.
If we wanted to set the flags on subprograms, we would need to first
accummulate all of them, then propagate them down to all subprograms.
That is not done as it is not necessary (the containing program unit is
always available).
---
flang/include/flang/Semantics/openmp-utils.h | 1 +
flang/lib/Semantics/openmp-utils.cpp | 23 +++++++-
flang/lib/Semantics/resolve-directives.cpp | 60 ++++++++++----------
3 files changed, 52 insertions(+), 32 deletions(-)
diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h
index 2954a1c4769f7..0f851830edd46 100644
--- a/flang/include/flang/Semantics/openmp-utils.h
+++ b/flang/include/flang/Semantics/openmp-utils.h
@@ -38,6 +38,7 @@ template <typename T, typename U = std::remove_const_t<T>> U AsRvalue(T &t) {
template <typename T> T &&AsRvalue(T &&t) { return std::move(t); }
const Scope &GetScopingUnit(const Scope &scope);
+const Scope &GetProgramUnit(const Scope &scope);
// There is no consistent way to get the source of an ActionStmt, but there
// is "source" in Statement<T>. This structure keeps the ActionStmt with the
diff --git a/flang/lib/Semantics/openmp-utils.cpp b/flang/lib/Semantics/openmp-utils.cpp
index a8ec4d6c24beb..292e73b4899c0 100644
--- a/flang/lib/Semantics/openmp-utils.cpp
+++ b/flang/lib/Semantics/openmp-utils.cpp
@@ -13,6 +13,7 @@
#include "flang/Semantics/openmp-utils.h"
#include "flang/Common/Fortran-consts.h"
+#include "flang/Common/idioms.h"
#include "flang/Common/indirection.h"
#include "flang/Common/reference.h"
#include "flang/Common/visit.h"
@@ -59,6 +60,26 @@ const Scope &GetScopingUnit(const Scope &scope) {
return *iter;
}
+const Scope &GetProgramUnit(const Scope &scope) {
+ const Scope *unit{nullptr};
+ for (const Scope *iter{&scope}; !iter->IsTopLevel(); iter = &iter->parent()) {
+ switch (iter->kind()) {
+ case Scope::Kind::BlockData:
+ case Scope::Kind::MainProgram:
+ case Scope::Kind::Module:
+ return *iter;
+ case Scope::Kind::Subprogram:
+ // Ignore subprograms that are nested.
+ unit = iter;
+ break;
+ default:
+ break;
+ }
+ }
+ assert(unit && "Scope not in a program unit");
+ return *unit;
+}
+
SourcedActionStmt GetActionStmt(const parser::ExecutionPartConstruct *x) {
if (x == nullptr) {
return SourcedActionStmt{};
@@ -202,7 +223,7 @@ std::optional<SomeExpr> GetEvaluateExpr(const parser::Expr &parserExpr) {
// ForwardOwningPointer typedExpr
// `- GenericExprWrapper ^.get()
// `- std::optional<Expr> ^->v
- return typedExpr.get()->v;
+ return DEREF(typedExpr.get()).v;
}
std::optional<evaluate::DynamicType> GetDynamicType(
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index 18fc63814d973..de680b41d1524 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -3549,40 +3549,38 @@ void OmpAttributeVisitor::CheckLabelContext(const parser::CharBlock source,
void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope,
WithOmpDeclarative::RequiresFlags flags,
std::optional<common::OmpMemoryOrderType> memOrder) {
- Scope *scopeIter = &scope;
- do {
- if (Symbol * symbol{scopeIter->symbol()}) {
- common::visit(
- [&](auto &details) {
- // Store clauses information into the symbol for the parent and
- // enclosing modules, programs, functions and subroutines.
- if constexpr (std::is_convertible_v<decltype(&details),
- WithOmpDeclarative *>) {
- if (flags.any()) {
- if (const WithOmpDeclarative::RequiresFlags *
- otherFlags{details.ompRequires()}) {
- flags |= *otherFlags;
- }
- details.set_ompRequires(flags);
+ const Scope &programUnit{omp::GetProgramUnit(scope)};
+
+ if (auto *symbol{const_cast<Symbol *>(programUnit.symbol())}) {
+ common::visit(
+ [&](auto &details) {
+ // Store clauses information into the symbol for the parent and
+ // enclosing modules, programs, functions and subroutines.
+ if constexpr (std::is_convertible_v<decltype(&details),
+ WithOmpDeclarative *>) {
+ if (flags.any()) {
+ if (const WithOmpDeclarative::RequiresFlags *otherFlags{
+ details.ompRequires()}) {
+ flags |= *otherFlags;
}
- if (memOrder) {
- if (details.has_ompAtomicDefaultMemOrder() &&
- *details.ompAtomicDefaultMemOrder() != *memOrder) {
- context_.Say(scopeIter->sourceRange(),
- "Conflicting '%s' REQUIRES clauses found in compilation "
- "unit"_err_en_US,
- parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
- llvm::omp::Clause::OMPC_atomic_default_mem_order)
- .str()));
- }
- details.set_ompAtomicDefaultMemOrder(*memOrder);
+ details.set_ompRequires(flags);
+ }
+ if (memOrder) {
+ if (details.has_ompAtomicDefaultMemOrder() &&
+ *details.ompAtomicDefaultMemOrder() != *memOrder) {
+ context_.Say(programUnit.sourceRange(),
+ "Conflicting '%s' REQUIRES clauses found in compilation "
+ "unit"_err_en_US,
+ parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
+ llvm::omp::Clause::OMPC_atomic_default_mem_order)
+ .str()));
}
+ details.set_ompAtomicDefaultMemOrder(*memOrder);
}
- },
- symbol->details());
- }
- scopeIter = &scopeIter->parent();
- } while (!scopeIter->IsGlobal());
+ }
+ },
+ symbol->details());
+ }
}
void OmpAttributeVisitor::IssueNonConformanceWarning(llvm::omp::Directive D,
More information about the flang-commits
mailing list