[flang-commits] [flang] e5a34f9 - [Flang][OpenMP] Push genEval closer to leaf lowering functions (#77760)

via flang-commits flang-commits at lists.llvm.org
Thu Jan 18 05:47:39 PST 2024


Author: Krzysztof Parzyszek
Date: 2024-01-18T07:47:35-06:00
New Revision: e5a34f9226ef56669f670dc32661934ee3e56f37

URL: https://github.com/llvm/llvm-project/commit/e5a34f9226ef56669f670dc32661934ee3e56f37
DIFF: https://github.com/llvm/llvm-project/commit/e5a34f9226ef56669f670dc32661934ee3e56f37.diff

LOG: [Flang][OpenMP] Push genEval closer to leaf lowering functions (#77760)

This moves the lowering of the nested evaluations all the way to the
bottom of the call stack. This PR does not attempt to change the leaf
lowering functions beyond placing the call to `genEval` in there.
Whether the nested evaluations should be lowered for any given op
depends on the context in which that op is created, hence a `genNested`
parameter was added.

Contexts in which nested evaluations should not be lowered are during
lowering of composite constructs, such as PARALLEL SECTIONS. This
particular case is considered a block construct tied to the SECTIONS
directive, and the lowering code will first create an empty parallel op,
and then recursively lower the SECTIONS code. Similar situations occur
when lowering most (if not all) compound/composite constructs.

Recursive lowering [4/5]

Added: 
    

Modified: 
    flang/lib/Lower/OpenMP.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index c770d1c60718c35..f269155e4519494 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -112,7 +112,8 @@ static void gatherFuncAndVarSyms(
 
 static Fortran::lower::pft::Evaluation *
 getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) {
-  // Return the Evaluation of the innermost collapsed loop.
+  // Return the Evaluation of the innermost collapsed loop, or the current one
+  // if there was no COLLAPSE.
   if (collapseValue == 0)
     return &eval;
 
@@ -131,7 +132,7 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
                                  Fortran::lower::pft::Evaluation &eval,
                                  int collapseValue = 0) {
   Fortran::lower::pft::Evaluation *curEval =
-      collapseValue == 0 ? &eval : getCollapsedLoopEval(eval, collapseValue);
+      getCollapsedLoopEval(eval, collapseValue);
 
   for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
     converter.genEval(e);
@@ -2171,6 +2172,7 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
 /// \param [inout] converter - converter to use for the clauses.
 /// \param [in]    loc - location in source code.
 /// \param [in]    eval - current PFT node/evaluation.
+/// \param [in]    genNested - whether to generate FIR for nested evaluations
 /// \oaran [in]    clauses - list of clauses to process.
 /// \param [in]    args - block arguments (induction variable[s]) for the
 ////                      region.
@@ -2179,7 +2181,7 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
 template <typename Op>
 static void createBodyOfOp(
     Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
-    Fortran::lower::pft::Evaluation &eval,
+    Fortran::lower::pft::Evaluation &eval, bool genNested,
     const Fortran::parser::OmpClauseList *clauses = nullptr,
     const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
     bool outerCombined = false, DataSharingProcessor *dsp = nullptr) {
@@ -2249,11 +2251,15 @@ static void createBodyOfOp(
     if (clauses)
       ClauseProcessor(converter, *clauses).processCopyin();
   }
+
+  if (genNested)
+    genNestedEvaluations(converter, eval);
 }
 
 static void genBodyOfTargetDataOp(
     Fortran::lower::AbstractConverter &converter,
-    Fortran::lower::pft::Evaluation &eval, mlir::omp::DataOp &dataOp,
+    Fortran::lower::pft::Evaluation &eval, bool genNested,
+    mlir::omp::DataOp &dataOp,
     const llvm::SmallVector<mlir::Type> &useDeviceTypes,
     const llvm::SmallVector<mlir::Location> &useDeviceLocs,
     const llvm::SmallVector<const Fortran::semantics::Symbol *>
@@ -2311,26 +2317,30 @@ static void genBodyOfTargetDataOp(
 
   // Set the insertion point after the marker.
   firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
+  if (genNested)
+    genNestedEvaluations(converter, eval);
 }
 
 template <typename OpTy, typename... Args>
 static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
-                          Fortran::lower::pft::Evaluation &eval,
+                          Fortran::lower::pft::Evaluation &eval, bool genNested,
                           mlir::Location currentLocation, bool outerCombined,
                           const Fortran::parser::OmpClauseList *clauseList,
                           Args &&...args) {
   auto op = converter.getFirOpBuilder().create<OpTy>(
       currentLocation, std::forward<Args>(args)...);
-  createBodyOfOp<OpTy>(op, converter, currentLocation, eval, clauseList,
+  createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
+                       clauseList,
                        /*args=*/{}, outerCombined);
   return op;
 }
 
 static mlir::omp::MasterOp
 genMasterOp(Fortran::lower::AbstractConverter &converter,
-            Fortran::lower::pft::Evaluation &eval,
+            Fortran::lower::pft::Evaluation &eval, bool genNested,
             mlir::Location currentLocation) {
-  return genOpWithBody<mlir::omp::MasterOp>(converter, eval, currentLocation,
+  return genOpWithBody<mlir::omp::MasterOp>(converter, eval, genNested,
+                                            currentLocation,
                                             /*outerCombined=*/false,
                                             /*clauseList=*/nullptr,
                                             /*resultTypes=*/mlir::TypeRange());
@@ -2338,16 +2348,17 @@ genMasterOp(Fortran::lower::AbstractConverter &converter,
 
 static mlir::omp::OrderedRegionOp
 genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
-                   Fortran::lower::pft::Evaluation &eval,
+                   Fortran::lower::pft::Evaluation &eval, bool genNested,
                    mlir::Location currentLocation) {
   return genOpWithBody<mlir::omp::OrderedRegionOp>(
-      converter, eval, currentLocation, /*outerCombined=*/false,
+      converter, eval, genNested, currentLocation,
+      /*outerCombined=*/false,
       /*clauseList=*/nullptr, /*simd=*/false);
 }
 
 static mlir::omp::ParallelOp
 genParallelOp(Fortran::lower::AbstractConverter &converter,
-              Fortran::lower::pft::Evaluation &eval,
+              Fortran::lower::pft::Evaluation &eval, bool genNested,
               mlir::Location currentLocation,
               const Fortran::parser::OmpClauseList &clauseList,
               bool outerCombined = false) {
@@ -2369,7 +2380,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
     cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
 
   return genOpWithBody<mlir::omp::ParallelOp>(
-      converter, eval, currentLocation, outerCombined, &clauseList,
+      converter, eval, genNested, currentLocation, outerCombined, &clauseList,
       /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
       numThreadsClauseOperand, allocateOperands, allocatorOperands,
       reductionVars,
@@ -2382,19 +2393,19 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
 
 static mlir::omp::SectionOp
 genSectionOp(Fortran::lower::AbstractConverter &converter,
-             Fortran::lower::pft::Evaluation &eval,
+             Fortran::lower::pft::Evaluation &eval, bool genNested,
              mlir::Location currentLocation,
              const Fortran::parser::OmpClauseList &sectionsClauseList) {
   // Currently only private/firstprivate clause is handled, and
   // all privatization is done within `omp.section` operations.
-  return genOpWithBody<mlir::omp::SectionOp>(converter, eval, currentLocation,
-                                             /*outerCombined=*/false,
-                                             &sectionsClauseList);
+  return genOpWithBody<mlir::omp::SectionOp>(
+      converter, eval, genNested, currentLocation,
+      /*outerCombined=*/false, &sectionsClauseList);
 }
 
 static mlir::omp::SingleOp
 genSingleOp(Fortran::lower::AbstractConverter &converter,
-            Fortran::lower::pft::Evaluation &eval,
+            Fortran::lower::pft::Evaluation &eval, bool genNested,
             mlir::Location currentLocation,
             const Fortran::parser::OmpClauseList &beginClauseList,
             const Fortran::parser::OmpClauseList &endClauseList) {
@@ -2409,13 +2420,15 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
   ClauseProcessor(converter, endClauseList).processNowait(nowaitAttr);
 
   return genOpWithBody<mlir::omp::SingleOp>(
-      converter, eval, currentLocation, /*outerCombined=*/false,
-      &beginClauseList, allocateOperands, allocatorOperands, nowaitAttr);
+      converter, eval, genNested, currentLocation,
+      /*outerCombined=*/false, &beginClauseList, allocateOperands,
+      allocatorOperands, nowaitAttr);
 }
 
 static mlir::omp::TaskOp
 genTaskOp(Fortran::lower::AbstractConverter &converter,
-          Fortran::lower::pft::Evaluation &eval, mlir::Location currentLocation,
+          Fortran::lower::pft::Evaluation &eval, bool genNested,
+          mlir::Location currentLocation,
           const Fortran::parser::OmpClauseList &clauseList) {
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value ifClauseOperand, finalClauseOperand, priorityClauseOperand;
@@ -2440,8 +2453,9 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
       currentLocation, llvm::omp::Directive::OMPD_task);
 
   return genOpWithBody<mlir::omp::TaskOp>(
-      converter, eval, currentLocation, /*outerCombined=*/false, &clauseList,
-      ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr,
+      converter, eval, genNested, currentLocation,
+      /*outerCombined=*/false, &clauseList, ifClauseOperand, finalClauseOperand,
+      untiedAttr, mergeableAttr,
       /*in_reduction_vars=*/mlir::ValueRange(),
       /*in_reductions=*/nullptr, priorityClauseOperand,
       dependTypeOperands.empty()
@@ -2453,7 +2467,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
 
 static mlir::omp::TaskGroupOp
 genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
-               Fortran::lower::pft::Evaluation &eval,
+               Fortran::lower::pft::Evaluation &eval, bool genNested,
                mlir::Location currentLocation,
                const Fortran::parser::OmpClauseList &clauseList) {
   llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
@@ -2462,7 +2476,8 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
   cp.processTODO<Fortran::parser::OmpClause::TaskReduction>(
       currentLocation, llvm::omp::Directive::OMPD_taskgroup);
   return genOpWithBody<mlir::omp::TaskGroupOp>(
-      converter, eval, currentLocation, /*outerCombined=*/false, &clauseList,
+      converter, eval, genNested, currentLocation,
+      /*outerCombined=*/false, &clauseList,
       /*task_reduction_vars=*/mlir::ValueRange(),
       /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
 }
@@ -2471,7 +2486,7 @@ static mlir::omp::DataOp
 genDataOp(Fortran::lower::AbstractConverter &converter,
           Fortran::lower::pft::Evaluation &eval,
           Fortran::semantics::SemanticsContext &semanticsContext,
-          mlir::Location currentLocation,
+          bool genNested, mlir::Location currentLocation,
           const Fortran::parser::OmpClauseList &clauseList) {
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value ifClauseOperand, deviceOperand;
@@ -2495,8 +2510,8 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
   auto dataOp = converter.getFirOpBuilder().create<mlir::omp::DataOp>(
       currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
       deviceAddrOperands, mapOperands);
-  genBodyOfTargetDataOp(converter, eval, dataOp, useDeviceTypes, useDeviceLocs,
-                        useDeviceSymbols, currentLocation);
+  genBodyOfTargetDataOp(converter, eval, genNested, dataOp, useDeviceTypes,
+                        useDeviceLocs, useDeviceSymbols, currentLocation);
   return dataOp;
 }
 
@@ -2557,7 +2572,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
 // all the symbols present in mapSymbols as block arguments to this block.
 static void genBodyOfTargetOp(
     Fortran::lower::AbstractConverter &converter,
-    Fortran::lower::pft::Evaluation &eval, mlir::omp::TargetOp &targetOp,
+    Fortran::lower::pft::Evaluation &eval, bool genNested,
+    mlir::omp::TargetOp &targetOp,
     const llvm::SmallVector<mlir::Type> &mapSymTypes,
     const llvm::SmallVector<mlir::Location> &mapSymLocs,
     const llvm::SmallVector<const Fortran::semantics::Symbol *> &mapSymbols,
@@ -2687,7 +2703,6 @@ static void genBodyOfTargetOp(
 
   // Create blocks for unstructured regions. This has to be done since
   // blocks are initially allocated with the function as the parent region.
-  // the parent region of blocks.
   if (eval.lowerAsUnstructured()) {
     Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
                                             mlir::omp::YieldOp>(
@@ -2698,13 +2713,15 @@ static void genBodyOfTargetOp(
 
   // Create the insertion point after the marker.
   firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
+  if (genNested)
+    genNestedEvaluations(converter, eval);
 }
 
 static mlir::omp::TargetOp
 genTargetOp(Fortran::lower::AbstractConverter &converter,
             Fortran::lower::pft::Evaluation &eval,
             Fortran::semantics::SemanticsContext &semanticsContext,
-            mlir::Location currentLocation,
+            bool genNested, mlir::Location currentLocation,
             const Fortran::parser::OmpClauseList &clauseList,
             llvm::omp::Directive directive, bool outerCombined = false) {
   Fortran::lower::StatementContext stmtCtx;
@@ -2804,15 +2821,15 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
       currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
       nowaitAttr, mapOperands);
 
-  genBodyOfTargetOp(converter, eval, targetOp, mapSymTypes, mapSymLocs,
-                    mapSymbols, currentLocation);
+  genBodyOfTargetOp(converter, eval, genNested, targetOp, mapSymTypes,
+                    mapSymLocs, mapSymbols, currentLocation);
 
   return targetOp;
 }
 
 static mlir::omp::TeamsOp
 genTeamsOp(Fortran::lower::AbstractConverter &converter,
-           Fortran::lower::pft::Evaluation &eval,
+           Fortran::lower::pft::Evaluation &eval, bool genNested,
            mlir::Location currentLocation,
            const Fortran::parser::OmpClauseList &clauseList,
            bool outerCombined = false) {
@@ -2833,7 +2850,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
       currentLocation, llvm::omp::Directive::OMPD_teams);
 
   return genOpWithBody<mlir::omp::TeamsOp>(
-      converter, eval, currentLocation, outerCombined, &clauseList,
+      converter, eval, genNested, currentLocation, outerCombined, &clauseList,
       /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
       threadLimitClauseOperand, allocateOperands, allocatorOperands,
       reductionVars,
@@ -2917,6 +2934,7 @@ static void
 genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
                        Fortran::lower::pft::Evaluation &eval,
                        Fortran::semantics::SemanticsContext &semanticsContext,
+                       bool genNested,
                        const Fortran::parser::OpenMPSimpleStandaloneConstruct
                            &simpleStandaloneConstruct) {
   const auto &directive =
@@ -2944,7 +2962,8 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
     firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation);
     break;
   case llvm::omp::Directive::OMPD_target_data:
-    genDataOp(converter, eval, semanticsContext, currentLocation, opClauseList);
+    genDataOp(converter, eval, semanticsContext, genNested, currentLocation,
+              opClauseList);
     break;
   case llvm::omp::Directive::OMPD_target_enter_data:
     genEnterExitUpdateDataOp<mlir::omp::EnterDataOp>(
@@ -2992,6 +3011,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
           [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
                   &simpleStandaloneConstruct) {
             genOmpSimpleStandalone(converter, eval, semanticsContext,
+                                   /*genNested=*/true,
                                    simpleStandaloneConstruct);
           },
           [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
@@ -3071,12 +3091,12 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
       /*alignment_values=*/nullptr, ifClauseOperand, nontemporalVars,
       orderClauseOperand, simdlenClauseOperand, safelenClauseOperand,
       /*inclusive=*/firOpBuilder.getUnitAttr());
-  createBodyOfOp<mlir::omp::SimdLoopOp>(simdLoopOp, converter, loc, eval,
-                                        &loopOpClauseList, iv,
-                                        /*outer=*/false, &dsp);
 
-  genNestedEvaluations(converter, eval,
-                       Fortran::lower::getCollapseValue(loopOpClauseList));
+  auto *nestedEval = getCollapsedLoopEval(
+      eval, Fortran::lower::getCollapseValue(loopOpClauseList));
+  createBodyOfOp<mlir::omp::SimdLoopOp>(simdLoopOp, converter, loc, *nestedEval,
+                                        /*genNested=*/true, &loopOpClauseList,
+                                        iv, /*outer=*/false, &dsp);
 }
 
 static void createWsLoop(Fortran::lower::AbstractConverter &converter,
@@ -3147,12 +3167,11 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
       wsLoopOp.setNowaitAttr(nowaitClauseOperand);
   }
 
-  createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, loc, eval,
-                                      &beginClauseList, iv,
+  auto *nestedEval = getCollapsedLoopEval(
+      eval, Fortran::lower::getCollapseValue(beginClauseList));
+  createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, loc, *nestedEval,
+                                      /*genNested=*/true, &beginClauseList, iv,
                                       /*outer=*/false, &dsp);
-
-  genNestedEvaluations(converter, eval,
-                       Fortran::lower::getCollapseValue(beginClauseList));
 }
 
 static void genOMP(Fortran::lower::AbstractConverter &converter,
@@ -3189,13 +3208,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      genTargetOp(converter, eval, semanticsContext, currentLocation,
-                  loopOpClauseList, ompDirective, /*outerCombined=*/true);
+      genTargetOp(converter, eval, semanticsContext, /*genNested=*/false,
+                  currentLocation, loopOpClauseList, ompDirective,
+                  /*outerCombined=*/true);
     }
     if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      genTeamsOp(converter, eval, currentLocation, loopOpClauseList,
+      genTeamsOp(converter, eval, /*genNested=*/false, currentLocation,
+                 loopOpClauseList,
                  /*outerCombined=*/true);
     }
     if (llvm::omp::allDistributeSet.test(ompDirective)) {
@@ -3205,7 +3226,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      genParallelOp(converter, eval, currentLocation, loopOpClauseList,
+      genParallelOp(converter, eval, /*genNested=*/false, currentLocation,
+                    loopOpClauseList,
                     /*outerCombined=*/true);
     }
   }
@@ -3278,76 +3300,89 @@ genOMP(Fortran::lower::AbstractConverter &converter,
       TODO(clauseLocation, "OpenMP Block construct clause");
   }
 
+  bool singleDirective = true;
   mlir::Location currentLocation = converter.genLocation(directive.source);
   switch (directive.v) {
   case llvm::omp::Directive::OMPD_master:
-    genMasterOp(converter, eval, currentLocation);
+    genMasterOp(converter, eval, /*genNested=*/true, currentLocation);
     break;
   case llvm::omp::Directive::OMPD_ordered:
-    genOrderedRegionOp(converter, eval, currentLocation);
+    genOrderedRegionOp(converter, eval, /*genNested=*/true, currentLocation);
     break;
   case llvm::omp::Directive::OMPD_parallel:
-    genParallelOp(converter, eval, currentLocation, beginClauseList);
+    genParallelOp(converter, eval, /*genNested=*/true, currentLocation,
+                  beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_single:
-    genSingleOp(converter, eval, currentLocation, beginClauseList,
-                endClauseList);
+    genSingleOp(converter, eval, /*genNested=*/true, currentLocation,
+                beginClauseList, endClauseList);
     break;
   case llvm::omp::Directive::OMPD_target:
-    genTargetOp(converter, eval, semanticsContext, currentLocation,
-                beginClauseList, directive.v);
+    genTargetOp(converter, eval, semanticsContext, /*genNested=*/true,
+                currentLocation, beginClauseList, directive.v);
     break;
   case llvm::omp::Directive::OMPD_target_data:
-    genDataOp(converter, eval, semanticsContext, currentLocation,
-              beginClauseList);
+    genDataOp(converter, eval, semanticsContext, /*genNested=*/true,
+              currentLocation, beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_task:
-    genTaskOp(converter, eval, currentLocation, beginClauseList);
+    genTaskOp(converter, eval, /*genNested=*/true, currentLocation,
+              beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_taskgroup:
-    genTaskGroupOp(converter, eval, currentLocation, beginClauseList);
+    genTaskGroupOp(converter, eval, /*genNested=*/true, currentLocation,
+                   beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_teams:
-    genTeamsOp(converter, eval, currentLocation, beginClauseList,
+    genTeamsOp(converter, eval, /*genNested=*/true, currentLocation,
+               beginClauseList,
                /*outerCombined=*/false);
     break;
   case llvm::omp::Directive::OMPD_workshare:
     TODO(currentLocation, "Workshare construct");
     break;
-  default: {
-    // Codegen for combined directives
-    bool combinedDirective = false;
-    if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet)
-            .test(directive.v)) {
-      genTargetOp(converter, eval, semanticsContext, currentLocation,
-                  beginClauseList, directive.v, /*outerCombined=*/true);
-      combinedDirective = true;
-    }
-    if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet)
-            .test(directive.v)) {
-      genTeamsOp(converter, eval, currentLocation, beginClauseList);
-      combinedDirective = true;
-    }
-    if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet)
-            .test(directive.v)) {
-      bool outerCombined =
-          directive.v != llvm::omp::Directive::OMPD_target_parallel;
-      genParallelOp(converter, eval, currentLocation, beginClauseList,
-                    outerCombined);
-      combinedDirective = true;
-    }
-    if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet)
-            .test(directive.v)) {
-      TODO(currentLocation, "Workshare construct");
-      combinedDirective = true;
-    }
-    if (!combinedDirective)
-      TODO(currentLocation, "Unhandled block directive (" +
-                                llvm::omp::getOpenMPDirectiveName(directive.v) +
-                                ")");
+  default:
+    singleDirective = false;
     break;
   }
+
+  if (singleDirective) {
+    genOpenMPReduction(converter, beginClauseList);
+    return;
+  }
+
+  // Codegen for combined directives
+  bool combinedDirective = false;
+  if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet)
+          .test(directive.v)) {
+    genTargetOp(converter, eval, semanticsContext, /*genNested=*/false,
+                currentLocation, beginClauseList, directive.v,
+                /*outerCombined=*/true);
+    combinedDirective = true;
+  }
+  if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet)
+          .test(directive.v)) {
+    genTeamsOp(converter, eval, /*genNested=*/false, currentLocation,
+               beginClauseList);
+    combinedDirective = true;
   }
+  if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet)
+          .test(directive.v)) {
+    bool outerCombined =
+        directive.v != llvm::omp::Directive::OMPD_target_parallel;
+    genParallelOp(converter, eval, /*genNested=*/false, currentLocation,
+                  beginClauseList, outerCombined);
+    combinedDirective = true;
+  }
+  if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet)
+          .test(directive.v)) {
+    TODO(currentLocation, "Workshare construct");
+    combinedDirective = true;
+  }
+  if (!combinedDirective)
+    TODO(currentLocation, "Unhandled block directive (" +
+                              llvm::omp::getOpenMPDirectiveName(directive.v) +
+                              ")");
 
   genNestedEvaluations(converter, eval);
   genOpenMPReduction(converter, beginClauseList);
@@ -3390,8 +3425,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
                                                       global.getSymName()));
   }();
   createBodyOfOp<mlir::omp::CriticalOp>(criticalOp, converter, currentLocation,
-                                        eval);
-  genNestedEvaluations(converter, eval);
+                                        eval, /*genNested=*/true);
 }
 
 static void
@@ -3420,7 +3454,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
 
   // Parallel wrapper of PARALLEL SECTIONS construct
   if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
-    genParallelOp(converter, eval, currentLocation, sectionsClauseList,
+    genParallelOp(converter, eval,
+                  /*genNested=*/false, currentLocation, sectionsClauseList,
                   /*outerCombined=*/true);
   } else {
     const auto &endSectionsDirective =
@@ -3432,7 +3467,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
   }
 
   // SECTIONS construct
-  genOpWithBody<mlir::omp::SectionsOp>(converter, eval, currentLocation,
+  genOpWithBody<mlir::omp::SectionsOp>(converter, eval,
+                                       /*genNested=*/false, currentLocation,
                                        /*outerCombined=*/false,
                                        /*clauseList=*/nullptr,
                                        /*reduction_vars=*/mlir::ValueRange(),
@@ -3446,8 +3482,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
   for (const auto &[nblock, neval] :
        llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) {
     symTable.pushScope();
-    genSectionOp(converter, neval, currentLocation, sectionsClauseList);
-    genNestedEvaluations(converter, neval);
+    genSectionOp(converter, neval, /*genNested=*/true, currentLocation,
+                 sectionsClauseList);
     symTable.popScope();
     firOpBuilder.restoreInsertionPoint(ip);
   }


        


More information about the flang-commits mailing list