[llvm-branch-commits] [flang] [Flang][OpenMP][Lower] Refactor lowering of compound constructs (PR #87070)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Mar 29 07:09:41 PDT 2024


https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/87070

This patch simplifies the lowering from PFT to MLIR of OpenMP compound constructs (i.e. combined and composite).

The new approach consists of iteratively processing the outermost leaf construct of the given combined construct until it cannot be split further. Both leaf constructs and composite ones have `gen...()` functions that are called when appropriate.

This approach enables treating a leaf construct the same way regardless of if it appeared as part of a combined construct, and it also enables the lowering of composite constructs as a single unit.

Previous corner cases are now handled in a more straightforward way and comments pointing to the relevant spec section are added. Directive sets are also completed with missing LOOP related constructs.

>From ec0ed50b0d5f9606f0e9a1a3a9999f601bec310f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 29 Mar 2024 13:57:40 +0000
Subject: [PATCH] [Flang][OpenMP][Lower] Refactor lowering of compound
 constructs

This patch simplifies the lowering from PFT to MLIR of OpenMP compound
constructs (i.e. combined and composite).

The new approach consists of iteratively processing the outermost leaf
construct of the given combined construct until it cannot be split further.
Both leaf constructs and composite ones have `gen...()` functions that are
called when appropriate.

This approach enables treating a leaf construct the same way regardless of if
it appeared as part of a combined construct, and it also enables the lowering
of composite constructs as a single unit.

Previous corner cases are now handled in a more straightforward way and
comments pointing to the relevant spec section are added. Directive sets are
also completed with missing LOOP related constructs.
---
 .../flang/Semantics/openmp-directive-sets.h   |  57 ++-
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 432 ++++++++++++------
 2 files changed, 335 insertions(+), 154 deletions(-)

diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 91773ae3ea9a3e..842d251b682aa9 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -32,14 +32,14 @@ static const OmpDirectiveSet topDistributeSet{
 
 static const OmpDirectiveSet allDistributeSet{
     OmpDirectiveSet{
-        llvm::omp::OMPD_target_teams_distribute,
-        llvm::omp::OMPD_target_teams_distribute_parallel_do,
-        llvm::omp::OMPD_target_teams_distribute_parallel_do_simd,
-        llvm::omp::OMPD_target_teams_distribute_simd,
-        llvm::omp::OMPD_teams_distribute,
-        llvm::omp::OMPD_teams_distribute_parallel_do,
-        llvm::omp::OMPD_teams_distribute_parallel_do_simd,
-        llvm::omp::OMPD_teams_distribute_simd,
+        Directive::OMPD_target_teams_distribute,
+        Directive::OMPD_target_teams_distribute_parallel_do,
+        Directive::OMPD_target_teams_distribute_parallel_do_simd,
+        Directive::OMPD_target_teams_distribute_simd,
+        Directive::OMPD_teams_distribute,
+        Directive::OMPD_teams_distribute_parallel_do,
+        Directive::OMPD_teams_distribute_parallel_do_simd,
+        Directive::OMPD_teams_distribute_simd,
     } | topDistributeSet,
 };
 
@@ -63,10 +63,24 @@ static const OmpDirectiveSet allDoSet{
     } | topDoSet,
 };
 
+static const OmpDirectiveSet topLoopSet{
+    Directive::OMPD_loop,
+};
+
+static const OmpDirectiveSet allLoopSet{
+    OmpDirectiveSet{
+        Directive::OMPD_parallel_loop,
+        Directive::OMPD_target_parallel_loop,
+        Directive::OMPD_target_teams_loop,
+        Directive::OMPD_teams_loop,
+    } | topLoopSet,
+};
+
 static const OmpDirectiveSet topParallelSet{
     Directive::OMPD_parallel,
     Directive::OMPD_parallel_do,
     Directive::OMPD_parallel_do_simd,
+    Directive::OMPD_parallel_loop,
     Directive::OMPD_parallel_masked_taskloop,
     Directive::OMPD_parallel_masked_taskloop_simd,
     Directive::OMPD_parallel_master_taskloop,
@@ -82,6 +96,7 @@ static const OmpDirectiveSet allParallelSet{
         Directive::OMPD_target_parallel,
         Directive::OMPD_target_parallel_do,
         Directive::OMPD_target_parallel_do_simd,
+        Directive::OMPD_target_parallel_loop,
         Directive::OMPD_target_teams_distribute_parallel_do,
         Directive::OMPD_target_teams_distribute_parallel_do_simd,
         Directive::OMPD_teams_distribute_parallel_do,
@@ -118,12 +133,14 @@ static const OmpDirectiveSet topTargetSet{
     Directive::OMPD_target_parallel,
     Directive::OMPD_target_parallel_do,
     Directive::OMPD_target_parallel_do_simd,
+    Directive::OMPD_target_parallel_loop,
     Directive::OMPD_target_simd,
     Directive::OMPD_target_teams,
     Directive::OMPD_target_teams_distribute,
     Directive::OMPD_target_teams_distribute_parallel_do,
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
+    Directive::OMPD_target_teams_loop,
 };
 
 static const OmpDirectiveSet allTargetSet{topTargetSet};
@@ -156,11 +173,12 @@ static const OmpDirectiveSet topTeamsSet{
 
 static const OmpDirectiveSet allTeamsSet{
     OmpDirectiveSet{
-        llvm::omp::OMPD_target_teams,
-        llvm::omp::OMPD_target_teams_distribute,
-        llvm::omp::OMPD_target_teams_distribute_parallel_do,
-        llvm::omp::OMPD_target_teams_distribute_parallel_do_simd,
-        llvm::omp::OMPD_target_teams_distribute_simd,
+        Directive::OMPD_target_teams,
+        Directive::OMPD_target_teams_distribute,
+        Directive::OMPD_target_teams_distribute_parallel_do,
+        Directive::OMPD_target_teams_distribute_parallel_do_simd,
+        Directive::OMPD_target_teams_distribute_simd,
+        Directive::OMPD_target_teams_loop,
     } | topTeamsSet,
 };
 
@@ -178,6 +196,14 @@ static const OmpDirectiveSet allDistributeSimdSet{
 static const OmpDirectiveSet allDoSimdSet{allDoSet & allSimdSet};
 static const OmpDirectiveSet allTaskloopSimdSet{allTaskloopSet & allSimdSet};
 
+static const OmpDirectiveSet compositeConstructSet{
+    Directive::OMPD_distribute_parallel_do,
+    Directive::OMPD_distribute_parallel_do_simd,
+    Directive::OMPD_distribute_simd,
+    Directive::OMPD_do_simd,
+    Directive::OMPD_taskloop_simd,
+};
+
 static const OmpDirectiveSet blockConstructSet{
     Directive::OMPD_master,
     Directive::OMPD_ordered,
@@ -201,12 +227,14 @@ static const OmpDirectiveSet loopConstructSet{
     Directive::OMPD_distribute_simd,
     Directive::OMPD_do,
     Directive::OMPD_do_simd,
+    Directive::OMPD_loop,
     Directive::OMPD_masked_taskloop,
     Directive::OMPD_masked_taskloop_simd,
     Directive::OMPD_master_taskloop,
     Directive::OMPD_master_taskloop_simd,
     Directive::OMPD_parallel_do,
     Directive::OMPD_parallel_do_simd,
+    Directive::OMPD_parallel_loop,
     Directive::OMPD_parallel_masked_taskloop,
     Directive::OMPD_parallel_masked_taskloop_simd,
     Directive::OMPD_parallel_master_taskloop,
@@ -214,17 +242,20 @@ static const OmpDirectiveSet loopConstructSet{
     Directive::OMPD_simd,
     Directive::OMPD_target_parallel_do,
     Directive::OMPD_target_parallel_do_simd,
+    Directive::OMPD_target_parallel_loop,
     Directive::OMPD_target_simd,
     Directive::OMPD_target_teams_distribute,
     Directive::OMPD_target_teams_distribute_parallel_do,
     Directive::OMPD_target_teams_distribute_parallel_do_simd,
     Directive::OMPD_target_teams_distribute_simd,
+    Directive::OMPD_target_teams_loop,
     Directive::OMPD_taskloop,
     Directive::OMPD_taskloop_simd,
     Directive::OMPD_teams_distribute,
     Directive::OMPD_teams_distribute_parallel_do,
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
+    Directive::OMPD_teams_loop,
     Directive::OMPD_tile,
     Directive::OMPD_unroll,
 };
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 692d81f9188be3..edae453972d3d9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -710,6 +710,81 @@ genOpenMPReduction(Fortran::lower::AbstractConverter &converter,
   }
 }
 
+/// Split a combined directive into an outer leaf directive and the (possibly
+/// combined) rest of the combined directive. Composite directives and
+/// non-compound directives are not split, in which case it will return the
+/// input directive as its first output and an empty value as its second output.
+static std::pair<llvm::omp::Directive, std::optional<llvm::omp::Directive>>
+splitCombinedDirective(llvm::omp::Directive dir) {
+  using D = llvm::omp::Directive;
+  switch (dir) {
+  case D::OMPD_masked_taskloop:
+    return {D::OMPD_masked, D::OMPD_taskloop};
+  case D::OMPD_masked_taskloop_simd:
+    return {D::OMPD_masked, D::OMPD_taskloop_simd};
+  case D::OMPD_master_taskloop:
+    return {D::OMPD_master, D::OMPD_taskloop};
+  case D::OMPD_master_taskloop_simd:
+    return {D::OMPD_master, D::OMPD_taskloop_simd};
+  case D::OMPD_parallel_do:
+    return {D::OMPD_parallel, D::OMPD_do};
+  case D::OMPD_parallel_do_simd:
+    return {D::OMPD_parallel, D::OMPD_do_simd};
+  case D::OMPD_parallel_masked:
+    return {D::OMPD_parallel, D::OMPD_masked};
+  case D::OMPD_parallel_masked_taskloop:
+    return {D::OMPD_parallel, D::OMPD_masked_taskloop};
+  case D::OMPD_parallel_masked_taskloop_simd:
+    return {D::OMPD_parallel, D::OMPD_masked_taskloop_simd};
+  case D::OMPD_parallel_master:
+    return {D::OMPD_parallel, D::OMPD_master};
+  case D::OMPD_parallel_master_taskloop:
+    return {D::OMPD_parallel, D::OMPD_master_taskloop};
+  case D::OMPD_parallel_master_taskloop_simd:
+    return {D::OMPD_parallel, D::OMPD_master_taskloop_simd};
+  case D::OMPD_parallel_sections:
+    return {D::OMPD_parallel, D::OMPD_sections};
+  case D::OMPD_parallel_workshare:
+    return {D::OMPD_parallel, D::OMPD_workshare};
+  case D::OMPD_target_parallel:
+    return {D::OMPD_target, D::OMPD_parallel};
+  case D::OMPD_target_parallel_do:
+    return {D::OMPD_target, D::OMPD_parallel_do};
+  case D::OMPD_target_parallel_do_simd:
+    return {D::OMPD_target, D::OMPD_parallel_do_simd};
+  case D::OMPD_target_simd:
+    return {D::OMPD_target, D::OMPD_simd};
+  case D::OMPD_target_teams:
+    return {D::OMPD_target, D::OMPD_teams};
+  case D::OMPD_target_teams_distribute:
+    return {D::OMPD_target, D::OMPD_teams_distribute};
+  case D::OMPD_target_teams_distribute_parallel_do:
+    return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do};
+  case D::OMPD_target_teams_distribute_parallel_do_simd:
+    return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do_simd};
+  case D::OMPD_target_teams_distribute_simd:
+    return {D::OMPD_target, D::OMPD_teams_distribute_simd};
+  case D::OMPD_teams_distribute:
+    return {D::OMPD_teams, D::OMPD_distribute};
+  case D::OMPD_teams_distribute_parallel_do:
+    return {D::OMPD_teams, D::OMPD_distribute_parallel_do};
+  case D::OMPD_teams_distribute_parallel_do_simd:
+    return {D::OMPD_teams, D::OMPD_distribute_parallel_do_simd};
+  case D::OMPD_teams_distribute_simd:
+    return {D::OMPD_teams, D::OMPD_distribute_simd};
+  case D::OMPD_parallel_loop:
+    return {D::OMPD_parallel, D::OMPD_loop};
+  case D::OMPD_target_parallel_loop:
+    return {D::OMPD_target, D::OMPD_parallel_loop};
+  case D::OMPD_target_teams_loop:
+    return {D::OMPD_target, D::OMPD_teams_loop};
+  case D::OMPD_teams_loop:
+    return {D::OMPD_teams, D::OMPD_loop};
+  default:
+    return {dir, std::nullopt};
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Op body generation helper structures and functions
 //===----------------------------------------------------------------------===//
@@ -1962,16 +2037,44 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
 // Code generation functions for composite constructs
 //===----------------------------------------------------------------------===//
 
-static void genCompositeDoSimd(
+static void genCompositeDistributeParallelDo(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::semantics::SemanticsContext &semaCtx,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OmpClauseList &beginClauseList,
+    const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+  TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
+}
+
+static void genCompositeDistributeParallelDoSimd(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::semantics::SemanticsContext &semaCtx,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OmpClauseList &beginClauseList,
+    const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+  TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
+}
+
+static void genCompositeDistributeSimd(
     Fortran::lower::AbstractConverter &converter,
     Fortran::semantics::SemanticsContext &semaCtx,
-    Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive ompDirective,
+    Fortran::lower::pft::Evaluation &eval,
     const Fortran::parser::OmpClauseList &beginClauseList,
     const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+  TODO(loc, "Composite DISTRIBUTE SIMD");
+}
+
+static void
+genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
+                   Fortran::semantics::SemanticsContext &semaCtx,
+                   Fortran::lower::pft::Evaluation &eval,
+                   const Fortran::parser::OmpClauseList &beginClauseList,
+                   const Fortran::parser::OmpClauseList *endClauseList,
+                   mlir::Location loc) {
   ClauseProcessor cp(converter, semaCtx, beginClauseList);
   cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
-                 clause::Order, clause::Safelen, clause::Simdlen>(loc,
-                                                                  ompDirective);
+                 clause::Order, clause::Safelen, clause::Simdlen>(
+      loc, llvm::omp::OMPD_do_simd);
   // TODO: Add support for vectorization - add vectorization hints inside loop
   // body.
   // OpenMP standard does not specify the length of vector instructions.
@@ -1983,6 +2086,16 @@ static void genCompositeDoSimd(
   genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList);
 }
 
+static void
+genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter,
+                         Fortran::semantics::SemanticsContext &semaCtx,
+                         Fortran::lower::pft::Evaluation &eval,
+                         const Fortran::parser::OmpClauseList &beginClauseList,
+                         const Fortran::parser::OmpClauseList *endClauseList,
+                         mlir::Location loc) {
+  TODO(loc, "Composite TASKLOOP SIMD");
+}
+
 //===----------------------------------------------------------------------===//
 // OpenMPDeclarativeConstruct visitors
 //===----------------------------------------------------------------------===//
@@ -2240,13 +2353,18 @@ genOMP(Fortran::lower::AbstractConverter &converter,
       std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
   const auto &endBlockDirective =
       std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
-  const auto &directive =
-      std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
+  mlir::Location currentLocation =
+      converter.genLocation(beginBlockDirective.source);
+  const auto origDirective =
+      std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t).v;
   const auto &beginClauseList =
       std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
   const auto &endClauseList =
       std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t);
 
+  assert(llvm::omp::blockConstructSet.test(origDirective) &&
+         "Expected block construct");
+
   for (const Fortran::parser::OmpClause &clause : beginClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (!std::get_if<Fortran::parser::OmpClause::If>(&clause.u) &&
@@ -2280,93 +2398,74 @@ genOMP(Fortran::lower::AbstractConverter &converter,
       TODO(clauseLocation, "OpenMP Block construct clause");
   }
 
-  bool singleDirective = true;
-  mlir::Location currentLocation = converter.genLocation(directive.source);
-  switch (directive.v) {
-  case llvm::omp::Directive::OMPD_master:
-    genMasterOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation);
-    break;
-  case llvm::omp::Directive::OMPD_ordered:
-    genOrderedRegionOp(converter, semaCtx, eval, /*genNested=*/true,
-                       currentLocation, beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_parallel:
-    genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true,
-                  currentLocation, beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_single:
-    genSingleOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
-                beginClauseList, endClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_target:
-    genTargetOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
+  std::optional<llvm::omp::Directive> nextDir = origDirective;
+  bool outermostLeafConstruct = true;
+  while (nextDir) {
+    llvm::omp::Directive leafDir;
+    std::tie(leafDir, nextDir) = splitCombinedDirective(*nextDir);
+    const bool genNested = !nextDir;
+    const bool outerCombined = outermostLeafConstruct && nextDir.has_value();
+    switch (leafDir) {
+    case llvm::omp::Directive::OMPD_master:
+      // 2.16 MASTER construct.
+      genMasterOp(converter, semaCtx, eval, genNested, currentLocation);
+      break;
+    case llvm::omp::Directive::OMPD_ordered:
+      // 2.17.9 ORDERED construct.
+      genOrderedRegionOp(converter, semaCtx, eval, genNested, currentLocation,
+                         beginClauseList);
+      break;
+    case llvm::omp::Directive::OMPD_parallel:
+      // 2.6 PARALLEL construct.
+      genParallelOp(converter, symTable, semaCtx, eval, genNested,
+                    currentLocation, beginClauseList, outerCombined);
+      break;
+    case llvm::omp::Directive::OMPD_single:
+      // 2.8.2 SINGLE construct.
+      genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
+                  beginClauseList, endClauseList);
+      break;
+    case llvm::omp::Directive::OMPD_target:
+      // 2.12.5 TARGET construct.
+      genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
+                  beginClauseList, outerCombined);
+      break;
+    case llvm::omp::Directive::OMPD_target_data:
+      // 2.12.2 TARGET DATA construct.
+      genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation,
+                      beginClauseList);
+      break;
+    case llvm::omp::Directive::OMPD_task:
+      // 2.10.1 TASK construct.
+      genTaskOp(converter, semaCtx, eval, genNested, currentLocation,
                 beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_target_data:
-    genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true,
-                    currentLocation, beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_task:
-    genTaskOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
-              beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_taskgroup:
-    genTaskgroupOp(converter, semaCtx, eval, /*genNested=*/true,
-                   currentLocation, beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_teams:
-    genTeamsOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
-               beginClauseList);
-    break;
-  case llvm::omp::Directive::OMPD_workshare:
-    // FIXME: Workshare is not a commonly used OpenMP construct, an
-    // implementation for this feature will come later. For the codes
-    // that use this construct, add a single construct for now.
-    genSingleOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
-                beginClauseList, endClauseList);
-    break;
-  default:
-    singleDirective = false;
-    break;
-  }
-
-  if (singleDirective)
-    return;
-
-  // Codegen for combined directives
-  bool combinedDirective = false;
-  if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet)
-          .test(directive.v)) {
-    genTargetOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
-                beginClauseList, /*outerCombined=*/true);
-    combinedDirective = true;
-  }
-  if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet)
-          .test(directive.v)) {
-    genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
-               beginClauseList);
-    combinedDirective = true;
-  }
-  if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet)
-          .test(directive.v)) {
-    bool outerCombined =
-        directive.v != llvm::omp::Directive::OMPD_target_parallel;
-    genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false,
-                  currentLocation, beginClauseList, outerCombined);
-    combinedDirective = true;
-  }
-  if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet)
-          .test(directive.v)) {
-    genSingleOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
-                beginClauseList, endClauseList);
-    combinedDirective = true;
+      break;
+    case llvm::omp::Directive::OMPD_taskgroup:
+      // 2.17.6 TASKGROUP construct.
+      genTaskgroupOp(converter, semaCtx, eval, genNested, currentLocation,
+                     beginClauseList);
+      break;
+    case llvm::omp::Directive::OMPD_teams:
+      // 2.7 TEAMS construct.
+      // FIXME Pass the outerCombined argument or rename it to better describe
+      // what it represents if it must always be `false` in this context.
+      genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
+                 beginClauseList);
+      break;
+    case llvm::omp::Directive::OMPD_workshare:
+      // 2.8.3 WORKSHARE construct.
+      // FIXME: Workshare is not a commonly used OpenMP construct, an
+      // implementation for this feature will come later. For the codes
+      // that use this construct, add a single construct for now.
+      genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
+                  beginClauseList, endClauseList);
+      break;
+    default:
+      llvm_unreachable("Unexpected block construct");
+      break;
+    }
+    outermostLeafConstruct = false;
   }
-  if (!combinedDirective)
-    TODO(currentLocation, "Unhandled block directive (" +
-                              llvm::omp::getOpenMPDirectiveName(directive.v) +
-                              ")");
-
-  genNestedEvaluations(converter, eval);
 }
 
 static void
@@ -2404,9 +2503,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
       std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t);
   mlir::Location currentLocation =
       converter.genLocation(beginLoopDirective.source);
-  const auto ompDirective =
+  const auto origDirective =
       std::get<Fortran::parser::OmpLoopDirective>(beginLoopDirective.t).v;
 
+  assert(llvm::omp::loopConstructSet.test(origDirective) &&
+         "Expected loop construct");
+
   const auto *endClauseList = [&]() {
     using RetTy = const Fortran::parser::OmpClauseList *;
     if (auto &endLoopDirective =
@@ -2418,57 +2520,105 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     return RetTy();
   }();
 
-  bool validDirective = false;
-  if (llvm::omp::topTaskloopSet.test(ompDirective)) {
-    validDirective = true;
-    genTaskloopOp(converter, semaCtx, eval, currentLocation, beginClauseList);
-  } else {
-    // Create omp.{target, teams, distribute, parallel} nested operations
-    if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet)
-            .test(ompDirective)) {
-      validDirective = true;
-      genTargetOp(converter, semaCtx, eval, /*genNested=*/false,
-                  currentLocation, beginClauseList, /*outerCombined=*/true);
-    }
-    if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
-            .test(ompDirective)) {
-      validDirective = true;
-      genTeamsOp(converter, semaCtx, eval, /*genNested=*/false, currentLocation,
-                 beginClauseList, /*outerCombined=*/true);
-    }
-    if (llvm::omp::allDistributeSet.test(ompDirective)) {
-      validDirective = true;
-      genDistributeOp(converter, semaCtx, eval, /*genNested=*/false,
-                      currentLocation, beginClauseList);
-    }
-    if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
-            .test(ompDirective)) {
-      validDirective = true;
-      genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false,
-                    currentLocation, beginClauseList, /*outerCombined=*/true);
+  std::optional<llvm::omp::Directive> nextDir = origDirective;
+  while (nextDir) {
+    llvm::omp::Directive leafDir;
+    std::tie(leafDir, nextDir) = splitCombinedDirective(*nextDir);
+    if (llvm::omp::compositeConstructSet.test(leafDir)) {
+      assert(!nextDir && "Composite construct cannot be split");
+      switch (leafDir) {
+      case llvm::omp::Directive::OMPD_distribute_parallel_do:
+        // 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct.
+        genCompositeDistributeParallelDo(converter, semaCtx, eval,
+                                         beginClauseList, endClauseList,
+                                         currentLocation);
+        break;
+      case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
+        // 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct.
+        genCompositeDistributeParallelDoSimd(converter, semaCtx, eval,
+                                             beginClauseList, endClauseList,
+                                             currentLocation);
+        break;
+      case llvm::omp::Directive::OMPD_distribute_simd:
+        // 2.9.4.2 DISTRIBUTE SIMD construct.
+        genCompositeDistributeSimd(converter, semaCtx, eval, beginClauseList,
+                                   endClauseList, currentLocation);
+        break;
+      case llvm::omp::Directive::OMPD_do_simd:
+        // 2.9.3.2 Worksharing-Loop SIMD construct.
+        genCompositeDoSimd(converter, semaCtx, eval, beginClauseList,
+                           endClauseList, currentLocation);
+        break;
+      case llvm::omp::Directive::OMPD_taskloop_simd:
+        // 2.10.3 TASKLOOP SIMD construct.
+        genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauseList,
+                                 endClauseList, currentLocation);
+        break;
+      default:
+        llvm_unreachable("Unexpected composite construct");
+      }
+    } else {
+      const bool genNested = !nextDir;
+      switch (leafDir) {
+      case llvm::omp::Directive::OMPD_distribute:
+        // 2.9.4.1 DISTRIBUTE construct.
+        genDistributeOp(converter, semaCtx, eval, genNested, currentLocation,
+                        beginClauseList);
+        break;
+      case llvm::omp::Directive::OMPD_do:
+        // 2.9.2 Worksharing-Loop construct.
+        genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList,
+                    endClauseList);
+        break;
+      case llvm::omp::Directive::OMPD_parallel:
+        // 2.6 PARALLEL construct.
+        // FIXME This is not necessarily always the outer leaf construct of a
+        // combined construct in this constext (e.g. distribute parallel do).
+        // Maybe rename the argument if it represents something else or
+        // initialize it properly.
+        genParallelOp(converter, symTable, semaCtx, eval, genNested,
+                      currentLocation, beginClauseList,
+                      /*outerCombined=*/true);
+        break;
+      case llvm::omp::Directive::OMPD_simd:
+        // 2.9.3.1 SIMD construct.
+        genSimdLoopOp(converter, semaCtx, eval, currentLocation,
+                      beginClauseList);
+        genOpenMPReduction(converter, semaCtx, beginClauseList);
+        break;
+      case llvm::omp::Directive::OMPD_target:
+        // 2.12.5 TARGET construct.
+        genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
+                    beginClauseList, /*outerCombined=*/true);
+        break;
+      case llvm::omp::Directive::OMPD_taskloop:
+        // 2.10.2 TASKLOOP construct.
+        genTaskloopOp(converter, semaCtx, eval, currentLocation,
+                      beginClauseList);
+        break;
+      case llvm::omp::Directive::OMPD_teams:
+        // 2.7 TEAMS construct.
+        // FIXME This is not necessarily always the outer leaf construct of a
+        // combined construct in this constext (e.g. target teams distribute).
+        // Maybe rename the argument if it represents something else or
+        // initialize it properly.
+        genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
+                   beginClauseList, /*outerCombined=*/true);
+        break;
+      case llvm::omp::Directive::OMPD_loop:
+      case llvm::omp::Directive::OMPD_masked:
+      case llvm::omp::Directive::OMPD_master:
+      case llvm::omp::Directive::OMPD_tile:
+      case llvm::omp::Directive::OMPD_unroll:
+        TODO(currentLocation, "Unhandled loop directive (" +
+                                  llvm::omp::getOpenMPDirectiveName(leafDir) +
+                                  ")");
+        break;
+      default:
+        llvm_unreachable("Unexpected loop construct");
+      }
     }
   }
-  if ((llvm::omp::allDoSet | llvm::omp::allSimdSet).test(ompDirective))
-    validDirective = true;
-
-  if (!validDirective) {
-    TODO(currentLocation, "Unhandled loop directive (" +
-                              llvm::omp::getOpenMPDirectiveName(ompDirective) +
-                              ")");
-  }
-
-  if (llvm::omp::allDoSimdSet.test(ompDirective)) {
-    // 2.9.3.2 Workshare SIMD construct
-    genCompositeDoSimd(converter, semaCtx, eval, ompDirective, beginClauseList,
-                       endClauseList, currentLocation);
-  } else if (llvm::omp::allSimdSet.test(ompDirective)) {
-    // 2.9.3.1 SIMD construct
-    genSimdLoopOp(converter, semaCtx, eval, currentLocation, beginClauseList);
-    genOpenMPReduction(converter, semaCtx, beginClauseList);
-  } else {
-    genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList,
-                endClauseList);
-  }
 }
 
 static void



More information about the llvm-branch-commits mailing list