[flang-commits] [flang] [flang][OpenMP] Expand GetOmpObjectList to all subclasses of OmpClause (PR #170351)
Krzysztof Parzyszek via flang-commits
flang-commits at lists.llvm.org
Tue Dec 2 11:08:52 PST 2025
https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/170351
Use GetOmpObjectList instead of extracting the object list by hand.
>From d31bdf1b33b13733968b91f2c8381dafe3d0cbc9 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 2 Dec 2025 09:04:38 -0600
Subject: [PATCH] [flang][OpenMP] Expand GetOmpObjectList to all subclasses of
OmpClause
Use GetOmpObjectList instead of extracting the object list by hand.
---
flang/include/flang/Parser/openmp-utils.h | 56 +++++++++++++++
flang/lib/Parser/openmp-utils.cpp | 45 +++----------
flang/lib/Semantics/check-omp-loop.cpp | 9 +--
flang/lib/Semantics/check-omp-structure.cpp | 75 +++++++++------------
flang/lib/Semantics/resolve-directives.cpp | 29 ++++----
5 files changed, 117 insertions(+), 97 deletions(-)
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index b7d990c9e75d6..90dbe3f893130 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -14,6 +14,7 @@
#define FORTRAN_PARSER_OPENMP_UTILS_H
#include "flang/Common/indirection.h"
+#include "flang/Common/template.h"
#include "flang/Parser/parse-tree.h"
#include "llvm/Frontend/OpenMP/OMP.h"
@@ -127,7 +128,62 @@ template <typename T> struct IsStatement<Statement<T>> {
std::optional<Label> GetStatementLabel(const ExecutionPartConstruct &x);
std::optional<Label> GetFinalLabel(const OpenMPConstruct &x);
+namespace detail {
+// Clauses with OmpObjectList as its data member
+using MemberObjectListClauses =
+ std::tuple<OmpClause::Copyin, OmpClause::Copyprivate, OmpClause::Exclusive,
+ OmpClause::Firstprivate, OmpClause::HasDeviceAddr, OmpClause::Inclusive,
+ OmpClause::IsDevicePtr, OmpClause::Link, OmpClause::Private,
+ OmpClause::Shared, OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
+
+// Clauses with OmpObjectList in the tuple
+using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
+ OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
+ OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
+ OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
+ OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
+
+template <typename...> struct WrappedInType;
+
+template <typename T> struct WrappedInType<T> {
+ static constexpr bool value{false};
+};
+
+template <typename T, typename U, typename... Us>
+struct WrappedInType<T, U, Us...> {
+ static constexpr bool value{//
+ std::is_same_v<T, decltype(U::v)> || WrappedInType<T, Us...>::value};
+};
+
+template <typename...> struct WrappedInTuple {
+ static constexpr bool value{false};
+};
+template <typename T, typename... Us>
+struct WrappedInTuple<T, std::tuple<Us...>> {
+ static constexpr bool value{WrappedInType<T, Us...>::value};
+};
+template <typename T, typename U>
+constexpr bool WrappedInTupleV{WrappedInTuple<T, U>::value};
+} // namespace detail
+
+template <typename T> const OmpObjectList *GetOmpObjectList(const T &clause) {
+ using namespace detail;
+
+ if constexpr (common::HasMember<T, MemberObjectListClauses>) {
+ return &clause.v;
+ } else if constexpr (common::HasMember<T, TupleObjectListClauses>) {
+ return &std::get<OmpObjectList>(clause.v.t);
+ } else if constexpr (WrappedInTupleV<T, TupleObjectListClauses>) {
+ return &std::get<OmpObjectList>(clause.t);
+ } else {
+ static_assert(std::is_class_v<T>, "Unexpected argument type");
+ return nullptr;
+ }
+}
+
const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
+const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause);
+const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x);
template <typename T>
const T *GetFirstArgument(const OmpDirectiveSpecification &spec) {
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index ab2ed0641f4c7..506442015aae2 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -117,43 +117,20 @@ std::optional<Label> GetFinalLabel(const OpenMPConstruct &x) {
}
const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
- // Clauses with OmpObjectList as its data member
- using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
- OmpClause::Copyprivate, OmpClause::Exclusive, OmpClause::Firstprivate,
- OmpClause::HasDeviceAddr, OmpClause::Inclusive, OmpClause::IsDevicePtr,
- OmpClause::Link, OmpClause::Private, OmpClause::Shared,
- OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
-
- // Clauses with OmpObjectList in the tuple
- using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
- OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
- OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
- OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
- OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
-
- // TODO:: Generate the tuples using TableGen.
+ return common::visit([](auto &&s) { return GetOmpObjectList(s); }, clause.u);
+}
+
+const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause) {
return common::visit(
common::visitors{
- [&](const OmpClause::Depend &x) -> const OmpObjectList * {
- if (auto *taskDep{std::get_if<OmpDependClause::TaskDep>(&x.v.u)}) {
- return &std::get<OmpObjectList>(taskDep->t);
- } else {
- return nullptr;
- }
- },
- [&](const auto &x) -> const OmpObjectList * {
- using Ty = std::decay_t<decltype(x)>;
- if constexpr (common::HasMember<Ty, MemberObjectListClauses>) {
- return &x.v;
- } else if constexpr (common::HasMember<Ty,
- TupleObjectListClauses>) {
- return &std::get<OmpObjectList>(x.v.t);
- } else {
- return nullptr;
- }
- },
+ [](const OmpDoacross &) { return nullptr; },
+ [](const OmpDependClause::TaskDep &x) { return GetOmpObjectList(x); },
},
- clause.u);
+ clause.v.u);
+}
+
+const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x) {
+ return &std::get<OmpObjectList>(x.t);
}
const BlockConstruct *GetFortranBlockConstruct(
diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp
index 9a78209369949..b705a15efd4b1 100644
--- a/flang/lib/Semantics/check-omp-loop.cpp
+++ b/flang/lib/Semantics/check-omp-loop.cpp
@@ -480,9 +480,8 @@ void OmpStructureChecker::CheckDistLinear(
// Collect symbols of all the variables from linear clauses
for (auto &clause : clauses.v) {
- if (auto *linearClause{std::get_if<parser::OmpClause::Linear>(&clause.u)}) {
- auto &objects{std::get<parser::OmpObjectList>(linearClause->v.t)};
- GetSymbolsInObjectList(objects, indexVars);
+ if (std::get_if<parser::OmpClause::Linear>(&clause.u)) {
+ GetSymbolsInObjectList(*parser::omp::GetOmpObjectList(clause), indexVars);
}
}
@@ -604,8 +603,6 @@ void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
auto *maybeModifier{OmpGetUniqueModifier<ReductionModifier>(modifiers)};
if (maybeModifier &&
maybeModifier->v == ReductionModifier::Value::Inscan) {
- const auto &objectList{
- std::get<parser::OmpObjectList>(reductionClause->v.t)};
auto checkReductionSymbolInScan = [&](const parser::Name *name) {
if (auto &symbol = name->symbol) {
if (!symbol->test(Symbol::Flag::OmpInclusiveScan) &&
@@ -618,7 +615,7 @@ void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
}
}
};
- for (const auto &ompObj : objectList.v) {
+ for (const auto &ompObj : parser::omp::GetOmpObjectList(clause)->v) {
common::visit(
common::visitors{
[&](const parser::Designator &designator) {
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index f7778472f71f1..6e2662bb2a34f 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -624,11 +624,9 @@ void OmpStructureChecker::CheckMultListItems() {
// Linear clause
for (auto [_, clause] : FindClauses(llvm::omp::Clause::OMPC_linear)) {
- auto &linearClause{std::get<parser::OmpClause::Linear>(clause->u)};
std::list<parser::Name> nameList;
SymbolSourceMap symbols;
- GetSymbolsInObjectList(
- std::get<parser::OmpObjectList>(linearClause.v.t), symbols);
+ GetSymbolsInObjectList(*GetOmpObjectList(*clause), symbols);
llvm::transform(symbols, std::back_inserter(nameList), [&](auto &&pair) {
return parser::Name{pair.second, const_cast<Symbol *>(pair.first)};
});
@@ -2101,29 +2099,29 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclareTargetConstruct &x) {
}
}
- bool toClauseFound{false}, deviceTypeClauseFound{false},
- enterClauseFound{false};
+ bool toClauseFound{false};
+ bool deviceTypeClauseFound{false};
+ bool enterClauseFound{false};
for (const parser::OmpClause &clause : x.v.Clauses().v) {
common::visit(
common::visitors{
- [&](const parser::OmpClause::To &toClause) {
- toClauseFound = true;
- auto &objList{std::get<parser::OmpObjectList>(toClause.v.t)};
- CheckSymbolNames(dirName.source, objList);
- CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
- CheckThreadprivateOrDeclareTargetVar(objList);
- },
- [&](const parser::OmpClause::Link &linkClause) {
- CheckSymbolNames(dirName.source, linkClause.v);
- CheckVarIsNotPartOfAnotherVar(dirName.source, linkClause.v);
- CheckThreadprivateOrDeclareTargetVar(linkClause.v);
- },
- [&](const parser::OmpClause::Enter &enterClause) {
- enterClauseFound = true;
- auto &objList{std::get<parser::OmpObjectList>(enterClause.v.t)};
- CheckSymbolNames(dirName.source, objList);
- CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
- CheckThreadprivateOrDeclareTargetVar(objList);
+ [&](const auto &c) {
+ using TypeC = llvm::remove_cvref_t<decltype(c)>;
+ if constexpr ( //
+ std::is_same_v<TypeC, parser::OmpClause::Enter> ||
+ std::is_same_v<TypeC, parser::OmpClause::Link> ||
+ std::is_same_v<TypeC, parser::OmpClause::To>) {
+ auto &objList{*GetOmpObjectList(c)};
+ CheckSymbolNames(dirName.source, objList);
+ CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
+ CheckThreadprivateOrDeclareTargetVar(objList);
+ }
+ if constexpr (std::is_same_v<TypeC, parser::OmpClause::Enter>) {
+ enterClauseFound = true;
+ }
+ if constexpr (std::is_same_v<TypeC, parser::OmpClause::To>) {
+ toClauseFound = true;
+ }
},
[&](const parser::OmpClause::DeviceType &deviceTypeClause) {
deviceTypeClauseFound = true;
@@ -2134,7 +2132,6 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclareTargetConstruct &x) {
deviceConstructFound_ = true;
}
},
- [&](const auto &) {},
},
clause.u);
@@ -2424,12 +2421,8 @@ void OmpStructureChecker::CheckTargetUpdate() {
}
if (toWrapper && fromWrapper) {
SymbolSourceMap toSymbols, fromSymbols;
- auto &fromClause{std::get<parser::OmpClause::From>(fromWrapper->u).v};
- auto &toClause{std::get<parser::OmpClause::To>(toWrapper->u).v};
- GetSymbolsInObjectList(
- std::get<parser::OmpObjectList>(fromClause.t), fromSymbols);
- GetSymbolsInObjectList(
- std::get<parser::OmpObjectList>(toClause.t), toSymbols);
+ GetSymbolsInObjectList(*GetOmpObjectList(*fromWrapper), fromSymbols);
+ GetSymbolsInObjectList(*GetOmpObjectList(*toWrapper), toSymbols);
for (auto &[symbol, source] : toSymbols) {
auto fromSymbol{fromSymbols.find(symbol)};
@@ -3269,7 +3262,7 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) {
const auto &irClause{
std::get<parser::OmpClause::InReduction>(dataEnvClause->u)};
checkVarAppearsInDataEnvClause(
- std::get<parser::OmpObjectList>(irClause.v.t), "IN_REDUCTION");
+ *GetOmpObjectList(irClause), "IN_REDUCTION");
}
}
}
@@ -3436,7 +3429,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Destroy &x) {
void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_reduction);
- auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
+ auto &objects{*GetOmpObjectList(x)};
if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_reduction,
GetContext().clauseSource, context_)) {
@@ -3476,7 +3469,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
void OmpStructureChecker::Enter(const parser::OmpClause::InReduction &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_in_reduction);
- auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
+ auto &objects{*GetOmpObjectList(x)};
if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_in_reduction,
GetContext().clauseSource, context_)) {
@@ -3494,7 +3487,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::InReduction &x) {
void OmpStructureChecker::Enter(const parser::OmpClause::TaskReduction &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_task_reduction);
- auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
+ auto &objects{*GetOmpObjectList(x)};
if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_task_reduction,
GetContext().clauseSource, context_)) {
@@ -4347,8 +4340,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) {
}};
evaluate::ExpressionAnalyzer ea{context_};
- const auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
- for (auto &object : objects.v) {
+ for (auto &object : GetOmpObjectList(x)->v) {
if (const parser::Designator *d{GetDesignatorFromObj(object)}) {
if (auto &&expr{ea.Analyze(*d)}) {
if (hasBasePointer(*expr)) {
@@ -4501,7 +4493,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Depend &x) {
}
}
if (taskDep) {
- auto &objList{std::get<parser::OmpObjectList>(taskDep->t)};
+ auto &objList{*GetOmpObjectList(*taskDep)};
if (dir == llvm::omp::OMPD_depobj) {
// [5.0:255:13], [5.1:288:6], [5.2:322:26]
// A depend clause on a depobj construct must only specify one locator.
@@ -4647,7 +4639,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Copyprivate &x) {
void OmpStructureChecker::Enter(const parser::OmpClause::Lastprivate &x) {
CheckAllowedClause(llvm::omp::Clause::OMPC_lastprivate);
- const auto &objectList{std::get<parser::OmpObjectList>(x.v.t)};
+ const auto &objectList{*GetOmpObjectList(x)};
CheckVarIsNotPartOfAnotherVar(
GetContext().clauseSource, objectList, "LASTPRIVATE");
CheckCrayPointee(objectList, "LASTPRIVATE");
@@ -4889,9 +4881,8 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Enter &x) {
x.v, llvm::omp::OMPC_enter, GetContext().clauseSource, context_)) {
return;
}
- const parser::OmpObjectList &objList{std::get<parser::OmpObjectList>(x.v.t)};
SymbolSourceMap symbols;
- GetSymbolsInObjectList(objList, symbols);
+ GetSymbolsInObjectList(*GetOmpObjectList(x), symbols);
for (const auto &[symbol, source] : symbols) {
if (!IsExtendedListItem(*symbol)) {
context_.SayWithDecl(*symbol, source,
@@ -4914,7 +4905,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::From &x) {
CheckIteratorModifier(*iter);
}
- const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
+ const auto &objList{*GetOmpObjectList(x)};
SymbolSourceMap symbols;
GetSymbolsInObjectList(objList, symbols);
CheckVariableListItem(symbols);
@@ -4954,7 +4945,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::To &x) {
CheckIteratorModifier(*iter);
}
- const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
+ const auto &objList{*GetOmpObjectList(x)};
SymbolSourceMap symbols;
GetSymbolsInObjectList(objList, symbols);
CheckVariableListItem(symbols);
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index b992a4125ffcb..5d77eaa35fc08 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -717,8 +717,8 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
return false;
}
bool Pre(const parser::OmpAllocateClause &x) {
- const auto &objectList{std::get<parser::OmpObjectList>(x.t)};
- ResolveOmpObjectList(objectList, Symbol::Flag::OmpAllocate);
+ ResolveOmpObjectList(
+ *parser::omp::GetOmpObjectList(x), Symbol::Flag::OmpAllocate);
return false;
}
bool Pre(const parser::OmpClause::Firstprivate &x) {
@@ -726,8 +726,8 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
return false;
}
bool Pre(const parser::OmpClause::Lastprivate &x) {
- const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
- ResolveOmpObjectList(objList, Symbol::Flag::OmpLastPrivate);
+ ResolveOmpObjectList(
+ *parser::omp::GetOmpObjectList(x), Symbol::Flag::OmpLastPrivate);
return false;
}
bool Pre(const parser::OmpClause::Copyin &x) {
@@ -739,8 +739,8 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
return false;
}
bool Pre(const parser::OmpLinearClause &x) {
- auto &objects{std::get<parser::OmpObjectList>(x.t)};
- ResolveOmpObjectList(objects, Symbol::Flag::OmpLinear);
+ ResolveOmpObjectList(
+ *parser::omp::GetOmpObjectList(x), Symbol::Flag::OmpLinear);
return false;
}
@@ -750,13 +750,13 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
}
bool Pre(const parser::OmpInReductionClause &x) {
- auto &objects{std::get<parser::OmpObjectList>(x.t)};
- ResolveOmpObjectList(objects, Symbol::Flag::OmpInReduction);
+ ResolveOmpObjectList(
+ *parser::omp::GetOmpObjectList(x), Symbol::Flag::OmpInReduction);
return false;
}
bool Pre(const parser::OmpClause::Reduction &x) {
- const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
+ const auto &objList{*parser::omp::GetOmpObjectList(x)};
ResolveOmpObjectList(objList, Symbol::Flag::OmpReduction);
if (auto &modifiers{OmpGetModifiers(x.v)}) {
@@ -806,8 +806,8 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
}
bool Pre(const parser::OmpAlignedClause &x) {
- const auto &alignedNameList{std::get<parser::OmpObjectList>(x.t)};
- ResolveOmpObjectList(alignedNameList, Symbol::Flag::OmpAligned);
+ ResolveOmpObjectList(
+ *parser::omp::GetOmpObjectList(x), Symbol::Flag::OmpAligned);
return false;
}
@@ -920,7 +920,7 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
}
}
- const auto &ompObjList{std::get<parser::OmpObjectList>(x.t)};
+ const auto &ompObjList{*parser::omp::GetOmpObjectList(x)};
for (const auto &ompObj : ompObjList.v) {
common::visit(
common::visitors{
@@ -2566,9 +2566,8 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPAllocatorsConstruct &x) {
PushContext(x.source, dirSpec.DirId());
for (const auto &clause : dirSpec.Clauses().v) {
- if (const auto *allocClause{
- std::get_if<parser::OmpClause::Allocate>(&clause.u)}) {
- ResolveOmpObjectList(std::get<parser::OmpObjectList>(allocClause->v.t),
+ if (std::get_if<parser::OmpClause::Allocate>(&clause.u)) {
+ ResolveOmpObjectList(*parser::omp::GetOmpObjectList(x),
Symbol::Flag::OmpExecutableAllocateDirective);
}
}
More information about the flang-commits
mailing list