[flang-commits] [flang] [flang][OpenMP] Expand GetOmpObjectList to all subclasses of OmpClause (PR #170351)

Krzysztof Parzyszek via flang-commits flang-commits at lists.llvm.org
Tue Dec 2 11:08:52 PST 2025


https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/170351

Use GetOmpObjectList instead of extracting the object list by hand.

>From d31bdf1b33b13733968b91f2c8381dafe3d0cbc9 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 2 Dec 2025 09:04:38 -0600
Subject: [PATCH] [flang][OpenMP] Expand GetOmpObjectList to all subclasses of
 OmpClause

Use GetOmpObjectList instead of extracting the object list by hand.
---
 flang/include/flang/Parser/openmp-utils.h   | 56 +++++++++++++++
 flang/lib/Parser/openmp-utils.cpp           | 45 +++----------
 flang/lib/Semantics/check-omp-loop.cpp      |  9 +--
 flang/lib/Semantics/check-omp-structure.cpp | 75 +++++++++------------
 flang/lib/Semantics/resolve-directives.cpp  | 29 ++++----
 5 files changed, 117 insertions(+), 97 deletions(-)

diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index b7d990c9e75d6..90dbe3f893130 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -14,6 +14,7 @@
 #define FORTRAN_PARSER_OPENMP_UTILS_H
 
 #include "flang/Common/indirection.h"
+#include "flang/Common/template.h"
 #include "flang/Parser/parse-tree.h"
 #include "llvm/Frontend/OpenMP/OMP.h"
 
@@ -127,7 +128,62 @@ template <typename T> struct IsStatement<Statement<T>> {
 std::optional<Label> GetStatementLabel(const ExecutionPartConstruct &x);
 std::optional<Label> GetFinalLabel(const OpenMPConstruct &x);
 
+namespace detail {
+// Clauses with OmpObjectList as its data member
+using MemberObjectListClauses =
+    std::tuple<OmpClause::Copyin, OmpClause::Copyprivate, OmpClause::Exclusive,
+        OmpClause::Firstprivate, OmpClause::HasDeviceAddr, OmpClause::Inclusive,
+        OmpClause::IsDevicePtr, OmpClause::Link, OmpClause::Private,
+        OmpClause::Shared, OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
+
+// Clauses with OmpObjectList in the tuple
+using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
+    OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
+    OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
+    OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
+    OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
+
+template <typename...> struct WrappedInType;
+
+template <typename T> struct WrappedInType<T> {
+  static constexpr bool value{false};
+};
+
+template <typename T, typename U, typename... Us>
+struct WrappedInType<T, U, Us...> {
+  static constexpr bool value{//
+      std::is_same_v<T, decltype(U::v)> || WrappedInType<T, Us...>::value};
+};
+
+template <typename...> struct WrappedInTuple {
+  static constexpr bool value{false};
+};
+template <typename T, typename... Us>
+struct WrappedInTuple<T, std::tuple<Us...>> {
+  static constexpr bool value{WrappedInType<T, Us...>::value};
+};
+template <typename T, typename U>
+constexpr bool WrappedInTupleV{WrappedInTuple<T, U>::value};
+} // namespace detail
+
+template <typename T> const OmpObjectList *GetOmpObjectList(const T &clause) {
+  using namespace detail;
+
+  if constexpr (common::HasMember<T, MemberObjectListClauses>) {
+    return &clause.v;
+  } else if constexpr (common::HasMember<T, TupleObjectListClauses>) {
+    return &std::get<OmpObjectList>(clause.v.t);
+  } else if constexpr (WrappedInTupleV<T, TupleObjectListClauses>) {
+    return &std::get<OmpObjectList>(clause.t);
+  } else {
+    static_assert(std::is_class_v<T>, "Unexpected argument type");
+    return nullptr;
+  }
+}
+
 const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
+const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause);
+const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x);
 
 template <typename T>
 const T *GetFirstArgument(const OmpDirectiveSpecification &spec) {
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index ab2ed0641f4c7..506442015aae2 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -117,43 +117,20 @@ std::optional<Label> GetFinalLabel(const OpenMPConstruct &x) {
 }
 
 const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
-  // Clauses with OmpObjectList as its data member
-  using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
-      OmpClause::Copyprivate, OmpClause::Exclusive, OmpClause::Firstprivate,
-      OmpClause::HasDeviceAddr, OmpClause::Inclusive, OmpClause::IsDevicePtr,
-      OmpClause::Link, OmpClause::Private, OmpClause::Shared,
-      OmpClause::UseDeviceAddr, OmpClause::UseDevicePtr>;
-
-  // Clauses with OmpObjectList in the tuple
-  using TupleObjectListClauses = std::tuple<OmpClause::AdjustArgs,
-      OmpClause::Affinity, OmpClause::Aligned, OmpClause::Allocate,
-      OmpClause::Enter, OmpClause::From, OmpClause::InReduction,
-      OmpClause::Lastprivate, OmpClause::Linear, OmpClause::Map,
-      OmpClause::Reduction, OmpClause::TaskReduction, OmpClause::To>;
-
-  // TODO:: Generate the tuples using TableGen.
+  return common::visit([](auto &&s) { return GetOmpObjectList(s); }, clause.u);
+}
+
+const OmpObjectList *GetOmpObjectList(const OmpClause::Depend &clause) {
   return common::visit(
       common::visitors{
-          [&](const OmpClause::Depend &x) -> const OmpObjectList * {
-            if (auto *taskDep{std::get_if<OmpDependClause::TaskDep>(&x.v.u)}) {
-              return &std::get<OmpObjectList>(taskDep->t);
-            } else {
-              return nullptr;
-            }
-          },
-          [&](const auto &x) -> const OmpObjectList * {
-            using Ty = std::decay_t<decltype(x)>;
-            if constexpr (common::HasMember<Ty, MemberObjectListClauses>) {
-              return &x.v;
-            } else if constexpr (common::HasMember<Ty,
-                                     TupleObjectListClauses>) {
-              return &std::get<OmpObjectList>(x.v.t);
-            } else {
-              return nullptr;
-            }
-          },
+          [](const OmpDoacross &) { return nullptr; },
+          [](const OmpDependClause::TaskDep &x) { return GetOmpObjectList(x); },
       },
-      clause.u);
+      clause.v.u);
+}
+
+const OmpObjectList *GetOmpObjectList(const OmpDependClause::TaskDep &x) {
+  return &std::get<OmpObjectList>(x.t);
 }
 
 const BlockConstruct *GetFortranBlockConstruct(
diff --git a/flang/lib/Semantics/check-omp-loop.cpp b/flang/lib/Semantics/check-omp-loop.cpp
index 9a78209369949..b705a15efd4b1 100644
--- a/flang/lib/Semantics/check-omp-loop.cpp
+++ b/flang/lib/Semantics/check-omp-loop.cpp
@@ -480,9 +480,8 @@ void OmpStructureChecker::CheckDistLinear(
 
   // Collect symbols of all the variables from linear clauses
   for (auto &clause : clauses.v) {
-    if (auto *linearClause{std::get_if<parser::OmpClause::Linear>(&clause.u)}) {
-      auto &objects{std::get<parser::OmpObjectList>(linearClause->v.t)};
-      GetSymbolsInObjectList(objects, indexVars);
+    if (std::get_if<parser::OmpClause::Linear>(&clause.u)) {
+      GetSymbolsInObjectList(*parser::omp::GetOmpObjectList(clause), indexVars);
     }
   }
 
@@ -604,8 +603,6 @@ void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
       auto *maybeModifier{OmpGetUniqueModifier<ReductionModifier>(modifiers)};
       if (maybeModifier &&
           maybeModifier->v == ReductionModifier::Value::Inscan) {
-        const auto &objectList{
-            std::get<parser::OmpObjectList>(reductionClause->v.t)};
         auto checkReductionSymbolInScan = [&](const parser::Name *name) {
           if (auto &symbol = name->symbol) {
             if (!symbol->test(Symbol::Flag::OmpInclusiveScan) &&
@@ -618,7 +615,7 @@ void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
             }
           }
         };
-        for (const auto &ompObj : objectList.v) {
+        for (const auto &ompObj : parser::omp::GetOmpObjectList(clause)->v) {
           common::visit(
               common::visitors{
                   [&](const parser::Designator &designator) {
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index f7778472f71f1..6e2662bb2a34f 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -624,11 +624,9 @@ void OmpStructureChecker::CheckMultListItems() {
 
   // Linear clause
   for (auto [_, clause] : FindClauses(llvm::omp::Clause::OMPC_linear)) {
-    auto &linearClause{std::get<parser::OmpClause::Linear>(clause->u)};
     std::list<parser::Name> nameList;
     SymbolSourceMap symbols;
-    GetSymbolsInObjectList(
-        std::get<parser::OmpObjectList>(linearClause.v.t), symbols);
+    GetSymbolsInObjectList(*GetOmpObjectList(*clause), symbols);
     llvm::transform(symbols, std::back_inserter(nameList), [&](auto &&pair) {
       return parser::Name{pair.second, const_cast<Symbol *>(pair.first)};
     });
@@ -2101,29 +2099,29 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclareTargetConstruct &x) {
     }
   }
 
-  bool toClauseFound{false}, deviceTypeClauseFound{false},
-      enterClauseFound{false};
+  bool toClauseFound{false};
+  bool deviceTypeClauseFound{false};
+  bool enterClauseFound{false};
   for (const parser::OmpClause &clause : x.v.Clauses().v) {
     common::visit(
         common::visitors{
-            [&](const parser::OmpClause::To &toClause) {
-              toClauseFound = true;
-              auto &objList{std::get<parser::OmpObjectList>(toClause.v.t)};
-              CheckSymbolNames(dirName.source, objList);
-              CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
-              CheckThreadprivateOrDeclareTargetVar(objList);
-            },
-            [&](const parser::OmpClause::Link &linkClause) {
-              CheckSymbolNames(dirName.source, linkClause.v);
-              CheckVarIsNotPartOfAnotherVar(dirName.source, linkClause.v);
-              CheckThreadprivateOrDeclareTargetVar(linkClause.v);
-            },
-            [&](const parser::OmpClause::Enter &enterClause) {
-              enterClauseFound = true;
-              auto &objList{std::get<parser::OmpObjectList>(enterClause.v.t)};
-              CheckSymbolNames(dirName.source, objList);
-              CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
-              CheckThreadprivateOrDeclareTargetVar(objList);
+            [&](const auto &c) {
+              using TypeC = llvm::remove_cvref_t<decltype(c)>;
+              if constexpr ( //
+                  std::is_same_v<TypeC, parser::OmpClause::Enter> ||
+                  std::is_same_v<TypeC, parser::OmpClause::Link> ||
+                  std::is_same_v<TypeC, parser::OmpClause::To>) {
+                auto &objList{*GetOmpObjectList(c)};
+                CheckSymbolNames(dirName.source, objList);
+                CheckVarIsNotPartOfAnotherVar(dirName.source, objList);
+                CheckThreadprivateOrDeclareTargetVar(objList);
+              }
+              if constexpr (std::is_same_v<TypeC, parser::OmpClause::Enter>) {
+                enterClauseFound = true;
+              }
+              if constexpr (std::is_same_v<TypeC, parser::OmpClause::To>) {
+                toClauseFound = true;
+              }
             },
             [&](const parser::OmpClause::DeviceType &deviceTypeClause) {
               deviceTypeClauseFound = true;
@@ -2134,7 +2132,6 @@ void OmpStructureChecker::Leave(const parser::OpenMPDeclareTargetConstruct &x) {
                 deviceConstructFound_ = true;
               }
             },
-            [&](const auto &) {},
         },
         clause.u);
 
@@ -2424,12 +2421,8 @@ void OmpStructureChecker::CheckTargetUpdate() {
   }
   if (toWrapper && fromWrapper) {
     SymbolSourceMap toSymbols, fromSymbols;
-    auto &fromClause{std::get<parser::OmpClause::From>(fromWrapper->u).v};
-    auto &toClause{std::get<parser::OmpClause::To>(toWrapper->u).v};
-    GetSymbolsInObjectList(
-        std::get<parser::OmpObjectList>(fromClause.t), fromSymbols);
-    GetSymbolsInObjectList(
-        std::get<parser::OmpObjectList>(toClause.t), toSymbols);
+    GetSymbolsInObjectList(*GetOmpObjectList(*fromWrapper), fromSymbols);
+    GetSymbolsInObjectList(*GetOmpObjectList(*toWrapper), toSymbols);
 
     for (auto &[symbol, source] : toSymbols) {
       auto fromSymbol{fromSymbols.find(symbol)};
@@ -3269,7 +3262,7 @@ void OmpStructureChecker::Leave(const parser::OmpClauseList &) {
             const auto &irClause{
                 std::get<parser::OmpClause::InReduction>(dataEnvClause->u)};
             checkVarAppearsInDataEnvClause(
-                std::get<parser::OmpObjectList>(irClause.v.t), "IN_REDUCTION");
+                *GetOmpObjectList(irClause), "IN_REDUCTION");
           }
         }
       }
@@ -3436,7 +3429,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Destroy &x) {
 
 void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
   CheckAllowedClause(llvm::omp::Clause::OMPC_reduction);
-  auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
+  auto &objects{*GetOmpObjectList(x)};
 
   if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_reduction,
           GetContext().clauseSource, context_)) {
@@ -3476,7 +3469,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Reduction &x) {
 
 void OmpStructureChecker::Enter(const parser::OmpClause::InReduction &x) {
   CheckAllowedClause(llvm::omp::Clause::OMPC_in_reduction);
-  auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
+  auto &objects{*GetOmpObjectList(x)};
 
   if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_in_reduction,
           GetContext().clauseSource, context_)) {
@@ -3494,7 +3487,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::InReduction &x) {
 
 void OmpStructureChecker::Enter(const parser::OmpClause::TaskReduction &x) {
   CheckAllowedClause(llvm::omp::Clause::OMPC_task_reduction);
-  auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
+  auto &objects{*GetOmpObjectList(x)};
 
   if (OmpVerifyModifiers(x.v, llvm::omp::OMPC_task_reduction,
           GetContext().clauseSource, context_)) {
@@ -4347,8 +4340,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Map &x) {
     }};
 
     evaluate::ExpressionAnalyzer ea{context_};
-    const auto &objects{std::get<parser::OmpObjectList>(x.v.t)};
-    for (auto &object : objects.v) {
+    for (auto &object : GetOmpObjectList(x)->v) {
       if (const parser::Designator *d{GetDesignatorFromObj(object)}) {
         if (auto &&expr{ea.Analyze(*d)}) {
           if (hasBasePointer(*expr)) {
@@ -4501,7 +4493,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Depend &x) {
     }
   }
   if (taskDep) {
-    auto &objList{std::get<parser::OmpObjectList>(taskDep->t)};
+    auto &objList{*GetOmpObjectList(*taskDep)};
     if (dir == llvm::omp::OMPD_depobj) {
       // [5.0:255:13], [5.1:288:6], [5.2:322:26]
       // A depend clause on a depobj construct must only specify one locator.
@@ -4647,7 +4639,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Copyprivate &x) {
 void OmpStructureChecker::Enter(const parser::OmpClause::Lastprivate &x) {
   CheckAllowedClause(llvm::omp::Clause::OMPC_lastprivate);
 
-  const auto &objectList{std::get<parser::OmpObjectList>(x.v.t)};
+  const auto &objectList{*GetOmpObjectList(x)};
   CheckVarIsNotPartOfAnotherVar(
       GetContext().clauseSource, objectList, "LASTPRIVATE");
   CheckCrayPointee(objectList, "LASTPRIVATE");
@@ -4889,9 +4881,8 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Enter &x) {
           x.v, llvm::omp::OMPC_enter, GetContext().clauseSource, context_)) {
     return;
   }
-  const parser::OmpObjectList &objList{std::get<parser::OmpObjectList>(x.v.t)};
   SymbolSourceMap symbols;
-  GetSymbolsInObjectList(objList, symbols);
+  GetSymbolsInObjectList(*GetOmpObjectList(x), symbols);
   for (const auto &[symbol, source] : symbols) {
     if (!IsExtendedListItem(*symbol)) {
       context_.SayWithDecl(*symbol, source,
@@ -4914,7 +4905,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::From &x) {
     CheckIteratorModifier(*iter);
   }
 
-  const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
+  const auto &objList{*GetOmpObjectList(x)};
   SymbolSourceMap symbols;
   GetSymbolsInObjectList(objList, symbols);
   CheckVariableListItem(symbols);
@@ -4954,7 +4945,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::To &x) {
     CheckIteratorModifier(*iter);
   }
 
-  const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
+  const auto &objList{*GetOmpObjectList(x)};
   SymbolSourceMap symbols;
   GetSymbolsInObjectList(objList, symbols);
   CheckVariableListItem(symbols);
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index b992a4125ffcb..5d77eaa35fc08 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -717,8 +717,8 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
     return false;
   }
   bool Pre(const parser::OmpAllocateClause &x) {
-    const auto &objectList{std::get<parser::OmpObjectList>(x.t)};
-    ResolveOmpObjectList(objectList, Symbol::Flag::OmpAllocate);
+    ResolveOmpObjectList(
+        *parser::omp::GetOmpObjectList(x), Symbol::Flag::OmpAllocate);
     return false;
   }
   bool Pre(const parser::OmpClause::Firstprivate &x) {
@@ -726,8 +726,8 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
     return false;
   }
   bool Pre(const parser::OmpClause::Lastprivate &x) {
-    const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
-    ResolveOmpObjectList(objList, Symbol::Flag::OmpLastPrivate);
+    ResolveOmpObjectList(
+        *parser::omp::GetOmpObjectList(x), Symbol::Flag::OmpLastPrivate);
     return false;
   }
   bool Pre(const parser::OmpClause::Copyin &x) {
@@ -739,8 +739,8 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
     return false;
   }
   bool Pre(const parser::OmpLinearClause &x) {
-    auto &objects{std::get<parser::OmpObjectList>(x.t)};
-    ResolveOmpObjectList(objects, Symbol::Flag::OmpLinear);
+    ResolveOmpObjectList(
+        *parser::omp::GetOmpObjectList(x), Symbol::Flag::OmpLinear);
     return false;
   }
 
@@ -750,13 +750,13 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
   }
 
   bool Pre(const parser::OmpInReductionClause &x) {
-    auto &objects{std::get<parser::OmpObjectList>(x.t)};
-    ResolveOmpObjectList(objects, Symbol::Flag::OmpInReduction);
+    ResolveOmpObjectList(
+        *parser::omp::GetOmpObjectList(x), Symbol::Flag::OmpInReduction);
     return false;
   }
 
   bool Pre(const parser::OmpClause::Reduction &x) {
-    const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
+    const auto &objList{*parser::omp::GetOmpObjectList(x)};
     ResolveOmpObjectList(objList, Symbol::Flag::OmpReduction);
 
     if (auto &modifiers{OmpGetModifiers(x.v)}) {
@@ -806,8 +806,8 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
   }
 
   bool Pre(const parser::OmpAlignedClause &x) {
-    const auto &alignedNameList{std::get<parser::OmpObjectList>(x.t)};
-    ResolveOmpObjectList(alignedNameList, Symbol::Flag::OmpAligned);
+    ResolveOmpObjectList(
+        *parser::omp::GetOmpObjectList(x), Symbol::Flag::OmpAligned);
     return false;
   }
 
@@ -920,7 +920,7 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
       }
     }
 
-    const auto &ompObjList{std::get<parser::OmpObjectList>(x.t)};
+    const auto &ompObjList{*parser::omp::GetOmpObjectList(x)};
     for (const auto &ompObj : ompObjList.v) {
       common::visit(
           common::visitors{
@@ -2566,9 +2566,8 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPAllocatorsConstruct &x) {
   PushContext(x.source, dirSpec.DirId());
 
   for (const auto &clause : dirSpec.Clauses().v) {
-    if (const auto *allocClause{
-            std::get_if<parser::OmpClause::Allocate>(&clause.u)}) {
-      ResolveOmpObjectList(std::get<parser::OmpObjectList>(allocClause->v.t),
+    if (std::get_if<parser::OmpClause::Allocate>(&clause.u)) {
+      ResolveOmpObjectList(*parser::omp::GetOmpObjectList(x),
           Symbol::Flag::OmpExecutableAllocateDirective);
     }
   }



More information about the flang-commits mailing list