[flang-commits] [flang] 8c177ae - [Flang][OpenMP][Lower] Refactor MLIR codegen for OpenMP constructs

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Tue Aug 15 05:17:51 PDT 2023


Author: Sergio Afonso
Date: 2023-08-15T13:17:35+01:00
New Revision: 8c177ae9ddba6aefc52289e5d14089e7c838ab8b

URL: https://github.com/llvm/llvm-project/commit/8c177ae9ddba6aefc52289e5d14089e7c838ab8b
DIFF: https://github.com/llvm/llvm-project/commit/8c177ae9ddba6aefc52289e5d14089e7c838ab8b.diff

LOG: [Flang][OpenMP][Lower] Refactor MLIR codegen for OpenMP constructs

This patch extracts MLIR codegen logic from various types of OpenMP constructs
and places it into operation-specific functions. This refactoring mainly
targets block constructs and unifying logic for directives that can appear on
their own as well as combined with others.

The processing of clauses that do not apply to the directive being processed is
avoided and code duplication for combined constructs is reduced.

Depends on D156455.

Differential Revision: https://reviews.llvm.org/D156809

Added: 
    

Modified: 
    flang/lib/Lower/OpenMP.cpp
    flang/test/Lower/OpenMP/parallel-sections.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index a3323e8105b6fd..79d54e232b777d 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2071,14 +2071,13 @@ static void createBodyOfOp(
   }
 }
 
-static void
-createBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
-                     mlir::omp::DataOp &dataOp,
-                     const llvm::SmallVector<mlir::Type> &useDeviceTypes,
-                     const llvm::SmallVector<mlir::Location> &useDeviceLocs,
-                     const llvm::SmallVector<const Fortran::semantics::Symbol *>
-                         &useDeviceSymbols,
-                     const mlir::Location &currentLocation) {
+static void createBodyOfTargetDataOp(
+    Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp,
+    const llvm::SmallVector<mlir::Type> &useDeviceTypes,
+    const llvm::SmallVector<mlir::Location> &useDeviceLocs,
+    const llvm::SmallVector<const Fortran::semantics::Symbol *>
+        &useDeviceSymbols,
+    const mlir::Location &currentLocation) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Region &region = dataOp.getRegion();
 
@@ -2115,15 +2114,148 @@ createBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
   }
 }
 
-static void createTargetOp(Fortran::lower::AbstractConverter &converter,
-                           const Fortran::parser::OmpClauseList &opClauseList,
-                           const llvm::omp::Directive &directive,
-                           mlir::Location currentLocation,
-                           Fortran::lower::pft::Evaluation *eval = nullptr) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+template <typename OpTy, typename... Args>
+static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
+                          Fortran::lower::pft::Evaluation &eval,
+                          mlir::Location currentLocation, bool outerCombined,
+                          const Fortran::parser::OmpClauseList *clauseList,
+                          Args &&...args) {
+  auto op = converter.getFirOpBuilder().create<OpTy>(
+      currentLocation, std::forward<Args>(args)...);
+  createBodyOfOp<OpTy>(op, converter, currentLocation, eval, clauseList,
+                       /*args=*/{}, outerCombined);
+  return op;
+}
+
+static mlir::omp::MasterOp
+genMasterOp(Fortran::lower::AbstractConverter &converter,
+            Fortran::lower::pft::Evaluation &eval,
+            mlir::Location currentLocation) {
+  return genOpWithBody<mlir::omp::MasterOp>(converter, eval, currentLocation,
+                                            /*outerCombined=*/false,
+                                            /*clauseList=*/nullptr,
+                                            /*resultTypes=*/mlir::TypeRange());
+}
+
+static mlir::omp::OrderedRegionOp
+genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
+                   Fortran::lower::pft::Evaluation &eval,
+                   mlir::Location currentLocation) {
+  return genOpWithBody<mlir::omp::OrderedRegionOp>(
+      converter, eval, currentLocation, /*outerCombined=*/false,
+      /*clauseList=*/nullptr, /*simd=*/false);
+}
+
+static mlir::omp::ParallelOp
+genParallelOp(Fortran::lower::AbstractConverter &converter,
+              Fortran::lower::pft::Evaluation &eval,
+              mlir::Location currentLocation,
+              const Fortran::parser::OmpClauseList &clauseList,
+              bool outerCombined = false) {
   Fortran::lower::StatementContext stmtCtx;
-  mlir::Value ifClauseOperand, deviceOperand, threadLmtOperand;
+  mlir::Value ifClauseOperand, numThreadsClauseOperand;
+  mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
+  llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
+      reductionVars;
+  llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
+
+  ClauseProcessor cp(converter, clauseList);
+  cp.processIf(stmtCtx,
+               Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
+               ifClauseOperand);
+  cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
+  cp.processProcBind(procBindKindAttr);
+  cp.processDefault();
+  cp.processAllocate(allocatorOperands, allocateOperands);
+  cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
+
+  return genOpWithBody<mlir::omp::ParallelOp>(
+      converter, eval, currentLocation, outerCombined, &clauseList,
+      /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
+      numThreadsClauseOperand, allocateOperands, allocatorOperands,
+      reductionVars,
+      reductionDeclSymbols.empty()
+          ? nullptr
+          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                 reductionDeclSymbols),
+      procBindKindAttr);
+}
+
+static mlir::omp::SingleOp
+genSingleOp(Fortran::lower::AbstractConverter &converter,
+            Fortran::lower::pft::Evaluation &eval,
+            mlir::Location currentLocation,
+            const Fortran::parser::OmpClauseList &beginClauseList,
+            const Fortran::parser::OmpClauseList &endClauseList) {
+  llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
   mlir::UnitAttr nowaitAttr;
+
+  ClauseProcessor(converter, beginClauseList)
+      .processAllocate(allocatorOperands, allocateOperands);
+  ClauseProcessor(converter, endClauseList).processNowait(nowaitAttr);
+
+  return genOpWithBody<mlir::omp::SingleOp>(
+      converter, eval, currentLocation, /*outerCombined=*/false,
+      &beginClauseList, allocateOperands, allocatorOperands, nowaitAttr);
+}
+
+static mlir::omp::TaskOp
+genTaskOp(Fortran::lower::AbstractConverter &converter,
+          Fortran::lower::pft::Evaluation &eval, mlir::Location currentLocation,
+          const Fortran::parser::OmpClauseList &clauseList) {
+  Fortran::lower::StatementContext stmtCtx;
+  mlir::Value ifClauseOperand, finalClauseOperand, priorityClauseOperand;
+  mlir::UnitAttr untiedAttr, mergeableAttr;
+  llvm::SmallVector<mlir::Attribute> dependTypeOperands;
+  llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
+      dependOperands;
+
+  ClauseProcessor cp(converter, clauseList);
+  cp.processIf(stmtCtx,
+               Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
+               ifClauseOperand);
+  cp.processAllocate(allocatorOperands, allocateOperands);
+  cp.processDefault();
+  cp.processFinal(stmtCtx, finalClauseOperand);
+  cp.processUntied(untiedAttr);
+  cp.processMergeable(mergeableAttr);
+  cp.processPriority(stmtCtx, priorityClauseOperand);
+  cp.processDepend(dependTypeOperands, dependOperands);
+
+  return genOpWithBody<mlir::omp::TaskOp>(
+      converter, eval, currentLocation, /*outerCombined=*/false, &clauseList,
+      ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr,
+      /*in_reduction_vars=*/mlir::ValueRange(),
+      /*in_reductions=*/nullptr, priorityClauseOperand,
+      dependTypeOperands.empty()
+          ? nullptr
+          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                 dependTypeOperands),
+      dependOperands, allocateOperands, allocatorOperands);
+}
+
+static mlir::omp::TaskGroupOp
+genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
+               Fortran::lower::pft::Evaluation &eval,
+               mlir::Location currentLocation,
+               const Fortran::parser::OmpClauseList &clauseList) {
+  llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
+  // TODO: Add task_reduction support
+  ClauseProcessor(converter, clauseList)
+      .processAllocate(allocatorOperands, allocateOperands);
+  return genOpWithBody<mlir::omp::TaskGroupOp>(
+      converter, eval, currentLocation, /*outerCombined=*/false, &clauseList,
+      /*task_reduction_vars=*/mlir::ValueRange(),
+      /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
+}
+
+static mlir::omp::DataOp
+genDataOp(Fortran::lower::AbstractConverter &converter,
+          mlir::Location currentLocation,
+          const Fortran::parser::OmpClauseList &clauseList) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  Fortran::lower::StatementContext stmtCtx;
+  mlir::Value ifClauseOperand, deviceOperand;
   llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands,
       deviceAddrOperands;
   llvm::SmallVector<mlir::IntegerAttr> mapTypes;
@@ -2131,79 +2263,106 @@ static void createTargetOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Location> useDeviceLocs;
   llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
 
+  ClauseProcessor cp(converter, clauseList);
+  cp.processIf(stmtCtx,
+               Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
+               ifClauseOperand);
+  cp.processDevice(stmtCtx, deviceOperand);
+  cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
+                         useDeviceSymbols);
+  cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
+                          useDeviceSymbols);
+  cp.processMap(mapOperands, mapTypes);
+
+  llvm::SmallVector<mlir::Attribute> mapTypesAttr(mapTypes.begin(),
+                                                  mapTypes.end());
+  mlir::ArrayAttr mapTypesArrayAttr =
+      mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr);
+
+  auto dataOp = converter.getFirOpBuilder().create<mlir::omp::DataOp>(
+      currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
+      deviceAddrOperands, mapOperands, mapTypesArrayAttr);
+  createBodyOfTargetDataOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
+                           useDeviceSymbols, currentLocation);
+  return dataOp;
+}
+
+template <typename OpTy>
+static OpTy
+genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
+                   mlir::Location currentLocation,
+                   const Fortran::parser::OmpClauseList &clauseList) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  Fortran::lower::StatementContext stmtCtx;
+  mlir::Value ifClauseOperand, deviceOperand;
+  mlir::UnitAttr nowaitAttr;
+  llvm::SmallVector<mlir::Value> mapOperands;
+  llvm::SmallVector<mlir::IntegerAttr> mapTypes;
+
   Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
-  switch (directive) {
-  case llvm::omp::Directive::OMPD_target:
-    directiveName = Fortran::parser::OmpIfClause::DirectiveNameModifier::Target;
-    break;
-  case llvm::omp::Directive::OMPD_target_data:
-    directiveName =
-        Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData;
-    break;
-  case llvm::omp::Directive::OMPD_target_enter_data:
+  if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
     directiveName =
         Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
-    break;
-  case llvm::omp::Directive::OMPD_target_exit_data:
+  } else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
     directiveName =
         Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
-    break;
-  default:
-    TODO(currentLocation, "OMPD_target directive unknown");
-    break;
+  } else {
+    return nullptr;
   }
 
-  ClauseProcessor cp(converter, opClauseList);
+  ClauseProcessor cp(converter, clauseList);
   cp.processIf(stmtCtx, directiveName, ifClauseOperand);
   cp.processDevice(stmtCtx, deviceOperand);
-  cp.processThreadLimit(stmtCtx, threadLmtOperand);
   cp.processNowait(nowaitAttr);
-  cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
-                         useDeviceSymbols);
-  cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
-                          useDeviceSymbols);
   cp.processMap(mapOperands, mapTypes);
 
-  for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
-    if (!std::get_if<Fortran::parser::OmpClause::If>(&clause.u) &&
-        !std::get_if<Fortran::parser::OmpClause::Device>(&clause.u) &&
-        !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
-        !std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u) &&
-        !std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
-        !std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
-        !std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
-      mlir::Location clauseLocation = converter.genLocation(clause.source);
-      TODO(clauseLocation, "OMPD_target unhandled clause");
-    }
-  }
+  llvm::SmallVector<mlir::Attribute> mapTypesAttr(mapTypes.begin(),
+                                                  mapTypes.end());
+  mlir::ArrayAttr mapTypesArrayAttr =
+      mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr);
+
+  return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
+                                   deviceOperand, nowaitAttr, mapOperands,
+                                   mapTypesArrayAttr);
+}
+
+static mlir::omp::TargetOp
+genTargetOp(Fortran::lower::AbstractConverter &converter,
+            Fortran::lower::pft::Evaluation &eval,
+            mlir::Location currentLocation,
+            const Fortran::parser::OmpClauseList &clauseList,
+            bool outerCombined = false) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  Fortran::lower::StatementContext stmtCtx;
+  mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
+  mlir::UnitAttr nowaitAttr;
+  llvm::SmallVector<mlir::Value> mapOperands;
+  llvm::SmallVector<mlir::IntegerAttr> mapTypes;
+
+  ClauseProcessor cp(converter, clauseList);
+  cp.processIf(stmtCtx,
+               Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
+               ifClauseOperand);
+  cp.processDevice(stmtCtx, deviceOperand);
+  cp.processThreadLimit(stmtCtx, threadLimitOperand);
+  cp.processNowait(nowaitAttr);
+  cp.processMap(mapOperands, mapTypes);
 
   llvm::SmallVector<mlir::Attribute> mapTypesAttr(mapTypes.begin(),
                                                   mapTypes.end());
   mlir::ArrayAttr mapTypesArrayAttr =
       mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr);
 
-  if (directive == llvm::omp::Directive::OMPD_target) {
-    auto targetOp = firOpBuilder.create<mlir::omp::TargetOp>(
-        currentLocation, ifClauseOperand, deviceOperand, threadLmtOperand,
-        nowaitAttr, mapOperands, mapTypesArrayAttr);
-    createBodyOfOp(targetOp, converter, currentLocation, *eval, &opClauseList);
-  } else if (directive == llvm::omp::Directive::OMPD_target_data) {
-    auto dataOp = firOpBuilder.create<mlir::omp::DataOp>(
-        currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
-        deviceAddrOperands, mapOperands, mapTypesArrayAttr);
-    createBodyOfTargetOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
-                         useDeviceSymbols, currentLocation);
-  } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) {
-    firOpBuilder.create<mlir::omp::EnterDataOp>(
-        currentLocation, ifClauseOperand, deviceOperand, nowaitAttr,
-        mapOperands, mapTypesArrayAttr);
-  } else if (directive == llvm::omp::Directive::OMPD_target_exit_data) {
-    firOpBuilder.create<mlir::omp::ExitDataOp>(currentLocation, ifClauseOperand,
-                                               deviceOperand, nowaitAttr,
-                                               mapOperands, mapTypesArrayAttr);
-  }
+  return genOpWithBody<mlir::omp::TargetOp>(
+      converter, eval, currentLocation, outerCombined, &clauseList,
+      ifClauseOperand, deviceOperand, threadLimitOperand, nowaitAttr,
+      mapOperands, mapTypesArrayAttr);
 }
 
+//===----------------------------------------------------------------------===//
+// genOMP() Code generation helper functions
+//===----------------------------------------------------------------------===//
+
 static void genOMP(Fortran::lower::AbstractConverter &converter,
                    Fortran::lower::pft::Evaluation &eval,
                    const Fortran::parser::OpenMPSimpleStandaloneConstruct
@@ -2229,9 +2388,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation);
     break;
   case llvm::omp::Directive::OMPD_target_data:
+    genDataOp(converter, currentLocation, opClauseList);
+    break;
   case llvm::omp::Directive::OMPD_target_enter_data:
+    genEnterExitDataOp<mlir::omp::EnterDataOp>(converter, currentLocation,
+                                               opClauseList);
+    break;
   case llvm::omp::Directive::OMPD_target_exit_data:
-    createTargetOp(converter, opClauseList, directive.v, currentLocation);
+    genEnterExitDataOp<mlir::omp::ExitDataOp>(converter, currentLocation,
+                                              opClauseList);
     break;
   case llvm::omp::Directive::OMPD_target_update:
     TODO(currentLocation, "OMPD_target_update");
@@ -2276,89 +2441,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
       standaloneConstruct.u);
 }
 
-/* When parallel is used in a combined construct, then use this function to
- * create the parallel operation. It handles the parallel specific clauses
- * and leaves the rest for handling at the inner operations.
- */
-template <typename Directive>
-static void
-createCombinedParallelOp(Fortran::lower::AbstractConverter &converter,
-                         Fortran::lower::pft::Evaluation &eval,
-                         const Directive &directive) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  mlir::Location currentLocation = converter.getCurrentLocation();
-  Fortran::lower::StatementContext stmtCtx;
-  llvm::ArrayRef<mlir::Type> argTy;
-  mlir::Value ifClauseOperand, numThreadsClauseOperand;
-  llvm::SmallVector<mlir::Value> allocatorOperands, allocateOperands;
-  mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
-  const auto &opClauseList =
-      std::get<Fortran::parser::OmpClauseList>(directive.t);
-  // TODO: Handle the following clauses
-  // 1. default
-  // Note: rest of the clauses are handled when the inner operation is created
-  ClauseProcessor cp(converter, opClauseList);
-  cp.processIf(stmtCtx,
-               Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
-               ifClauseOperand);
-  cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
-  cp.processProcBind(procBindKindAttr);
-
-  // Create and insert the operation.
-  auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
-      currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
-      allocateOperands, allocatorOperands,
-      /*reduction_vars=*/mlir::ValueRange(),
-      /*reductions=*/nullptr, procBindKindAttr);
-
-  createBodyOfOp<mlir::omp::ParallelOp>(parallelOp, converter, currentLocation,
-                                        eval, &opClauseList, /*iv=*/{},
-                                        /*isCombined=*/true);
-}
-
-/* When target is used in a combined construct, then use this function to
- * create the target operation. It handles the target specific clauses
- * and leaves the rest for handling at the inner operations.
- */
-template <typename Directive>
-static void createCombinedTargetOp(Fortran::lower::AbstractConverter &converter,
-                                   Fortran::lower::pft::Evaluation &eval,
-                                   const Directive &directive) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  mlir::Location currentLocation = converter.getCurrentLocation();
-  Fortran::lower::StatementContext stmtCtx;
-  mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
-  llvm::SmallVector<mlir::Value> mapOperands;
-  llvm::SmallVector<mlir::IntegerAttr> mapTypes;
-  mlir::UnitAttr nowaitAttr;
-  const auto &opClauseList =
-      std::get<Fortran::parser::OmpClauseList>(directive.t);
-
-  // Note: rest of the clauses are handled when the inner operation is created
-  ClauseProcessor cp(converter, opClauseList);
-  cp.processIf(stmtCtx,
-               Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
-               ifClauseOperand);
-  cp.processDevice(stmtCtx, deviceOperand);
-  cp.processThreadLimit(stmtCtx, threadLimitOperand);
-  cp.processNowait(nowaitAttr);
-  cp.processMap(mapOperands, mapTypes);
-
-  llvm::SmallVector<mlir::Attribute> mapTypesAttr(mapTypes.begin(),
-                                                  mapTypes.end());
-  mlir::ArrayAttr mapTypesArrayAttr =
-      mlir::ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr);
-
-  // Create and insert the operation.
-  auto targetOp = firOpBuilder.create<mlir::omp::TargetOp>(
-      currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
-      nowaitAttr, mapOperands, mapTypesArrayAttr);
-
-  createBodyOfOp<mlir::omp::TargetOp>(targetOp, converter, currentLocation,
-                                      eval, &opClauseList,
-                                      /*iv=*/{}, /*isCombined=*/true);
-}
-
 static void genOMP(Fortran::lower::AbstractConverter &converter,
                    Fortran::lower::pft::Evaluation &eval,
                    const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
@@ -2395,8 +2477,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      createCombinedTargetOp<Fortran::parser::OmpBeginLoopDirective>(
-          converter, eval, beginLoopDirective);
+      genTargetOp(converter, eval, currentLocation, loopOpClauseList,
+                  /*outerCombined=*/true);
     }
     if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
@@ -2410,8 +2492,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      createCombinedParallelOp<Fortran::parser::OmpBeginLoopDirective>(
-          converter, eval, beginLoopDirective);
+      genParallelOp(converter, eval, currentLocation, loopOpClauseList,
+                    /*outerCombined=*/true);
     }
   }
   if ((llvm::omp::allDoSet | llvm::omp::allSimdSet).test(ompDirective))
@@ -2516,67 +2598,16 @@ genOMP(Fortran::lower::AbstractConverter &converter,
        const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
   const auto &beginBlockDirective =
       std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
-  const auto &blockDirective =
-      std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
   const auto &endBlockDirective =
       std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  mlir::Location currentLocation = converter.genLocation(blockDirective.source);
-
-  Fortran::lower::StatementContext stmtCtx;
-  llvm::ArrayRef<mlir::Type> argTy;
-  mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand,
-      priorityClauseOperand;
-  mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
-  llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
-      dependOperands, reductionVars;
-  llvm::SmallVector<mlir::Attribute> dependTypeOperands, reductionDeclSymbols;
-  mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;
-
-  // Use placeholder value to avoid uninitialized `directiveName` compiler
-  // errors. The 'if clause' obtained won't be used for these directives.
-  Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName =
-      Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel;
-  switch (blockDirective.v) {
-  case llvm::omp::OMPD_parallel:
-    directiveName =
-        Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel;
-    break;
-  case llvm::omp::OMPD_task:
-    directiveName = Fortran::parser::OmpIfClause::DirectiveNameModifier::Task;
-    break;
-  // Target-related 'if' clauses handled by createTargetOp().
-  case llvm::omp::OMPD_target:
-  case llvm::omp::OMPD_target_data:
-  // These block directives do not accept an 'if' clause.
-  case llvm::omp::OMPD_master:
-  case llvm::omp::OMPD_single:
-  case llvm::omp::OMPD_ordered:
-  case llvm::omp::OMPD_taskgroup:
-    break;
-  default:
-    TODO(currentLocation,
-         "Unhandled block directive (" +
-             llvm::omp::getOpenMPDirectiveName(blockDirective.v) + ")");
-    break;
-  }
-
-  const auto &opClauseList =
+  const auto &directive =
+      std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
+  const auto &beginClauseList =
       std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
-  ClauseProcessor cp(converter, opClauseList);
-  cp.processIf(stmtCtx, directiveName, ifClauseOperand);
-  cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
-  cp.processProcBind(procBindKindAttr);
-  cp.processAllocate(allocatorOperands, allocateOperands);
-  cp.processDefault();
-  cp.processFinal(stmtCtx, finalClauseOperand);
-  cp.processUntied(untiedAttr);
-  cp.processMergeable(mergeableAttr);
-  cp.processPriority(stmtCtx, priorityClauseOperand);
-  cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
-  cp.processDepend(dependTypeOperands, dependOperands);
+  const auto &endClauseList =
+      std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t);
 
-  for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
+  for (const Fortran::parser::OmpClause &clause : beginClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (!std::get_if<Fortran::parser::OmpClause::If>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::NumThreads>(&clause.u) &&
@@ -2589,17 +2620,11 @@ genOMP(Fortran::lower::AbstractConverter &converter,
         !std::get_if<Fortran::parser::OmpClause::Priority>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u) &&
-        // Privatisation and copyin clauses are handled elsewhere.
         !std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::Firstprivate>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u) &&
-        // Shared is the default behavior in the IR, so no handling is required.
         !std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u) &&
-        // Nothing needs to be done for threads clause.
         !std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u) &&
-        // Map, UseDevicePtr, UseDeviceAddr and ThreadLimit clauses are
-        // exclusive to Target directives. They are handled as part of the
-        // TargetOp creation.
         !std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
@@ -2608,67 +2633,78 @@ genOMP(Fortran::lower::AbstractConverter &converter,
     }
   }
 
-  ClauseProcessor(converter,
-                  std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t))
-      .processNowait(nowaitAttr);
-  for (const auto &clause :
-       std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
+  for (const auto &clause : endClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (!std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
       TODO(clauseLocation, "OpenMP Block construct clause");
   }
 
-  if (blockDirective.v == llvm::omp::OMPD_parallel) {
-    // Create and insert the operation.
-    auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
-        currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
-        allocateOperands, allocatorOperands, reductionVars,
-        reductionDeclSymbols.empty()
-            ? nullptr
-            : mlir::ArrayAttr::get(firOpBuilder.getContext(),
-                                   reductionDeclSymbols),
-        procBindKindAttr);
-    createBodyOfOp<mlir::omp::ParallelOp>(parallelOp, converter,
-                                          currentLocation, eval, &opClauseList);
-  } else if (blockDirective.v == llvm::omp::OMPD_master) {
-    auto masterOp =
-        firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
-    createBodyOfOp<mlir::omp::MasterOp>(masterOp, converter, currentLocation,
-                                        eval);
-  } else if (blockDirective.v == llvm::omp::OMPD_single) {
-    auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>(
-        currentLocation, allocateOperands, allocatorOperands, nowaitAttr);
-    createBodyOfOp<mlir::omp::SingleOp>(singleOp, converter, currentLocation,
-                                        eval, &opClauseList);
-  } else if (blockDirective.v == llvm::omp::OMPD_ordered) {
-    auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>(
-        currentLocation, /*simd=*/false);
-    createBodyOfOp<mlir::omp::OrderedRegionOp>(orderedOp, converter,
-                                               currentLocation, eval);
-  } else if (blockDirective.v == llvm::omp::OMPD_task) {
-    auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
-        currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
-        mergeableAttr, /*in_reduction_vars=*/mlir::ValueRange(),
-        /*in_reductions=*/nullptr, priorityClauseOperand,
-        dependTypeOperands.empty()
-            ? nullptr
-            : mlir::ArrayAttr::get(firOpBuilder.getContext(),
-                                   dependTypeOperands),
-        dependOperands, allocateOperands, allocatorOperands);
-    createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
-  } else if (blockDirective.v == llvm::omp::OMPD_taskgroup) {
-    // TODO: Add task_reduction support
-    auto taskGroupOp = firOpBuilder.create<mlir::omp::TaskGroupOp>(
-        currentLocation, /*task_reduction_vars=*/mlir::ValueRange(),
-        /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
-    createBodyOfOp(taskGroupOp, converter, currentLocation, eval,
-                   &opClauseList);
-  } else if (blockDirective.v == llvm::omp::OMPD_target) {
-    createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
-                   &eval);
-  } else if (blockDirective.v == llvm::omp::OMPD_target_data) {
-    createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
-                   &eval);
+  mlir::Location currentLocation = converter.genLocation(directive.source);
+  switch (directive.v) {
+  case llvm::omp::Directive::OMPD_master:
+    genMasterOp(converter, eval, currentLocation);
+    break;
+  case llvm::omp::Directive::OMPD_ordered:
+    genOrderedRegionOp(converter, eval, currentLocation);
+    break;
+  case llvm::omp::Directive::OMPD_parallel:
+    genParallelOp(converter, eval, currentLocation, beginClauseList);
+    break;
+  case llvm::omp::Directive::OMPD_single:
+    genSingleOp(converter, eval, currentLocation, beginClauseList,
+                endClauseList);
+    break;
+  case llvm::omp::Directive::OMPD_target:
+    genTargetOp(converter, eval, currentLocation, beginClauseList);
+    break;
+  case llvm::omp::Directive::OMPD_target_data:
+    genDataOp(converter, currentLocation, beginClauseList);
+    break;
+  case llvm::omp::Directive::OMPD_task:
+    genTaskOp(converter, eval, currentLocation, beginClauseList);
+    break;
+  case llvm::omp::Directive::OMPD_taskgroup:
+    genTaskGroupOp(converter, eval, currentLocation, beginClauseList);
+    break;
+  case llvm::omp::Directive::OMPD_teams:
+    TODO(currentLocation, "Teams construct");
+    break;
+  case llvm::omp::Directive::OMPD_workshare:
+    TODO(currentLocation, "Workshare construct");
+    break;
+  default: {
+    // Codegen for combined directives
+    bool combinedDirective = false;
+    if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet)
+            .test(directive.v)) {
+      genTargetOp(converter, eval, currentLocation, beginClauseList,
+                  /*outerCombined=*/true);
+      combinedDirective = true;
+    }
+    if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet)
+            .test(directive.v)) {
+      TODO(currentLocation, "Teams construct");
+      combinedDirective = true;
+    }
+    if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet)
+            .test(directive.v)) {
+      bool outerCombined =
+          directive.v != llvm::omp::Directive::OMPD_target_parallel;
+      genParallelOp(converter, eval, currentLocation, beginClauseList,
+                    outerCombined);
+      combinedDirective = true;
+    }
+    if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet)
+            .test(directive.v)) {
+      TODO(currentLocation, "Workshare construct");
+      combinedDirective = true;
+    }
+    if (!combinedDirective)
+      TODO(currentLocation, "Unhandled block directive (" +
+                                llvm::omp::getOpenMPDirectiveName(directive.v) +
+                                ")");
+    break;
+  }
   }
 }
 
@@ -2742,51 +2778,44 @@ static void
 genOMP(Fortran::lower::AbstractConverter &converter,
        Fortran::lower::pft::Evaluation &eval,
        const Fortran::parser::OpenMPSectionsConstruct &sectionsConstruct) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
-  llvm::SmallVector<mlir::Value> reductionVars, allocateOperands,
-      allocatorOperands;
+  llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
   mlir::UnitAttr nowaitClauseOperand;
   const auto &beginSectionsDirective =
       std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t);
   const auto &sectionsClauseList =
       std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t);
 
+  // Process clauses before optional omp.parallel, so that new variables are
+  // allocated outside of the parallel region
   ClauseProcessor cp(converter, sectionsClauseList);
   cp.processSectionsReduction(currentLocation);
   cp.processAllocate(allocatorOperands, allocateOperands);
 
-  const auto &endSectionsDirective =
-      std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t);
-  const auto &endSectionsClauseList =
-      std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t);
-  ClauseProcessor(converter, endSectionsClauseList)
-      .processNowait(nowaitClauseOperand);
-
   llvm::omp::Directive dir =
       std::get<Fortran::parser::OmpSectionsDirective>(beginSectionsDirective.t)
           .v;
 
-  // Parallel Sections Construct
+  // Parallel wrapper of PARALLEL SECTIONS construct
   if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
-    createCombinedParallelOp<Fortran::parser::OmpBeginSectionsDirective>(
-        converter, eval,
-        std::get<Fortran::parser::OmpBeginSectionsDirective>(
-            sectionsConstruct.t));
-    auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>(
-        currentLocation, /*reduction_vars*/ mlir::ValueRange(),
-        /*reductions=*/nullptr, allocateOperands, allocatorOperands,
-        /*nowait=*/nullptr);
-    createBodyOfOp(sectionsOp, converter, currentLocation, eval);
-
-    // Sections Construct
-  } else if (dir == llvm::omp::Directive::OMPD_sections) {
-    auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>(
-        currentLocation, reductionVars, /*reductions=*/nullptr,
-        allocateOperands, allocatorOperands, nowaitClauseOperand);
-    createBodyOfOp<mlir::omp::SectionsOp>(sectionsOp, converter,
-                                          currentLocation, eval);
-  }
+    genParallelOp(converter, eval, currentLocation, sectionsClauseList,
+                  /*outerCombined=*/true);
+  } else {
+    const auto &endSectionsDirective =
+        std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t);
+    const auto &endSectionsClauseList =
+        std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t);
+    ClauseProcessor(converter, endSectionsClauseList)
+        .processNowait(nowaitClauseOperand);
+  }
+
+  // SECTIONS construct
+  genOpWithBody<mlir::omp::SectionsOp>(converter, eval, currentLocation,
+                                       /*outerCombined=*/false,
+                                       /*clauseList=*/nullptr,
+                                       /*reduction_vars=*/mlir::ValueRange(),
+                                       /*reductions=*/nullptr, allocateOperands,
+                                       allocatorOperands, nowaitClauseOperand);
 }
 
 static bool checkForSingleVariableOnRHS(

diff  --git a/flang/test/Lower/OpenMP/parallel-sections.f90 b/flang/test/Lower/OpenMP/parallel-sections.f90
index 0b04bfadfb8490..a638fdf293392f 100644
--- a/flang/test/Lower/OpenMP/parallel-sections.f90
+++ b/flang/test/Lower/OpenMP/parallel-sections.f90
@@ -38,12 +38,16 @@ end subroutine omp_parallel_sections
 subroutine omp_parallel_sections_allocate(x, y)
   use omp_lib
   integer, intent(inout) :: x, y
-  !FIRDialect: %[[allocator:.*]] = arith.constant 1 : i32
-  !LLVMDialect: %[[allocator:.*]] = llvm.mlir.constant(1 : i32) : i32
-  !OMPDialect: omp.parallel {
+  !FIRDialect: %[[allocator_1:.*]] = arith.constant 1 : i32
+  !FIRDialect: %[[allocator_2:.*]] = arith.constant 1 : i32
+  !LLVMDialect: %[[allocator_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+  !LLVMDialect: %[[allocator_2:.*]] = llvm.mlir.constant(1 : i32) : i32
+  !OMPDialect: omp.parallel allocate(
+  !FIRDialect: %[[allocator_2]] : i32 -> %{{.*}} : !fir.ref<i32>) {
+  !LLVMDialect: %[[allocator_2]] : i32 -> %{{.*}} : !llvm.ptr<i32>) {
   !OMPDialect: omp.sections allocate(
-  !FIRDialect: %[[allocator]] : i32 -> %{{.*}} : !fir.ref<i32>) {
-  !LLVMDialect: %[[allocator]] : i32 -> %{{.*}} : !llvm.ptr<i32>) {
+  !FIRDialect: %[[allocator_1]] : i32 -> %{{.*}} : !fir.ref<i32>) {
+  !LLVMDialect: %[[allocator_1]] : i32 -> %{{.*}} : !llvm.ptr<i32>) {
   !$omp parallel sections allocate(omp_high_bw_mem_alloc: x)
     !OMPDialect: omp.section {
     !$omp section


        


More information about the flang-commits mailing list