[llvm-branch-commits] [flang] [flang][OpenMP] Convert repeatable clauses (except Map) in ClauseProc… (PR #81623)

Krzysztof Parzyszek via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Feb 23 06:01:20 PST 2024


https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/81623

>From 655dce519efb87f8d3babf3b7a5d6132bb82e2a6 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 21 Feb 2024 15:51:38 -0600
Subject: [PATCH] [flang][OpenMP] Convert repeatable clauses (except Map) in
 ClauseProcessor

Rename `findRepeatableClause` to `findRepeatableClause2`, and make the
new `findRepeatableClause` operate on new `omp::Clause` objects.

Leave `Map` unchanged, because it will require more changes for it to
work.
---
 flang/include/flang/Evaluate/tools.h          |  23 ++
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    | 218 ++++++++----------
 flang/lib/Lower/OpenMP/ClauseProcessor.h      |  29 ++-
 flang/lib/Lower/OpenMP/Clauses.cpp            |   6 -
 flang/lib/Lower/OpenMP/Clauses.h              |   6 +
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 182 +++++++--------
 flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 155 ++++++-------
 flang/lib/Lower/OpenMP/ReductionProcessor.h   |  23 +-
 flang/lib/Lower/OpenMP/Utils.cpp              |  41 ++--
 flang/lib/Lower/OpenMP/Utils.h                |  10 +-
 10 files changed, 348 insertions(+), 345 deletions(-)

diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index d257da1a709642..e9999974944e88 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -430,6 +430,29 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) {
   }
 }
 
+struct ExtractSubstringHelper {
+  template <typename T> static std::optional<Substring> visit(T &&) {
+    return std::nullopt;
+  }
+
+  static std::optional<Substring> visit(const Substring &e) { return e; }
+
+  template <typename T>
+  static std::optional<Substring> visit(const Designator<T> &e) {
+    return std::visit([](auto &&s) { return visit(s); }, e.u);
+  }
+
+  template <typename T>
+  static std::optional<Substring> visit(const Expr<T> &e) {
+    return std::visit([](auto &&s) { return visit(s); }, e.u);
+  }
+};
+
+template <typename A>
+std::optional<Substring> ExtractSubstring(const A &x) {
+  return ExtractSubstringHelper::visit(x);
+}
+
 // If an expression is simply a whole symbol data designator,
 // extract and return that symbol, else null.
 template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 9987cd73fc7670..6e45a939333d62 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -87,7 +87,7 @@ getSimdModifier(const omp::clause::Schedule &clause) {
 
 static void
 genAllocateClause(Fortran::lower::AbstractConverter &converter,
-                  const Fortran::parser::OmpAllocateClause &ompAllocateClause,
+                  const omp::clause::Allocate &clause,
                   llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
                   llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -95,21 +95,18 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
 
   mlir::Value allocatorOperand;
-  const Fortran::parser::OmpObjectList &ompObjectList =
-      std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
-  const auto &allocateModifier = std::get<
-      std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>(
-      ompAllocateClause.t);
+  const omp::ObjectList &objectList = std::get<omp::ObjectList>(clause.t);
+  const auto &modifier =
+      std::get<std::optional<omp::clause::Allocate::Modifier>>(clause.t);
 
   // If the allocate modifier is present, check if we only use the allocator
   // submodifier.  ALIGN in this context is unimplemented
   const bool onlyAllocator =
-      allocateModifier &&
-      std::holds_alternative<
-          Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
-          allocateModifier->u);
+      modifier &&
+      std::holds_alternative<omp::clause::Allocate::Modifier::Allocator>(
+          modifier->u);
 
-  if (allocateModifier && !onlyAllocator) {
+  if (modifier && !onlyAllocator) {
     TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
   }
 
@@ -117,20 +114,17 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
   // to list of allocators, otherwise, add default allocator to
   // list of allocators.
   if (onlyAllocator) {
-    const auto &allocatorValue = std::get<
-        Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
-        allocateModifier->u);
-    allocatorOperand = fir::getBase(converter.genExprValue(
-        *Fortran::semantics::GetExpr(allocatorValue.v), stmtCtx));
-    allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
-                             allocatorOperand);
+    const auto &value =
+        std::get<omp::clause::Allocate::Modifier::Allocator>(modifier->u);
+    mlir::Value operand =
+        fir::getBase(converter.genExprValue(value.v, stmtCtx));
+    allocatorOperands.append(objectList.size(), operand);
   } else {
-    allocatorOperand = firOpBuilder.createIntegerConstant(
+    mlir::Value operand = firOpBuilder.createIntegerConstant(
         currentLocation, firOpBuilder.getI32Type(), 1);
-    allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
-                             allocatorOperand);
+    allocatorOperands.append(objectList.size(), operand);
   }
-  genObjectList(ompObjectList, converter, allocateOperands);
+  genObjectList(objectList, converter, allocateOperands);
 }
 
 static mlir::omp::ClauseProcBindKindAttr
@@ -157,20 +151,17 @@ genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder,
 
 static mlir::omp::ClauseTaskDependAttr
 genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
-                  const Fortran::parser::OmpClause::Depend *dependClause) {
+                  const omp::clause::Depend &clause) {
   mlir::omp::ClauseTaskDepend pbKind;
-  switch (
-      std::get<Fortran::parser::OmpDependenceType>(
-          std::get<Fortran::parser::OmpDependClause::InOut>(dependClause->v.u)
-              .t)
-          .v) {
-  case Fortran::parser::OmpDependenceType::Type::In:
+  const auto &inOut = std::get<omp::clause::Depend::InOut>(clause.u);
+  switch (std::get<omp::clause::Depend::Type>(inOut.t)) {
+  case omp::clause::Depend::Type::In:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependin;
     break;
-  case Fortran::parser::OmpDependenceType::Type::Out:
+  case omp::clause::Depend::Type::Out:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependout;
     break;
-  case Fortran::parser::OmpDependenceType::Type::Inout:
+  case omp::clause::Depend::Type::Inout:
     pbKind = mlir::omp::ClauseTaskDepend::taskdependinout;
     break;
   default:
@@ -181,45 +172,41 @@ genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
                                               pbKind);
 }
 
-static mlir::Value getIfClauseOperand(
-    Fortran::lower::AbstractConverter &converter,
-    const Fortran::parser::OmpClause::If *ifClause,
-    Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
-    mlir::Location clauseLocation) {
+static mlir::Value
+getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
+                   const omp::clause::If &clause,
+                   omp::clause::If::DirectiveNameModifier directiveName,
+                   mlir::Location clauseLocation) {
   // Only consider the clause if it's intended for the given directive.
-  auto &directive = std::get<
-      std::optional<Fortran::parser::OmpIfClause::DirectiveNameModifier>>(
-      ifClause->v.t);
+  auto &directive =
+      std::get<std::optional<omp::clause::If::DirectiveNameModifier>>(clause.t);
   if (directive && directive.value() != directiveName)
     return nullptr;
 
   Fortran::lower::StatementContext stmtCtx;
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
   mlir::Value ifVal = fir::getBase(
-      converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
+      converter.genExprValue(std::get<omp::SomeExpr>(clause.t), stmtCtx));
   return firOpBuilder.createConvert(clauseLocation, firOpBuilder.getI1Type(),
                                     ifVal);
 }
 
 static void
 addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
-                   const Fortran::parser::OmpObjectList &useDeviceClause,
+                   const omp::ObjectList &objects,
                    llvm::SmallVectorImpl<mlir::Value> &operands,
                    llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
                    llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
                    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
                        &useDeviceSymbols) {
-  genObjectList(useDeviceClause, converter, operands);
+  genObjectList(objects, converter, operands);
   for (mlir::Value &operand : operands) {
     checkMapType(operand.getLoc(), operand.getType());
     useDeviceTypes.push_back(operand.getType());
     useDeviceLocs.push_back(operand.getLoc());
   }
-  for (const Fortran::parser::OmpObject &ompObject : useDeviceClause.v) {
-    Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
-    useDeviceSymbols.push_back(sym);
-  }
+  for (const omp::Object &object : objects)
+    useDeviceSymbols.push_back(object.id());
 }
 
 //===----------------------------------------------------------------------===//
@@ -527,10 +514,10 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
 bool ClauseProcessor::processAllocate(
     llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
     llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const {
-  return findRepeatableClause<ClauseTy::Allocate>(
-      [&](const ClauseTy::Allocate *allocateClause,
+  return findRepeatableClause<omp::clause::Allocate>(
+      [&](const omp::clause::Allocate &clause,
           const Fortran::parser::CharBlock &) {
-        genAllocateClause(converter, allocateClause->v, allocatorOperands,
+        genAllocateClause(converter, clause, allocatorOperands,
                           allocateOperands);
       });
 }
@@ -547,12 +534,12 @@ bool ClauseProcessor::processCopyin() const {
         if (converter.isPresentShallowLookup(*sym))
           converter.copyHostAssociateVar(*sym, copyAssignIP);
       };
-  bool hasCopyin = findRepeatableClause<ClauseTy::Copyin>(
-      [&](const ClauseTy::Copyin *copyinClause,
+  bool hasCopyin = findRepeatableClause<omp::clause::Copyin>(
+      [&](const omp::clause::Copyin &clause,
           const Fortran::parser::CharBlock &) {
-        const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v;
-        for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
-          Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+        for (const omp::Object &object : clause.v) {
+          Fortran::semantics::Symbol *sym = object.id();
+          assert(sym && "Expecting symbol");
           if (const auto *commonDetails =
                   sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
             for (const auto &mem : commonDetails->objects())
@@ -716,13 +703,11 @@ bool ClauseProcessor::processCopyPrivate(
     copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
   };
 
-  bool hasCopyPrivate = findRepeatableClause<ClauseTy::Copyprivate>(
-      [&](const ClauseTy::Copyprivate *copyPrivateClause,
+  bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
+      [&](const clause::Copyprivate &clause,
           const Fortran::parser::CharBlock &) {
-        const Fortran::parser::OmpObjectList &ompObjectList =
-            copyPrivateClause->v;
-        for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
-          Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
+        for (const Object &object : clause.v) {
+          Fortran::semantics::Symbol *sym = object.id();
           if (const auto *commonDetails =
                   sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
             for (const auto &mem : commonDetails->objects())
@@ -741,38 +726,30 @@ bool ClauseProcessor::processDepend(
     llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
-  return findRepeatableClause<ClauseTy::Depend>(
-      [&](const ClauseTy::Depend *dependClause,
+  return findRepeatableClause<omp::clause::Depend>(
+      [&](const omp::clause::Depend &clause,
           const Fortran::parser::CharBlock &) {
-        const std::list<Fortran::parser::Designator> &depVal =
-            std::get<std::list<Fortran::parser::Designator>>(
-                std::get<Fortran::parser::OmpDependClause::InOut>(
-                    dependClause->v.u)
-                    .t);
+        assert(std::holds_alternative<omp::clause::Depend::InOut>(clause.u) &&
+               "Only InOut is handled at the moment");
+        const auto &inOut = std::get<omp::clause::Depend::InOut>(clause.u);
+        const auto &objects = std::get<omp::ObjectList>(inOut.t);
+
         mlir::omp::ClauseTaskDependAttr dependTypeOperand =
-            genDependKindAttr(firOpBuilder, dependClause);
-        dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
-                                  dependTypeOperand);
-        for (const Fortran::parser::Designator &ompObject : depVal) {
-          Fortran::semantics::Symbol *sym = nullptr;
-          std::visit(
-              Fortran::common::visitors{
-                  [&](const Fortran::parser::DataRef &designator) {
-                    if (const Fortran::parser::Name *name =
-                            std::get_if<Fortran::parser::Name>(&designator.u)) {
-                      sym = name->symbol;
-                    } else if (std::get_if<Fortran::common::Indirection<
-                                   Fortran::parser::ArrayElement>>(
-                                   &designator.u)) {
-                      TODO(converter.getCurrentLocation(),
-                           "array sections not supported for task depend");
-                    }
-                  },
-                  [&](const Fortran::parser::Substring &designator) {
-                    TODO(converter.getCurrentLocation(),
-                         "substring not supported for task depend");
-                  }},
-              (ompObject).u);
+            genDependKindAttr(firOpBuilder, clause);
+        dependTypeOperands.append(objects.size(), dependTypeOperand);
+
+        for (const omp::Object &object : objects) {
+          assert(object.ref() && "Expecting designator");
+
+          if (Fortran::evaluate::ExtractSubstring(*object.ref())) {
+            TODO(converter.getCurrentLocation(),
+                 "substring not supported for task depend");
+          } else if (Fortran::evaluate::IsArrayElement(*object.ref())) {
+            TODO(converter.getCurrentLocation(),
+                 "array sections not supported for task depend");
+          }
+
+          Fortran::semantics::Symbol *sym = object.id();
           const mlir::Value variable = converter.getSymbolAddress(*sym);
           dependOperands.push_back(variable);
         }
@@ -780,14 +757,14 @@ bool ClauseProcessor::processDepend(
 }
 
 bool ClauseProcessor::processIf(
-    Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
+    omp::clause::If::DirectiveNameModifier directiveName,
     mlir::Value &result) const {
   bool found = false;
-  findRepeatableClause<ClauseTy::If>(
-      [&](const ClauseTy::If *ifClause,
+  findRepeatableClause<omp::clause::If>(
+      [&](const omp::clause::If &clause,
           const Fortran::parser::CharBlock &source) {
         mlir::Location clauseLocation = converter.genLocation(source);
-        mlir::Value operand = getIfClauseOperand(converter, ifClause,
+        mlir::Value operand = getIfClauseOperand(converter, clause,
                                                  directiveName, clauseLocation);
         // Assume that, at most, a single 'if' clause will be applicable to the
         // given directive.
@@ -801,12 +778,11 @@ bool ClauseProcessor::processIf(
 
 bool ClauseProcessor::processLink(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
-  return findRepeatableClause<ClauseTy::Link>(
-      [&](const ClauseTy::Link *linkClause,
-          const Fortran::parser::CharBlock &) {
+  return findRepeatableClause<omp::clause::Link>(
+      [&](const omp::clause::Link &clause, const Fortran::parser::CharBlock &) {
         // Case: declare target link(var1, var2)...
         gatherFuncAndVarSyms(
-            linkClause->v, mlir::omp::DeclareTargetCaptureClause::link, result);
+            clause.v, mlir::omp::DeclareTargetCaptureClause::link, result);
       });
 }
 
@@ -843,7 +819,7 @@ bool ClauseProcessor::processMap(
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
     const {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  return findRepeatableClause<ClauseTy::Map>(
+  return findRepeatableClause2<ClauseTy::Map>(
       [&](const ClauseTy::Map *mapClause,
           const Fortran::parser::CharBlock &source) {
         mlir::Location clauseLocation = converter.genLocation(source);
@@ -935,43 +911,41 @@ bool ClauseProcessor::processReduction(
     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
     const {
-  return findRepeatableClause<ClauseTy::Reduction>(
-      [&](const ClauseTy::Reduction *reductionClause,
+  return findRepeatableClause<omp::clause::Reduction>(
+      [&](const omp::clause::Reduction &clause,
           const Fortran::parser::CharBlock &) {
         ReductionProcessor rp;
-        rp.addReductionDecl(currentLocation, converter, reductionClause->v,
-                            reductionVars, reductionDeclSymbols,
-                            reductionSymbols);
+        rp.addReductionDecl(currentLocation, converter, clause, reductionVars,
+                            reductionDeclSymbols, reductionSymbols);
       });
 }
 
 bool ClauseProcessor::processSectionsReduction(
     mlir::Location currentLocation) const {
-  return findRepeatableClause<ClauseTy::Reduction>(
-      [&](const ClauseTy::Reduction *, const Fortran::parser::CharBlock &) {
+  return findRepeatableClause<omp::clause::Reduction>(
+      [&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) {
         TODO(currentLocation, "OMPC_Reduction");
       });
 }
 
 bool ClauseProcessor::processTo(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
-  return findRepeatableClause<ClauseTy::To>(
-      [&](const ClauseTy::To *toClause, const Fortran::parser::CharBlock &) {
+  return findRepeatableClause<omp::clause::To>(
+      [&](const omp::clause::To &clause, const Fortran::parser::CharBlock &) {
         // Case: declare target to(func, var1, var2)...
-        gatherFuncAndVarSyms(toClause->v,
+        gatherFuncAndVarSyms(clause.v,
                              mlir::omp::DeclareTargetCaptureClause::to, result);
       });
 }
 
 bool ClauseProcessor::processEnter(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
-  return findRepeatableClause<ClauseTy::Enter>(
-      [&](const ClauseTy::Enter *enterClause,
+  return findRepeatableClause<omp::clause::Enter>(
+      [&](const omp::clause::Enter &clause,
           const Fortran::parser::CharBlock &) {
         // Case: declare target enter(func, var1, var2)...
-        gatherFuncAndVarSyms(enterClause->v,
-                             mlir::omp::DeclareTargetCaptureClause::enter,
-                             result);
+        gatherFuncAndVarSyms(
+            clause.v, mlir::omp::DeclareTargetCaptureClause::enter, result);
       });
 }
 
@@ -981,11 +955,11 @@ bool ClauseProcessor::processUseDeviceAddr(
     llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
     const {
-  return findRepeatableClause<ClauseTy::UseDeviceAddr>(
-      [&](const ClauseTy::UseDeviceAddr *devAddrClause,
+  return findRepeatableClause<omp::clause::UseDeviceAddr>(
+      [&](const omp::clause::UseDeviceAddr &clause,
           const Fortran::parser::CharBlock &) {
-        addUseDeviceClause(converter, devAddrClause->v, operands,
-                           useDeviceTypes, useDeviceLocs, useDeviceSymbols);
+        addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
+                           useDeviceLocs, useDeviceSymbols);
       });
 }
 
@@ -995,10 +969,10 @@ bool ClauseProcessor::processUseDevicePtr(
     llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
     const {
-  return findRepeatableClause<ClauseTy::UseDevicePtr>(
-      [&](const ClauseTy::UseDevicePtr *devPtrClause,
+  return findRepeatableClause<omp::clause::UseDevicePtr>(
+      [&](const omp::clause::UseDevicePtr &clause,
           const Fortran::parser::CharBlock &) {
-        addUseDeviceClause(converter, devPtrClause->v, operands, useDeviceTypes,
+        addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
                            useDeviceLocs, useDeviceSymbols);
       });
 }
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index c87fc30c88bb93..3f6adcce8ae877 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -105,9 +105,8 @@ class ClauseProcessor {
                      llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
   bool
   processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
-  bool
-  processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
-            mlir::Value &result) const;
+  bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
+                 mlir::Value &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
 
@@ -178,6 +177,10 @@ class ClauseProcessor {
   /// if at least one instance was found.
   template <typename T>
   bool findRepeatableClause(
+      std::function<void(const T &, const Fortran::parser::CharBlock &source)>
+          callbackFn) const;
+  template <typename T>
+  bool findRepeatableClause2(
       std::function<void(const T *, const Fortran::parser::CharBlock &source)>
           callbackFn) const;
 
@@ -195,7 +198,7 @@ template <typename T>
 bool ClauseProcessor::processMotionClauses(
     Fortran::lower::StatementContext &stmtCtx,
     llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
-  return findRepeatableClause<T>(
+  return findRepeatableClause2<T>(
       [&](const T *motionClause, const Fortran::parser::CharBlock &source) {
         mlir::Location clauseLocation = converter.genLocation(source);
         fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -295,6 +298,24 @@ const T *ClauseProcessor::findUniqueClause(
 
 template <typename T>
 bool ClauseProcessor::findRepeatableClause(
+    std::function<void(const T &, const Fortran::parser::CharBlock &source)>
+        callbackFn) const {
+  bool found = false;
+  ClauseIterator nextIt, endIt = clauses.end();
+  for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
+    nextIt = findClause<T>(it, endIt);
+
+    if (nextIt != endIt) {
+      callbackFn(std::get<T>(nextIt->u), nextIt->source);
+      found = true;
+      ++nextIt;
+    }
+  }
+  return found;
+}
+
+template <typename T>
+bool ClauseProcessor::findRepeatableClause2(
     std::function<void(const T *, const Fortran::parser::CharBlock &source)>
         callbackFn) const {
   bool found = false;
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index 0b90b705b9e406..a3aa3d4de3cdc9 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -205,12 +205,6 @@ namespace clause {
 #undef EMPTY_CLASS
 #undef WRAPPER_CLASS
 
-using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
-using ProcedureDesignator =
-    tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
-using ReductionOperator =
-    tomp::clause::ReductionOperatorT<SymIdent, SymReference>;
-
 DefinedOperator makeDefOp(const parser::DefinedOperator &inp,
                           semantics::SemanticsContext &semaCtx) {
   return DefinedOperator{
diff --git a/flang/lib/Lower/OpenMP/Clauses.h b/flang/lib/Lower/OpenMP/Clauses.h
index a7e563f4b0f90b..c167e34637d500 100644
--- a/flang/lib/Lower/OpenMP/Clauses.h
+++ b/flang/lib/Lower/OpenMP/Clauses.h
@@ -106,6 +106,12 @@ getBaseObject(const Object &object,
               Fortran::semantics::SemanticsContext &semaCtx);
 
 namespace clause {
+using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
+using ProcedureDesignator =
+    tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
+using ReductionOperator =
+    tomp::clause::ReductionOperatorT<SymIdent, SymReference>;
+
 #ifdef EMPTY_CLASS
 #undef EMPTY_CLASS
 #endif
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7953bf83cba0fe..7445c0f13526f7 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -572,8 +572,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
 
   ClauseProcessor cp(converter, semaCtx, clauseList);
-  cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
-               ifClauseOperand);
+  cp.processIf(clause::If::DirectiveNameModifier::Parallel, ifClauseOperand);
   cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
   cp.processProcBind(procBindKindAttr);
   cp.processDefault();
@@ -676,8 +675,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
       dependOperands;
 
   ClauseProcessor cp(converter, semaCtx, clauseList);
-  cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
-               ifClauseOperand);
+  cp.processIf(clause::If::DirectiveNameModifier::Task, ifClauseOperand);
   cp.processAllocate(allocatorOperands, allocateOperands);
   cp.processDefault();
   cp.processFinal(stmtCtx, finalClauseOperand);
@@ -738,7 +736,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
 
   ClauseProcessor cp(converter, semaCtx, clauseList);
-  cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
+  cp.processIf(clause::If::DirectiveNameModifier::TargetData,
                ifClauseOperand);
   cp.processDevice(stmtCtx, deviceOperand);
   cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
@@ -770,19 +768,16 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
   llvm::SmallVector<mlir::Attribute> dependTypeOperands;
 
-  Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
+  clause::If::DirectiveNameModifier directiveName;
   llvm::omp::Directive directive;
   if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
-    directiveName =
-        Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
+    directiveName = clause::If::DirectiveNameModifier::TargetEnterData;
     directive = llvm::omp::Directive::OMPD_target_enter_data;
   } else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
-    directiveName =
-        Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
+    directiveName = clause::If::DirectiveNameModifier::TargetExitData;
     directive = llvm::omp::Directive::OMPD_target_exit_data;
   } else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
-    directiveName =
-        Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
+    directiveName = clause::If::DirectiveNameModifier::TargetUpdate;
     directive = llvm::omp::Directive::OMPD_target_update;
   } else {
     return nullptr;
@@ -984,8 +979,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
 
   ClauseProcessor cp(converter, semaCtx, clauseList);
-  cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
-               ifClauseOperand);
+  cp.processIf(clause::If::DirectiveNameModifier::Target, ifClauseOperand);
   cp.processDevice(stmtCtx, deviceOperand);
   cp.processThreadLimit(stmtCtx, threadLimitOperand);
   cp.processDepend(dependTypeOperands, dependOperands);
@@ -1102,8 +1096,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
 
   ClauseProcessor cp(converter, semaCtx, clauseList);
-  cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
-               ifClauseOperand);
+  cp.processIf(clause::If::DirectiveNameModifier::Teams, ifClauseOperand);
   cp.processAllocate(allocatorOperands, allocateOperands);
   cp.processDefault();
   cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
@@ -1142,8 +1135,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
 
   if (const auto *objectList{
           Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
+    ObjectList objects{makeList(*objectList, semaCtx)};
     // Case: declare target(func, var1, var2)
-    gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
+    gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
                          symbolAndClause);
   } else if (const auto *clauseList{
                  Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
@@ -1257,7 +1251,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
   if (const auto &ompObjectList =
           std::get<std::optional<Fortran::parser::OmpObjectList>>(
               flushConstruct.t))
-    genObjectList(*ompObjectList, converter, operandRange);
+    genObjectList2(*ompObjectList, converter, operandRange);
   const auto &memOrderClause =
       std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
           flushConstruct.t);
@@ -1419,8 +1413,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
                      loopVarTypeSize);
   cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
   cp.processReduction(loc, reductionVars, reductionDeclSymbols);
-  cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
-               ifClauseOperand);
+  cp.processIf(clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
   cp.processSimdlen(simdlenClauseOperand);
   cp.processSafelen(safelenClauseOperand);
   cp.processTODO<Fortran::parser::OmpClause::Aligned,
@@ -2223,106 +2216,99 @@ void Fortran::lower::genOpenMPReduction(
     const Fortran::parser::OmpClauseList &clauseList) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
-  for (const Fortran::parser::OmpClause &clause : clauseList.v) {
+  List<Clause> clauses{makeList(clauseList, semaCtx)};
+
+  for (const Clause &clause : clauses) {
     if (const auto &reductionClause =
-            std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
-      const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
-          reductionClause->v.t)};
-      const auto &objectList{
-          std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
+            std::get_if<clause::Reduction>(&clause.u)) {
+      const auto &redOperator{
+          std::get<clause::ReductionOperator>(reductionClause->t)};
+      const auto &objects{std::get<ObjectList>(reductionClause->t)};
       if (const auto *reductionOp =
-              std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+              std::get_if<clause::DefinedOperator>(&redOperator.u)) {
         const auto &intrinsicOp{
-            std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+            std::get<clause::DefinedOperator::IntrinsicOperator>(
                 reductionOp->u)};
 
         switch (intrinsicOp) {
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-        case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+        case clause::DefinedOperator::IntrinsicOperator::Add:
+        case clause::DefinedOperator::IntrinsicOperator::Multiply:
+        case clause::DefinedOperator::IntrinsicOperator::AND:
+        case clause::DefinedOperator::IntrinsicOperator::EQV:
+        case clause::DefinedOperator::IntrinsicOperator::OR:
+        case clause::DefinedOperator::IntrinsicOperator::NEQV:
           break;
         default:
           continue;
         }
-        for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-          if (const auto *name{
-                  Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-            if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-              mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
-              if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
-                reductionVal = declOp.getBase();
-              mlir::Type reductionType =
-                  reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
-              if (!reductionType.isa<fir::LogicalType>()) {
-                if (!reductionType.isIntOrIndexOrFloat())
-                  continue;
-              }
-              for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
-                if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
-                        reductionValUse.getOwner())) {
-                  mlir::Value loadVal = loadOp.getRes();
-                  if (reductionType.isa<fir::LogicalType>()) {
-                    mlir::Operation *reductionOp = findReductionChain(loadVal);
-                    fir::ConvertOp convertOp =
-                        getConvertFromReductionOp(reductionOp, loadVal);
-                    updateReduction(reductionOp, firOpBuilder, loadVal,
-                                    reductionVal, &convertOp);
-                    removeStoreOp(reductionOp, reductionVal);
-                  } else if (mlir::Operation *reductionOp =
-                                 findReductionChain(loadVal, &reductionVal)) {
-                    updateReduction(reductionOp, firOpBuilder, loadVal,
-                                    reductionVal);
-                  }
+        for (const Object &object : objects) {
+          if (const Fortran::semantics::Symbol *symbol = object.id()) {
+            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+              reductionVal = declOp.getBase();
+            mlir::Type reductionType =
+                reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
+            if (!reductionType.isa<fir::LogicalType>()) {
+              if (!reductionType.isIntOrIndexOrFloat())
+                continue;
+            }
+            for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
+              if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+                mlir::Value loadVal = loadOp.getRes();
+                if (reductionType.isa<fir::LogicalType>()) {
+                  mlir::Operation *reductionOp = findReductionChain(loadVal);
+                  fir::ConvertOp convertOp =
+                      getConvertFromReductionOp(reductionOp, loadVal);
+                  updateReduction(reductionOp, firOpBuilder, loadVal,
+                                  reductionVal, &convertOp);
+                  removeStoreOp(reductionOp, reductionVal);
+                } else if (mlir::Operation *reductionOp =
+                               findReductionChain(loadVal, &reductionVal)) {
+                  updateReduction(reductionOp, firOpBuilder, loadVal,
+                                  reductionVal);
                 }
               }
             }
           }
         }
       } else if (const auto *reductionIntrinsic =
-                     std::get_if<Fortran::parser::ProcedureDesignator>(
+                     std::get_if<clause::ProcedureDesignator>(
                          &redOperator.u)) {
         if (!ReductionProcessor::supportedIntrinsicProcReduction(
                 *reductionIntrinsic))
           continue;
         ReductionProcessor::ReductionIdentifier redId =
             ReductionProcessor::getReductionType(*reductionIntrinsic);
-        for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-          if (const auto *name{
-                  Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-            if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-              mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
-              if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
-                reductionVal = declOp.getBase();
-              for (const mlir::OpOperand &reductionValUse :
-                   reductionVal.getUses()) {
-                if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
-                        reductionValUse.getOwner())) {
-                  mlir::Value loadVal = loadOp.getRes();
-                  // Max is lowered as a compare -> select.
-                  // Match the pattern here.
-                  mlir::Operation *reductionOp =
-                      findReductionChain(loadVal, &reductionVal);
-                  if (reductionOp == nullptr)
-                    continue;
-
-                  if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
-                      redId == ReductionProcessor::ReductionIdentifier::MIN) {
-                    assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
-                           "Selection Op not found in reduction intrinsic");
-                    mlir::Operation *compareOp =
-                        getCompareFromReductionOp(reductionOp, loadVal);
-                    updateReduction(compareOp, firOpBuilder, loadVal,
-                                    reductionVal);
-                  }
-                  if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
-                      redId == ReductionProcessor::ReductionIdentifier::IEOR ||
-                      redId == ReductionProcessor::ReductionIdentifier::IAND) {
-                    updateReduction(reductionOp, firOpBuilder, loadVal,
-                                    reductionVal);
-                  }
+        for (const Object &object : objects) {
+          if (const Fortran::semantics::Symbol *symbol = object.id()) {
+            mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+            if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+              reductionVal = declOp.getBase();
+            for (const mlir::OpOperand &reductionValUse :
+                 reductionVal.getUses()) {
+              if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
+                mlir::Value loadVal = loadOp.getRes();
+                // Max is lowered as a compare -> select.
+                // Match the pattern here.
+                mlir::Operation *reductionOp =
+                    findReductionChain(loadVal, &reductionVal);
+                if (reductionOp == nullptr)
+                  continue;
+
+                if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
+                    redId == ReductionProcessor::ReductionIdentifier::MIN) {
+                  assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+                         "Selection Op not found in reduction intrinsic");
+                  mlir::Operation *compareOp =
+                      getCompareFromReductionOp(reductionOp, loadVal);
+                  updateReduction(compareOp, firOpBuilder, loadVal,
+                                  reductionVal);
+                }
+                if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
+                    redId == ReductionProcessor::ReductionIdentifier::IEOR ||
+                    redId == ReductionProcessor::ReductionIdentifier::IAND) {
+                  updateReduction(reductionOp, firOpBuilder, loadVal,
+                                  reductionVal);
                 }
               }
             }
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index a8b98f3f567249..bf755b27487d95 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -23,9 +23,9 @@ namespace lower {
 namespace omp {
 
 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
-    const Fortran::parser::ProcedureDesignator &pd) {
+    const omp::clause::ProcedureDesignator &pd) {
   auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
-                     ReductionProcessor::getRealName(pd).ToString())
+                     getRealName(pd.v.id()).ToString())
                      .Case("max", ReductionIdentifier::MAX)
                      .Case("min", ReductionIdentifier::MIN)
                      .Case("iand", ReductionIdentifier::IAND)
@@ -37,21 +37,21 @@ ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
 }
 
 ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
-    Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+    omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
   switch (intrinsicOp) {
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+  case omp::clause::DefinedOperator::IntrinsicOperator::Add:
     return ReductionIdentifier::ADD;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+  case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
     return ReductionIdentifier::SUBTRACT;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+  case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
     return ReductionIdentifier::MULTIPLY;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+  case omp::clause::DefinedOperator::IntrinsicOperator::AND:
     return ReductionIdentifier::AND;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+  case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
     return ReductionIdentifier::EQV;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+  case omp::clause::DefinedOperator::IntrinsicOperator::OR:
     return ReductionIdentifier::OR;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+  case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
     return ReductionIdentifier::NEQV;
   default:
     llvm_unreachable("unexpected intrinsic operator in reduction");
@@ -59,13 +59,11 @@ ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
 }
 
 bool ReductionProcessor::supportedIntrinsicProcReduction(
-    const Fortran::parser::ProcedureDesignator &pd) {
-  const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
-  assert(name && "Invalid Reduction Intrinsic.");
-  if (!name->symbol->GetUltimate().attrs().test(
-          Fortran::semantics::Attr::INTRINSIC))
+    const omp::clause::ProcedureDesignator &pd) {
+  Fortran::semantics::Symbol *sym = pd.v.id();
+  if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
     return false;
-  auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString())
+  auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
                      .Case("max", true)
                      .Case("min", true)
                      .Case("iand", true)
@@ -84,24 +82,24 @@ std::string ReductionProcessor::getReductionName(llvm::StringRef name,
 }
 
 std::string ReductionProcessor::getReductionName(
-    Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+    omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
     mlir::Type ty) {
   std::string reductionName;
 
   switch (intrinsicOp) {
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+  case omp::clause::DefinedOperator::IntrinsicOperator::Add:
     reductionName = "add_reduction";
     break;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+  case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
     reductionName = "multiply_reduction";
     break;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+  case omp::clause::DefinedOperator::IntrinsicOperator::AND:
     return "and_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+  case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
     return "eqv_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+  case omp::clause::DefinedOperator::IntrinsicOperator::OR:
     return "or_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+  case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
     return "neqv_reduction";
   default:
     reductionName = "other_reduction";
@@ -305,7 +303,7 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
 void ReductionProcessor::addReductionDecl(
     mlir::Location currentLocation,
     Fortran::lower::AbstractConverter &converter,
-    const Fortran::parser::OmpReductionClause &reduction,
+    const omp::clause::Reduction &reduction,
     llvm::SmallVectorImpl<mlir::Value> &reductionVars,
     llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
     llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
@@ -313,12 +311,12 @@ void ReductionProcessor::addReductionDecl(
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::omp::ReductionDeclareOp decl;
   const auto &redOperator{
-      std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
-  const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+      std::get<omp::clause::ReductionOperator>(reduction.t)};
+  const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
   if (const auto &redDefinedOp =
-          std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+          std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
     const auto &intrinsicOp{
-        std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+        std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
             redDefinedOp->u)};
     ReductionIdentifier redId = getReductionType(intrinsicOp);
     switch (redId) {
@@ -334,10 +332,41 @@ void ReductionProcessor::addReductionDecl(
            "Reduction of some intrinsic operators is not supported");
       break;
     }
-    for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-      if (const auto *name{
-              Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-        if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+    for (const omp::Object &object : objectList) {
+      if (const Fortran::semantics::Symbol *symbol = object.id()) {
+        if (reductionSymbols)
+          reductionSymbols->push_back(symbol);
+        mlir::Value symVal = converter.getSymbolAddress(*symbol);
+        if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+          symVal = declOp.getBase();
+        mlir::Type redType =
+            symVal.getType().cast<fir::ReferenceType>().getEleTy();
+        reductionVars.push_back(symVal);
+        if (redType.isa<fir::LogicalType>())
+          decl = createReductionDecl(
+              firOpBuilder,
+              getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
+              redType, currentLocation);
+        else if (redType.isIntOrIndexOrFloat()) {
+          decl = createReductionDecl(firOpBuilder,
+                                     getReductionName(intrinsicOp, redType),
+                                     redId, redType, currentLocation);
+        } else {
+          TODO(currentLocation, "Reduction of some types is not supported");
+        }
+        reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+            firOpBuilder.getContext(), decl.getSymName()));
+      }
+    }
+  } else if (const auto *reductionIntrinsic =
+                 std::get_if<omp::clause::ProcedureDesignator>(
+                     &redOperator.u)) {
+    if (ReductionProcessor::supportedIntrinsicProcReduction(
+            *reductionIntrinsic)) {
+      ReductionProcessor::ReductionIdentifier redId =
+          ReductionProcessor::getReductionType(*reductionIntrinsic);
+      for (const omp::Object &object : objectList) {
+        if (const Fortran::semantics::Symbol *symbol = object.id()) {
           if (reductionSymbols)
             reductionSymbols->push_back(symbol);
           mlir::Value symVal = converter.getSymbolAddress(*symbol);
@@ -346,68 +375,28 @@ void ReductionProcessor::addReductionDecl(
           mlir::Type redType =
               symVal.getType().cast<fir::ReferenceType>().getEleTy();
           reductionVars.push_back(symVal);
-          if (redType.isa<fir::LogicalType>())
-            decl = createReductionDecl(
-                firOpBuilder,
-                getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
-                redType, currentLocation);
-          else if (redType.isIntOrIndexOrFloat()) {
-            decl = createReductionDecl(firOpBuilder,
-                                       getReductionName(intrinsicOp, redType),
-                                       redId, redType, currentLocation);
-          } else {
-            TODO(currentLocation, "Reduction of some types is not supported");
-          }
+          assert(redType.isIntOrIndexOrFloat() && "Unsupported reduction type");
+          decl = createReductionDecl(
+              firOpBuilder,
+              getReductionName(getRealName(*reductionIntrinsic).ToString(),
+                               redType),
+              redId, redType, currentLocation);
           reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
               firOpBuilder.getContext(), decl.getSymName()));
         }
       }
     }
-  } else if (const auto *reductionIntrinsic =
-                 std::get_if<Fortran::parser::ProcedureDesignator>(
-                     &redOperator.u)) {
-    if (ReductionProcessor::supportedIntrinsicProcReduction(
-            *reductionIntrinsic)) {
-      ReductionProcessor::ReductionIdentifier redId =
-          ReductionProcessor::getReductionType(*reductionIntrinsic);
-      for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-        if (const auto *name{
-                Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-          if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-            if (reductionSymbols)
-              reductionSymbols->push_back(symbol);
-            mlir::Value symVal = converter.getSymbolAddress(*symbol);
-            if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
-              symVal = declOp.getBase();
-            mlir::Type redType =
-                symVal.getType().cast<fir::ReferenceType>().getEleTy();
-            reductionVars.push_back(symVal);
-            assert(redType.isIntOrIndexOrFloat() &&
-                   "Unsupported reduction type");
-            decl = createReductionDecl(
-                firOpBuilder,
-                getReductionName(getRealName(*reductionIntrinsic).ToString(),
-                                 redType),
-                redId, redType, currentLocation);
-            reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
-                firOpBuilder.getContext(), decl.getSymName()));
-          }
-        }
-      }
-    }
   }
 }
 
 const Fortran::semantics::SourceName
-ReductionProcessor::getRealName(const Fortran::parser::Name *name) {
-  return name->symbol->GetUltimate().name();
+ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) {
+  return symbol->GetUltimate().name();
 }
 
-const Fortran::semantics::SourceName ReductionProcessor::getRealName(
-    const Fortran::parser::ProcedureDesignator &pd) {
-  const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
-  assert(name && "Invalid Reduction Intrinsic.");
-  return getRealName(name);
+const Fortran::semantics::SourceName
+ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
+  return getRealName(pd.v.id());
 }
 
 int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 00770fe81d1ef6..855e2aa4ad13cd 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -13,6 +13,7 @@
 #ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
 #define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
 
+#include "Clauses.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/symbol.h"
@@ -57,25 +58,25 @@ class ReductionProcessor {
   };
 
   static ReductionIdentifier
-  getReductionType(const Fortran::parser::ProcedureDesignator &pd);
+  getReductionType(const omp::clause::ProcedureDesignator &pd);
 
-  static ReductionIdentifier getReductionType(
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp);
+  static ReductionIdentifier
+  getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp);
 
-  static bool supportedIntrinsicProcReduction(
-      const Fortran::parser::ProcedureDesignator &pd);
+  static bool
+  supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd);
 
   static const Fortran::semantics::SourceName
-  getRealName(const Fortran::parser::Name *name);
+  getRealName(const Fortran::semantics::Symbol *symbol);
 
   static const Fortran::semantics::SourceName
-  getRealName(const Fortran::parser::ProcedureDesignator &pd);
+  getRealName(const omp::clause::ProcedureDesignator &pd);
 
   static std::string getReductionName(llvm::StringRef name, mlir::Type ty);
 
-  static std::string getReductionName(
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Type ty);
+  static std::string
+  getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
+                   mlir::Type ty);
 
   /// This function returns the identity value of the operator \p
   /// reductionOpName. For example:
@@ -112,7 +113,7 @@ class ReductionProcessor {
   static void
   addReductionDecl(mlir::Location currentLocation,
                    Fortran::lower::AbstractConverter &converter,
-                   const Fortran::parser::OmpReductionClause &reduction,
+                   const omp::clause::Reduction &reduction,
                    llvm::SmallVectorImpl<mlir::Value> &reductionVars,
                    llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
                    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 31b15257d18687..9a6a28ded7006d 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "Utils.h"
+#include "Clauses.h"
 
 #include <flang/Lower/AbstractConverter.h>
 #include <flang/Lower/ConvertType.h>
@@ -28,9 +29,27 @@ namespace Fortran {
 namespace lower {
 namespace omp {
 
-void genObjectList(const Fortran::parser::OmpObjectList &objectList,
+void genObjectList(const ObjectList &objects,
                    Fortran::lower::AbstractConverter &converter,
                    llvm::SmallVectorImpl<mlir::Value> &operands) {
+  for (const Object &object : objects) {
+    const Fortran::semantics::Symbol *sym = object.id();
+    assert(sym && "Expected Symbol");
+    if (mlir::Value variable = converter.getSymbolAddress(*sym)) {
+      operands.push_back(variable);
+    } else {
+      if (const auto *details =
+              sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
+        operands.push_back(converter.getSymbolAddress(details->symbol()));
+        converter.copySymbolBinding(details->symbol(), *sym);
+      }
+    }
+  }
+}
+
+void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
+                    Fortran::lower::AbstractConverter &converter,
+                    llvm::SmallVectorImpl<mlir::Value> &operands) {
   auto addOperands = [&](Fortran::lower::SymbolRef sym) {
     const mlir::Value variable = converter.getSymbolAddress(sym);
     if (variable) {
@@ -50,24 +69,10 @@ void genObjectList(const Fortran::parser::OmpObjectList &objectList,
 }
 
 void gatherFuncAndVarSyms(
-    const Fortran::parser::OmpObjectList &objList,
-    mlir::omp::DeclareTargetCaptureClause clause,
+    const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
-  for (const Fortran::parser::OmpObject &ompObject : objList.v) {
-    Fortran::common::visit(
-        Fortran::common::visitors{
-            [&](const Fortran::parser::Designator &designator) {
-              if (const Fortran::parser::Name *name =
-                      Fortran::semantics::getDesignatorNameIfDataRef(
-                          designator)) {
-                symbolAndClause.emplace_back(clause, *name->symbol);
-              }
-            },
-            [&](const Fortran::parser::Name &name) {
-              symbolAndClause.emplace_back(clause, *name.symbol);
-            }},
-        ompObject.u);
-  }
+  for (const Object &object : objects)
+    symbolAndClause.emplace_back(clause, *object.id());
 }
 
 Fortran::semantics::Symbol *
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index c346f891f0797e..4ab4bc9c137071 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -9,6 +9,7 @@
 #ifndef FORTRAN_LOWER_OPENMPUTILS_H
 #define FORTRAN_LOWER_OPENMPUTILS_H
 
+#include "Clauses.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Value.h"
@@ -50,17 +51,20 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
                 bool isVal = false);
 
 void gatherFuncAndVarSyms(
-    const Fortran::parser::OmpObjectList &objList,
-    mlir::omp::DeclareTargetCaptureClause clause,
+    const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);
 
 Fortran::semantics::Symbol *
 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject);
 
-void genObjectList(const Fortran::parser::OmpObjectList &objectList,
+void genObjectList(const ObjectList &objects,
                    Fortran::lower::AbstractConverter &converter,
                    llvm::SmallVectorImpl<mlir::Value> &operands);
 
+void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
+                    Fortran::lower::AbstractConverter &converter,
+                    llvm::SmallVectorImpl<mlir::Value> &operands);
+
 } // namespace omp
 } // namespace lower
 } // namespace Fortran



More information about the llvm-branch-commits mailing list