[flang-commits] [flang] c8dca5b - [Flang][OpenMP][Lower] Refactor lowering of compound constructs (#87070)

via flang-commits flang-commits at lists.llvm.org
Wed Apr 17 04:17:54 PDT 2024


Author: Sergio Afonso
Date: 2024-04-17T12:17:50+01:00
New Revision: c8dca5bc0733e2fba81008fc33fcad1f45ba666a

URL: https://github.com/llvm/llvm-project/commit/c8dca5bc0733e2fba81008fc33fcad1f45ba666a
DIFF: https://github.com/llvm/llvm-project/commit/c8dca5bc0733e2fba81008fc33fcad1f45ba666a.diff

LOG: [Flang][OpenMP][Lower] Refactor lowering of compound constructs (#87070)

This patch simplifies the lowering from PFT to MLIR of OpenMP compound
constructs (i.e. combined and composite).

The new approach consists of iteratively processing the outermost leaf
construct of the given combined construct until it cannot be split
further. Both leaf constructs and composite ones have `gen...()`
functions that are called when appropriate.

This approach enables treating a leaf construct the same way regardless
of if it appeared as part of a combined construct, and it also enables
the lowering of composite constructs as a single unit.

Previous corner cases are now handled in a more straightforward way and
comments pointing to the relevant spec section are added. Directive sets
are also completed with missing LOOP related constructs.

Added: 
    

Modified: 
    flang/include/flang/Semantics/openmp-directive-sets.h
    flang/lib/Lower/OpenMP/OpenMP.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 91773ae3ea9a3e..842d251b682aa9 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -32,14 +32,14 @@ static const OmpDirectiveSet topDistributeSet{
 
 static const OmpDirectiveSet allDistributeSet{
     OmpDirectiveSet{
-        llvm::omp::OMPD_target_teams_distribute,
-        llvm::omp::OMPD_target_teams_distribute_parallel_do,
-        llvm::omp::OMPD_target_teams_distribute_parallel_do_simd,
-        llvm::omp::OMPD_target_teams_distribute_simd,
-        llvm::omp::OMPD_teams_distribute,
-        llvm::omp::OMPD_teams_distribute_parallel_do,
-        llvm::omp::OMPD_teams_distribute_parallel_do_simd,
-        llvm::omp::OMPD_teams_distribute_simd,
+        Directive::OMPD_target_teams_distribute,
+        Directive::OMPD_target_teams_distribute_parallel_do,
+        Directive::OMPD_target_teams_distribute_parallel_do_simd,
+        Directive::OMPD_target_teams_distribute_simd,
+        Directive::OMPD_teams_distribute,
+        Directive::OMPD_teams_distribute_parallel_do,
+        Directive::OMPD_teams_distribute_parallel_do_simd,
+        Directive::OMPD_teams_distribute_simd,
     } | topDistributeSet,
 };
 
@@ -63,10 +63,24 @@ static const OmpDirectiveSet allDoSet{
     } | topDoSet,
 };
 
+static const OmpDirectiveSet topLoopSet{
+    Directive::OMPD_loop,
+};
+
+static const OmpDirectiveSet allLoopSet{
+    OmpDirectiveSet{
+        Directive::OMPD_parallel_loop,
+        Directive::OMPD_target_parallel_loop,
+        Directive::OMPD_target_teams_loop,
+        Directive::OMPD_teams_loop,
+    } | topLoopSet,
+};
+
 static const OmpDirectiveSet topParallelSet{
     Directive::OMPD_parallel,
     Directive::OMPD_parallel_do,
     Directive::OMPD_parallel_do_simd,
+    Directive::OMPD_parallel_loop,
     Directive::OMPD_parallel_masked_taskloop,
     Directive::OMPD_parallel_masked_taskloop_simd,
     Directive::OMPD_parallel_master_taskloop,
@@ -82,6 +96,7 @@ static const OmpDirectiveSet allParallelSet{
         Directive::OMPD_target_parallel,
         Directive::OMPD_target_parallel_do,
         Directive::OMPD_target_parallel_do_simd,
+        Directive::OMPD_target_parallel_loop,
         Directive::OMPD_target_teams_distribute_parallel_do,
         Directive::OMPD_target_teams_distribute_parallel_do_simd,
         Directive::OMPD_teams_distribute_parallel_do,
@@ -118,12 +133,14 @@ static const OmpDirectiveSet topTargetSet{
     Directive::OMPD_target_parallel,
     Directive::OMPD_target_parallel_do,
     Directive::OMPD_target_parallel_do_simd,
+    Directive::OMPD_target_parallel_loop,
     Directive::OMPD_target_simd,
     Directive::OMPD_target_teams,
     Directive::OMPD_target_teams_distribute,
     Directive::OMPD_target_teams_distribute_parallel_do,
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
+    Directive::OMPD_target_teams_loop,
 };
 
 static const OmpDirectiveSet allTargetSet{topTargetSet};
@@ -156,11 +173,12 @@ static const OmpDirectiveSet topTeamsSet{
 
 static const OmpDirectiveSet allTeamsSet{
     OmpDirectiveSet{
-        llvm::omp::OMPD_target_teams,
-        llvm::omp::OMPD_target_teams_distribute,
-        llvm::omp::OMPD_target_teams_distribute_parallel_do,
-        llvm::omp::OMPD_target_teams_distribute_parallel_do_simd,
-        llvm::omp::OMPD_target_teams_distribute_simd,
+        Directive::OMPD_target_teams,
+        Directive::OMPD_target_teams_distribute,
+        Directive::OMPD_target_teams_distribute_parallel_do,
+        Directive::OMPD_target_teams_distribute_parallel_do_simd,
+        Directive::OMPD_target_teams_distribute_simd,
+        Directive::OMPD_target_teams_loop,
     } | topTeamsSet,
 };
 
@@ -178,6 +196,14 @@ static const OmpDirectiveSet allDistributeSimdSet{
 static const OmpDirectiveSet allDoSimdSet{allDoSet & allSimdSet};
 static const OmpDirectiveSet allTaskloopSimdSet{allTaskloopSet & allSimdSet};
 
+static const OmpDirectiveSet compositeConstructSet{
+    Directive::OMPD_distribute_parallel_do,
+    Directive::OMPD_distribute_parallel_do_simd,
+    Directive::OMPD_distribute_simd,
+    Directive::OMPD_do_simd,
+    Directive::OMPD_taskloop_simd,
+};
+
 static const OmpDirectiveSet blockConstructSet{
     Directive::OMPD_master,
     Directive::OMPD_ordered,
@@ -201,12 +227,14 @@ static const OmpDirectiveSet loopConstructSet{
     Directive::OMPD_distribute_simd,
     Directive::OMPD_do,
     Directive::OMPD_do_simd,
+    Directive::OMPD_loop,
     Directive::OMPD_masked_taskloop,
     Directive::OMPD_masked_taskloop_simd,
     Directive::OMPD_master_taskloop,
     Directive::OMPD_master_taskloop_simd,
     Directive::OMPD_parallel_do,
     Directive::OMPD_parallel_do_simd,
+    Directive::OMPD_parallel_loop,
     Directive::OMPD_parallel_masked_taskloop,
     Directive::OMPD_parallel_masked_taskloop_simd,
     Directive::OMPD_parallel_master_taskloop,
@@ -214,17 +242,20 @@ static const OmpDirectiveSet loopConstructSet{
     Directive::OMPD_simd,
     Directive::OMPD_target_parallel_do,
     Directive::OMPD_target_parallel_do_simd,
+    Directive::OMPD_target_parallel_loop,
     Directive::OMPD_target_simd,
     Directive::OMPD_target_teams_distribute,
     Directive::OMPD_target_teams_distribute_parallel_do,
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
+    Directive::OMPD_target_teams_loop,
     Directive::OMPD_taskloop,
     Directive::OMPD_taskloop_simd,
     Directive::OMPD_teams_distribute,
     Directive::OMPD_teams_distribute_parallel_do,
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
+    Directive::OMPD_teams_loop,
     Directive::OMPD_tile,
     Directive::OMPD_unroll,
 };

diff  --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c31d63625dbb17..bb38082b245ef5 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -488,6 +488,81 @@ markDeclareTarget(mlir::Operation *op,
   declareTargetOp.setDeclareTarget(deviceType, captureClause);
 }
 
+/// Split a combined directive into an outer leaf directive and the (possibly
+/// combined) rest of the combined directive. Composite directives and
+/// non-compound directives are not split, in which case it will return the
+/// input directive as its first output and an empty value as its second output.
+static std::pair<llvm::omp::Directive, std::optional<llvm::omp::Directive>>
+splitCombinedDirective(llvm::omp::Directive dir) {
+  using D = llvm::omp::Directive;
+  switch (dir) {
+  case D::OMPD_masked_taskloop:
+    return {D::OMPD_masked, D::OMPD_taskloop};
+  case D::OMPD_masked_taskloop_simd:
+    return {D::OMPD_masked, D::OMPD_taskloop_simd};
+  case D::OMPD_master_taskloop:
+    return {D::OMPD_master, D::OMPD_taskloop};
+  case D::OMPD_master_taskloop_simd:
+    return {D::OMPD_master, D::OMPD_taskloop_simd};
+  case D::OMPD_parallel_do:
+    return {D::OMPD_parallel, D::OMPD_do};
+  case D::OMPD_parallel_do_simd:
+    return {D::OMPD_parallel, D::OMPD_do_simd};
+  case D::OMPD_parallel_masked:
+    return {D::OMPD_parallel, D::OMPD_masked};
+  case D::OMPD_parallel_masked_taskloop:
+    return {D::OMPD_parallel, D::OMPD_masked_taskloop};
+  case D::OMPD_parallel_masked_taskloop_simd:
+    return {D::OMPD_parallel, D::OMPD_masked_taskloop_simd};
+  case D::OMPD_parallel_master:
+    return {D::OMPD_parallel, D::OMPD_master};
+  case D::OMPD_parallel_master_taskloop:
+    return {D::OMPD_parallel, D::OMPD_master_taskloop};
+  case D::OMPD_parallel_master_taskloop_simd:
+    return {D::OMPD_parallel, D::OMPD_master_taskloop_simd};
+  case D::OMPD_parallel_sections:
+    return {D::OMPD_parallel, D::OMPD_sections};
+  case D::OMPD_parallel_workshare:
+    return {D::OMPD_parallel, D::OMPD_workshare};
+  case D::OMPD_target_parallel:
+    return {D::OMPD_target, D::OMPD_parallel};
+  case D::OMPD_target_parallel_do:
+    return {D::OMPD_target, D::OMPD_parallel_do};
+  case D::OMPD_target_parallel_do_simd:
+    return {D::OMPD_target, D::OMPD_parallel_do_simd};
+  case D::OMPD_target_simd:
+    return {D::OMPD_target, D::OMPD_simd};
+  case D::OMPD_target_teams:
+    return {D::OMPD_target, D::OMPD_teams};
+  case D::OMPD_target_teams_distribute:
+    return {D::OMPD_target, D::OMPD_teams_distribute};
+  case D::OMPD_target_teams_distribute_parallel_do:
+    return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do};
+  case D::OMPD_target_teams_distribute_parallel_do_simd:
+    return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do_simd};
+  case D::OMPD_target_teams_distribute_simd:
+    return {D::OMPD_target, D::OMPD_teams_distribute_simd};
+  case D::OMPD_teams_distribute:
+    return {D::OMPD_teams, D::OMPD_distribute};
+  case D::OMPD_teams_distribute_parallel_do:
+    return {D::OMPD_teams, D::OMPD_distribute_parallel_do};
+  case D::OMPD_teams_distribute_parallel_do_simd:
+    return {D::OMPD_teams, D::OMPD_distribute_parallel_do_simd};
+  case D::OMPD_teams_distribute_simd:
+    return {D::OMPD_teams, D::OMPD_distribute_simd};
+  case D::OMPD_parallel_loop:
+    return {D::OMPD_parallel, D::OMPD_loop};
+  case D::OMPD_target_parallel_loop:
+    return {D::OMPD_target, D::OMPD_parallel_loop};
+  case D::OMPD_target_teams_loop:
+    return {D::OMPD_target, D::OMPD_teams_loop};
+  case D::OMPD_teams_loop:
+    return {D::OMPD_teams, D::OMPD_loop};
+  default:
+    return {dir, std::nullopt};
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Op body generation helper structures and functions
 //===----------------------------------------------------------------------===//
@@ -1676,7 +1751,6 @@ static OpTy genTargetEnterExitUpdateDataOp(
   } else {
     llvm_unreachable("Unexpected TARGET DATA construct");
   }
-
   mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps;
   genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauseList,
                                       loc, directive, clauseOps);
@@ -1804,16 +1878,44 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
 // Code generation functions for composite constructs
 //===----------------------------------------------------------------------===//
 
-static void genCompositeDoSimd(
+static void genCompositeDistributeParallelDo(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::semantics::SemanticsContext &semaCtx,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OmpClauseList &beginClauseList,
+    const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+  TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
+}
+
+static void genCompositeDistributeParallelDoSimd(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::semantics::SemanticsContext &semaCtx,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OmpClauseList &beginClauseList,
+    const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+  TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
+}
+
+static void genCompositeDistributeSimd(
     Fortran::lower::AbstractConverter &converter,
     Fortran::semantics::SemanticsContext &semaCtx,
-    Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive ompDirective,
+    Fortran::lower::pft::Evaluation &eval,
     const Fortran::parser::OmpClauseList &beginClauseList,
     const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+  TODO(loc, "Composite DISTRIBUTE SIMD");
+}
+
+static void
+genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
+                   Fortran::semantics::SemanticsContext &semaCtx,
+                   Fortran::lower::pft::Evaluation &eval,
+                   const Fortran::parser::OmpClauseList &beginClauseList,
+                   const Fortran::parser::OmpClauseList *endClauseList,
+                   mlir::Location loc) {
   ClauseProcessor cp(converter, semaCtx, beginClauseList);
   cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
-                 clause::Order, clause::Safelen, clause::Simdlen>(loc,
-                                                                  ompDirective);
+                 clause::Order, clause::Safelen, clause::Simdlen>(
+      loc, llvm::omp::OMPD_do_simd);
   // TODO: Add support for vectorization - add vectorization hints inside loop
   // body.
   // OpenMP standard does not specify the length of vector instructions.
@@ -1825,6 +1927,16 @@ static void genCompositeDoSimd(
   genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList);
 }
 
+static void
+genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter,
+                         Fortran::semantics::SemanticsContext &semaCtx,
+                         Fortran::lower::pft::Evaluation &eval,
+                         const Fortran::parser::OmpClauseList &beginClauseList,
+                         const Fortran::parser::OmpClauseList *endClauseList,
+                         mlir::Location loc) {
+  TODO(loc, "Composite TASKLOOP SIMD");
+}
+
 //===----------------------------------------------------------------------===//
 // OpenMPDeclarativeConstruct visitors
 //===----------------------------------------------------------------------===//
@@ -2082,13 +2194,18 @@ genOMP(Fortran::lower::AbstractConverter &converter,
       std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
   const auto &endBlockDirective =
       std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
-  const auto &directive =
-      std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
+  mlir::Location currentLocation =
+      converter.genLocation(beginBlockDirective.source);
+  const auto origDirective =
+      std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t).v;
   const auto &beginClauseList =
       std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
   const auto &endClauseList =
       std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t);
 
+  assert(llvm::omp::blockConstructSet.test(origDirective) &&
+         "Expected block construct");
+
   for (const Fortran::parser::OmpClause &clause : beginClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (!std::get_if<Fortran::parser::OmpClause::If>(&clause.u) &&
@@ -2124,93 +2241,74 @@ genOMP(Fortran::lower::AbstractConverter &converter,
       TODO(clauseLocation, "OpenMP Block construct clause");
   }
 
-  bool singleDirective = true;
-  mlir::Location currentLocation = converter.genLocation(directive.source);
-  switch (directive.v) {
-  case llvm::omp::Directive::OMPD_master:
-    genMasterOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation);
-    break;
-  case llvm::omp::Directive::OMPD_ordered:
-    genOrderedRegionOp(converter, semaCtx, eval, /*genNested=*/true,
-                       currentLocation, beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_parallel:
-    genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true,
-                  currentLocation, beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_single:
-    genSingleOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
-                beginClauseList, endClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_target:
-    genTargetOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
+  std::optional<llvm::omp::Directive> nextDir = origDirective;
+  bool outermostLeafConstruct = true;
+  while (nextDir) {
+    llvm::omp::Directive leafDir;
+    std::tie(leafDir, nextDir) = splitCombinedDirective(*nextDir);
+    const bool genNested = !nextDir;
+    const bool outerCombined = outermostLeafConstruct && nextDir.has_value();
+    switch (leafDir) {
+    case llvm::omp::Directive::OMPD_master:
+      // 2.16 MASTER construct.
+      genMasterOp(converter, semaCtx, eval, genNested, currentLocation);
+      break;
+    case llvm::omp::Directive::OMPD_ordered:
+      // 2.17.9 ORDERED construct.
+      genOrderedRegionOp(converter, semaCtx, eval, genNested, currentLocation,
+                         beginClauseList);
+      break;
+    case llvm::omp::Directive::OMPD_parallel:
+      // 2.6 PARALLEL construct.
+      genParallelOp(converter, symTable, semaCtx, eval, genNested,
+                    currentLocation, beginClauseList, outerCombined);
+      break;
+    case llvm::omp::Directive::OMPD_single:
+      // 2.8.2 SINGLE construct.
+      genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
+                  beginClauseList, endClauseList);
+      break;
+    case llvm::omp::Directive::OMPD_target:
+      // 2.12.5 TARGET construct.
+      genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
+                  beginClauseList, outerCombined);
+      break;
+    case llvm::omp::Directive::OMPD_target_data:
+      // 2.12.2 TARGET DATA construct.
+      genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation,
+                      beginClauseList);
+      break;
+    case llvm::omp::Directive::OMPD_task:
+      // 2.10.1 TASK construct.
+      genTaskOp(converter, semaCtx, eval, genNested, currentLocation,
                 beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_target_data:
-    genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true,
-                    currentLocation, beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_task:
-    genTaskOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
-              beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_taskgroup:
-    genTaskgroupOp(converter, semaCtx, eval, /*genNested=*/true,
-                   currentLocation, beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_teams:
-    genTeamsOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
-               beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_workshare:
-    // FIXME: Workshare is not a commonly used OpenMP construct, an
-    // implementation for this feature will come later. For the codes
-    // that use this construct, add a single construct for now.
-    genSingleOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
-                beginClauseList, endClauseList);
-    break;
-  default:
-    singleDirective = false;
-    break;
-  }
-
-  if (singleDirective)
-    return;
-
-  // Codegen for combined directives
-  bool combinedDirective = false;
-  if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet)
-          .test(directive.v)) {
-    genTargetOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
-                beginClauseList, /*outerCombined=*/true);
-    combinedDirective = true;
-  }
-  if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet)
-          .test(directive.v)) {
-    genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
-               beginClauseList);
-    combinedDirective = true;
-  }
-  if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet)
-          .test(directive.v)) {
-    bool outerCombined =
-        directive.v != llvm::omp::Directive::OMPD_target_parallel;
-    genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false,
-                  currentLocation, beginClauseList, outerCombined);
-    combinedDirective = true;
-  }
-  if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet)
-          .test(directive.v)) {
-    genSingleOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
-                beginClauseList, endClauseList);
-    combinedDirective = true;
+      break;
+    case llvm::omp::Directive::OMPD_taskgroup:
+      // 2.17.6 TASKGROUP construct.
+      genTaskgroupOp(converter, semaCtx, eval, genNested, currentLocation,
+                     beginClauseList);
+      break;
+    case llvm::omp::Directive::OMPD_teams:
+      // 2.7 TEAMS construct.
+      // FIXME Pass the outerCombined argument or rename it to better describe
+      // what it represents if it must always be `false` in this context.
+      genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
+                 beginClauseList);
+      break;
+    case llvm::omp::Directive::OMPD_workshare:
+      // 2.8.3 WORKSHARE construct.
+      // FIXME: Workshare is not a commonly used OpenMP construct, an
+      // implementation for this feature will come later. For the codes
+      // that use this construct, add a single construct for now.
+      genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
+                  beginClauseList, endClauseList);
+      break;
+    default:
+      llvm_unreachable("Unexpected block construct");
+      break;
+    }
+    outermostLeafConstruct = false;
   }
-  if (!combinedDirective)
-    TODO(currentLocation, "Unhandled block directive (" +
-                              llvm::omp::getOpenMPDirectiveName(directive.v) +
-                              ")");
-
-  genNestedEvaluations(converter, eval);
 }
 
 static void
@@ -2248,9 +2346,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
       std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t);
   mlir::Location currentLocation =
       converter.genLocation(beginLoopDirective.source);
-  const auto ompDirective =
+  const auto origDirective =
       std::get<Fortran::parser::OmpLoopDirective>(beginLoopDirective.t).v;
 
+  assert(llvm::omp::loopConstructSet.test(origDirective) &&
+         "Expected loop construct");
+
   const auto *endClauseList = [&]() {
     using RetTy = const Fortran::parser::OmpClauseList *;
     if (auto &endLoopDirective =
@@ -2262,56 +2363,103 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     return RetTy();
   }();
 
-  bool validDirective = false;
-  if (llvm::omp::topTaskloopSet.test(ompDirective)) {
-    validDirective = true;
-    genTaskloopOp(converter, semaCtx, eval, currentLocation, beginClauseList);
-  } else {
-    // Create omp.{target, teams, distribute, parallel} nested operations
-    if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet)
-            .test(ompDirective)) {
-      validDirective = true;
-      genTargetOp(converter, semaCtx, eval, /*genNested=*/false,
-                  currentLocation, beginClauseList, /*outerCombined=*/true);
-    }
-    if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
-            .test(ompDirective)) {
-      validDirective = true;
-      genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
-                 beginClauseList, /*outerCombined=*/true);
-    }
-    if (llvm::omp::allDistributeSet.test(ompDirective)) {
-      validDirective = true;
-      genDistributeOp(converter, semaCtx, eval, /*genNested=*/false,
-                      currentLocation, beginClauseList);
-    }
-    if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
-            .test(ompDirective)) {
-      validDirective = true;
-      genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false,
-                    currentLocation, beginClauseList, /*outerCombined=*/true);
+  std::optional<llvm::omp::Directive> nextDir = origDirective;
+  while (nextDir) {
+    llvm::omp::Directive leafDir;
+    std::tie(leafDir, nextDir) = splitCombinedDirective(*nextDir);
+    if (llvm::omp::compositeConstructSet.test(leafDir)) {
+      assert(!nextDir && "Composite construct cannot be split");
+      switch (leafDir) {
+      case llvm::omp::Directive::OMPD_distribute_parallel_do:
+        // 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct.
+        genCompositeDistributeParallelDo(converter, semaCtx, eval,
+                                         beginClauseList, endClauseList,
+                                         currentLocation);
+        break;
+      case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
+        // 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct.
+        genCompositeDistributeParallelDoSimd(converter, semaCtx, eval,
+                                             beginClauseList, endClauseList,
+                                             currentLocation);
+        break;
+      case llvm::omp::Directive::OMPD_distribute_simd:
+        // 2.9.4.2 DISTRIBUTE SIMD construct.
+        genCompositeDistributeSimd(converter, semaCtx, eval, beginClauseList,
+                                   endClauseList, currentLocation);
+        break;
+      case llvm::omp::Directive::OMPD_do_simd:
+        // 2.9.3.2 Worksharing-Loop SIMD construct.
+        genCompositeDoSimd(converter, semaCtx, eval, beginClauseList,
+                           endClauseList, currentLocation);
+        break;
+      case llvm::omp::Directive::OMPD_taskloop_simd:
+        // 2.10.3 TASKLOOP SIMD construct.
+        genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauseList,
+                                 endClauseList, currentLocation);
+        break;
+      default:
+        llvm_unreachable("Unexpected composite construct");
+      }
+    } else {
+      const bool genNested = !nextDir;
+      switch (leafDir) {
+      case llvm::omp::Directive::OMPD_distribute:
+        // 2.9.4.1 DISTRIBUTE construct.
+        genDistributeOp(converter, semaCtx, eval, genNested, currentLocation,
+                        beginClauseList);
+        break;
+      case llvm::omp::Directive::OMPD_do:
+        // 2.9.2 Worksharing-Loop construct.
+        genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList,
+                    endClauseList);
+        break;
+      case llvm::omp::Directive::OMPD_parallel:
+        // 2.6 PARALLEL construct.
+        // FIXME This is not necessarily always the outer leaf construct of a
+        // combined construct in this context (e.g. DISTRIBUTE PARALLEL DO).
+        // Maybe rename the argument if it represents something else or
+        // initialize it properly.
+        genParallelOp(converter, symTable, semaCtx, eval, genNested,
+                      currentLocation, beginClauseList,
+                      /*outerCombined=*/true);
+        break;
+      case llvm::omp::Directive::OMPD_simd:
+        // 2.9.3.1 SIMD construct.
+        genSimdOp(converter, semaCtx, eval, currentLocation, beginClauseList);
+        break;
+      case llvm::omp::Directive::OMPD_target:
+        // 2.12.5 TARGET construct.
+        genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
+                    beginClauseList, /*outerCombined=*/true);
+        break;
+      case llvm::omp::Directive::OMPD_taskloop:
+        // 2.10.2 TASKLOOP construct.
+        genTaskloopOp(converter, semaCtx, eval, currentLocation,
+                      beginClauseList);
+        break;
+      case llvm::omp::Directive::OMPD_teams:
+        // 2.7 TEAMS construct.
+        // FIXME This is not necessarily always the outer leaf construct of a
+        // combined construct in this constext (e.g. TARGET TEAMS DISTRIBUTE).
+        // Maybe rename the argument if it represents something else or
+        // initialize it properly.
+        genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
+                   beginClauseList, /*outerCombined=*/true);
+        break;
+      case llvm::omp::Directive::OMPD_loop:
+      case llvm::omp::Directive::OMPD_masked:
+      case llvm::omp::Directive::OMPD_master:
+      case llvm::omp::Directive::OMPD_tile:
+      case llvm::omp::Directive::OMPD_unroll:
+        TODO(currentLocation, "Unhandled loop directive (" +
+                                  llvm::omp::getOpenMPDirectiveName(leafDir) +
+                                  ")");
+        break;
+      default:
+        llvm_unreachable("Unexpected loop construct");
+      }
     }
   }
-  if ((llvm::omp::allDoSet | llvm::omp::allSimdSet).test(ompDirective))
-    validDirective = true;
-
-  if (!validDirective) {
-    TODO(currentLocation, "Unhandled loop directive (" +
-                              llvm::omp::getOpenMPDirectiveName(ompDirective) +
-                              ")");
-  }
-
-  if (llvm::omp::allDoSimdSet.test(ompDirective)) {
-    // 2.9.3.2 Workshare SIMD construct
-    genCompositeDoSimd(converter, semaCtx, eval, ompDirective, beginClauseList,
-                       endClauseList, currentLocation);
-  } else if (llvm::omp::allSimdSet.test(ompDirective)) {
-    // 2.9.3.1 SIMD construct
-    genSimdOp(converter, semaCtx, eval, currentLocation, beginClauseList);
-  } else {
-    genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList,
-                endClauseList);
-  }
 }
 
 static void


        


More information about the flang-commits mailing list