[llvm-branch-commits] [flang] [llvm] [flang][OpenMP] Decompose compound construccts, do recursive lowering (PR #90098)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Apr 25 10:52:18 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Krzysztof Parzyszek (kparzysz)

<details>
<summary>Changes</summary>

A compound construct with a list of clauses is broken up into individual leaf/composite constructs. Each such construct has the list of clauses that apply to it based on the OpenMP spec.

Each lowering function (i.e. a function that generates MLIR ops) is now responsible for generating its body as described below.

Functions that receive AST nodes extract the construct, and the clauses from the node. They then create a work queue consisting of individual constructs, and invoke a common dispatch function.

The dispatch function examines the current position in the queue, and invokes the appropriate lowering function. Each lowering function receives the queue as well, and once it needs to generate its body, it either invokes the dispatch function on the rest of the queue (if any), or processes nested evaluations if the work queue is at the end.

---

Patch is 93.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90098.diff


2 Files Affected:

- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+391-393) 
- (added) llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h (+985) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 47935e6cf8efcf..4b8afd42f639d5 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -36,6 +36,7 @@
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 
 using namespace Fortran::lower::omp;
@@ -72,6 +73,89 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
     converter.genEval(e);
 }
 
+//===----------------------------------------------------------------------===//
+// Directive decomposition
+//===----------------------------------------------------------------------===//
+
+namespace {
+using DirectiveWithClauses = tomp::DirectiveWithClauses<lower::omp::Clause>;
+using ConstructQueue = List<DirectiveWithClauses>;
+} // namespace
+
+static void genOMPDispatch(Fortran::lower::AbstractConverter &converter,
+                           Fortran::lower::SymMap &symTable,
+                           Fortran::semantics::SemanticsContext &semaCtx,
+                           Fortran::lower::pft::Evaluation &eval,
+                           mlir::Location loc, const ConstructQueue &queue,
+                           ConstructQueue::iterator item);
+
+namespace {
+struct ConstructDecomposition {
+  ConstructDecomposition(mlir::ModuleOp modOp,
+                         semantics::SemanticsContext &semaCtx,
+                         lower::pft::Evaluation &ev,
+                         llvm::omp::Directive construct,
+                         const List<Clause> &clauses)
+      : semaCtx(semaCtx), mod(modOp), eval(ev) {
+    tomp::ConstructDecompositionT decompose(getOpenMPVersion(modOp), *this,
+                                            construct, llvm::ArrayRef(clauses));
+    output = std::move(decompose.output);
+  }
+
+  // Given an object, return its base object if one exists.
+  std::optional<Object> getBaseObject(const Object &object) {
+    return lower::omp::getBaseObject(object, semaCtx);
+  }
+
+  // Return the iteration variable of the associated loop if any.
+  std::optional<Object> getLoopIterVar() {
+    if (semantics::Symbol *symbol = getIterationVariableSymbol(eval))
+      return Object{symbol, /*designator=*/{}};
+    return std::nullopt;
+  }
+
+  semantics::SemanticsContext &semaCtx;
+  mlir::ModuleOp mod;
+  lower::pft::Evaluation &eval;
+  List<DirectiveWithClauses> output;
+};
+} // namespace
+
+LLVM_DUMP_METHOD static llvm::raw_ostream &
+operator<<(llvm::raw_ostream &os, const DirectiveWithClauses &dwc) {
+  os << llvm::omp::getOpenMPDirectiveName(dwc.id);
+  for (auto [index, clause] : llvm::enumerate(dwc.clauses)) {
+    os << (index == 0 ? '\t' : ' ');
+    os << llvm::omp::getOpenMPClauseName(clause.id);
+  }
+  return os;
+}
+
+static void splitCompoundConstruct(
+    mlir::ModuleOp modOp, Fortran::semantics::SemanticsContext &semaCtx,
+    Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive construct,
+    const List<Clause> &clauses, List<DirectiveWithClauses> &directives) {
+
+  ConstructDecomposition decompose(modOp, semaCtx, eval, construct, clauses);
+  assert(!decompose.output.empty());
+
+  llvm::SmallVector<llvm::omp::Directive> loweringUnits;
+  std::ignore =
+      llvm::omp::getLeafOrCompositeConstructs(construct, loweringUnits);
+
+  int leafIndex = 0;
+  for (llvm::omp::Directive dir_id : loweringUnits) {
+    directives.push_back(DirectiveWithClauses{dir_id});
+    DirectiveWithClauses &dwc = directives.back();
+    llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
+        llvm::omp::getLeafConstructsOrSelf(dir_id);
+    for (int i = 0, e = leafsOrSelf.size(); i != e; ++i) {
+      dwc.clauses.append(decompose.output[leafIndex].clauses);
+      ++leafIndex;
+    }
+  }
+}
+
 static fir::GlobalOp globalInitialization(
     Fortran::lower::AbstractConverter &converter,
     fir::FirOpBuilder &firOpBuilder, const Fortran::semantics::Symbol &sym,
@@ -460,81 +544,6 @@ markDeclareTarget(mlir::Operation *op,
   declareTargetOp.setDeclareTarget(deviceType, captureClause);
 }
 
-/// Split a combined directive into an outer leaf directive and the (possibly
-/// combined) rest of the combined directive. Composite directives and
-/// non-compound directives are not split, in which case it will return the
-/// input directive as its first output and an empty value as its second output.
-static std::pair<llvm::omp::Directive, std::optional<llvm::omp::Directive>>
-splitCombinedDirective(llvm::omp::Directive dir) {
-  using D = llvm::omp::Directive;
-  switch (dir) {
-  case D::OMPD_masked_taskloop:
-    return {D::OMPD_masked, D::OMPD_taskloop};
-  case D::OMPD_masked_taskloop_simd:
-    return {D::OMPD_masked, D::OMPD_taskloop_simd};
-  case D::OMPD_master_taskloop:
-    return {D::OMPD_master, D::OMPD_taskloop};
-  case D::OMPD_master_taskloop_simd:
-    return {D::OMPD_master, D::OMPD_taskloop_simd};
-  case D::OMPD_parallel_do:
-    return {D::OMPD_parallel, D::OMPD_do};
-  case D::OMPD_parallel_do_simd:
-    return {D::OMPD_parallel, D::OMPD_do_simd};
-  case D::OMPD_parallel_masked:
-    return {D::OMPD_parallel, D::OMPD_masked};
-  case D::OMPD_parallel_masked_taskloop:
-    return {D::OMPD_parallel, D::OMPD_masked_taskloop};
-  case D::OMPD_parallel_masked_taskloop_simd:
-    return {D::OMPD_parallel, D::OMPD_masked_taskloop_simd};
-  case D::OMPD_parallel_master:
-    return {D::OMPD_parallel, D::OMPD_master};
-  case D::OMPD_parallel_master_taskloop:
-    return {D::OMPD_parallel, D::OMPD_master_taskloop};
-  case D::OMPD_parallel_master_taskloop_simd:
-    return {D::OMPD_parallel, D::OMPD_master_taskloop_simd};
-  case D::OMPD_parallel_sections:
-    return {D::OMPD_parallel, D::OMPD_sections};
-  case D::OMPD_parallel_workshare:
-    return {D::OMPD_parallel, D::OMPD_workshare};
-  case D::OMPD_target_parallel:
-    return {D::OMPD_target, D::OMPD_parallel};
-  case D::OMPD_target_parallel_do:
-    return {D::OMPD_target, D::OMPD_parallel_do};
-  case D::OMPD_target_parallel_do_simd:
-    return {D::OMPD_target, D::OMPD_parallel_do_simd};
-  case D::OMPD_target_simd:
-    return {D::OMPD_target, D::OMPD_simd};
-  case D::OMPD_target_teams:
-    return {D::OMPD_target, D::OMPD_teams};
-  case D::OMPD_target_teams_distribute:
-    return {D::OMPD_target, D::OMPD_teams_distribute};
-  case D::OMPD_target_teams_distribute_parallel_do:
-    return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do};
-  case D::OMPD_target_teams_distribute_parallel_do_simd:
-    return {D::OMPD_target, D::OMPD_teams_distribute_parallel_do_simd};
-  case D::OMPD_target_teams_distribute_simd:
-    return {D::OMPD_target, D::OMPD_teams_distribute_simd};
-  case D::OMPD_teams_distribute:
-    return {D::OMPD_teams, D::OMPD_distribute};
-  case D::OMPD_teams_distribute_parallel_do:
-    return {D::OMPD_teams, D::OMPD_distribute_parallel_do};
-  case D::OMPD_teams_distribute_parallel_do_simd:
-    return {D::OMPD_teams, D::OMPD_distribute_parallel_do_simd};
-  case D::OMPD_teams_distribute_simd:
-    return {D::OMPD_teams, D::OMPD_distribute_simd};
-  case D::OMPD_parallel_loop:
-    return {D::OMPD_parallel, D::OMPD_loop};
-  case D::OMPD_target_parallel_loop:
-    return {D::OMPD_target, D::OMPD_parallel_loop};
-  case D::OMPD_target_teams_loop:
-    return {D::OMPD_target, D::OMPD_teams_loop};
-  case D::OMPD_teams_loop:
-    return {D::OMPD_teams, D::OMPD_loop};
-  default:
-    return {dir, std::nullopt};
-  }
-}
-
 //===----------------------------------------------------------------------===//
 // Op body generation helper structures and functions
 //===----------------------------------------------------------------------===//
@@ -555,11 +564,6 @@ struct OpWithBodyGenInfo {
       : converter(converter), symTable(symTable), semaCtx(semaCtx), loc(loc),
         eval(eval), dir(dir) {}
 
-  OpWithBodyGenInfo &setGenNested(bool value) {
-    genNested = value;
-    return *this;
-  }
-
   OpWithBodyGenInfo &setOuterCombined(bool value) {
     outerCombined = value;
     return *this;
@@ -600,8 +604,6 @@ struct OpWithBodyGenInfo {
   Fortran::lower::pft::Evaluation &eval;
   /// [in] leaf directive for which to generate the op body.
   llvm::omp::Directive dir;
-  /// [in] whether to generate FIR for nested evaluations
-  bool genNested = true;
   /// [in] is this an outer operation - prevents privatization.
   bool outerCombined = false;
   /// [in] list of clauses to process.
@@ -622,7 +624,9 @@ struct OpWithBodyGenInfo {
 ///
 /// \param [in]   op - the operation the body belongs to.
 /// \param [in] info - options controlling code-gen for the construction.
-static void createBodyOfOp(mlir::Operation &op, OpWithBodyGenInfo &info) {
+static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
+                           const ConstructQueue &queue,
+                           ConstructQueue::iterator item) {
   fir::FirOpBuilder &firOpBuilder = info.converter.getFirOpBuilder();
 
   auto insertMarker = [](fir::FirOpBuilder &builder) {
@@ -678,7 +682,10 @@ static void createBodyOfOp(mlir::Operation &op, OpWithBodyGenInfo &info) {
     }
   }
 
-  if (info.genNested) {
+  if (ConstructQueue::iterator next = std::next(item); next != queue.end()) {
+    genOMPDispatch(info.converter, info.symTable, info.semaCtx, info.eval,
+                   info.loc, queue, next);
+  } else {
     // genFIR(Evaluation&) tries to patch up unterminated blocks, causing
     // a lot of complications for our approach if the terminator generation
     // is delayed past this point. Insert a temporary terminator here, then
@@ -769,11 +776,12 @@ static void genBodyOfTargetDataOp(
     Fortran::lower::AbstractConverter &converter,
     Fortran::lower::SymMap &symTable,
     Fortran::semantics::SemanticsContext &semaCtx,
-    Fortran::lower::pft::Evaluation &eval, bool genNested,
-    mlir::omp::TargetDataOp &dataOp, llvm::ArrayRef<mlir::Type> useDeviceTypes,
+    Fortran::lower::pft::Evaluation &eval, mlir::omp::TargetDataOp &dataOp,
+    llvm::ArrayRef<mlir::Type> useDeviceTypes,
     llvm::ArrayRef<mlir::Location> useDeviceLocs,
     llvm::ArrayRef<const Fortran::semantics::Symbol *> useDeviceSymbols,
-    const mlir::Location &currentLocation) {
+    const mlir::Location &currentLocation, const ConstructQueue &queue,
+    ConstructQueue::iterator item) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Region &region = dataOp.getRegion();
 
@@ -826,8 +834,13 @@ static void genBodyOfTargetDataOp(
 
   // Set the insertion point after the marker.
   firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
-  if (genNested)
+
+  if (ConstructQueue::iterator next = std::next(item); next != queue.end()) {
+    genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
+                   next);
+  } else {
     genNestedEvaluations(converter, eval);
+  }
 }
 
 // This functions creates a block for the body of the targetOp's region. It adds
@@ -836,12 +849,13 @@ static void
 genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
                   Fortran::lower::SymMap &symTable,
                   Fortran::semantics::SemanticsContext &semaCtx,
-                  Fortran::lower::pft::Evaluation &eval, bool genNested,
+                  Fortran::lower::pft::Evaluation &eval,
                   mlir::omp::TargetOp &targetOp,
                   llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms,
                   llvm::ArrayRef<mlir::Location> mapSymLocs,
                   llvm::ArrayRef<mlir::Type> mapSymTypes,
-                  const mlir::Location &currentLocation) {
+                  const mlir::Location &currentLocation,
+                  const ConstructQueue &queue, ConstructQueue::iterator item) {
   assert(mapSymTypes.size() == mapSymLocs.size());
 
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -981,15 +995,22 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
 
   // Create the insertion point after the marker.
   firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
-  if (genNested)
+
+  if (ConstructQueue::iterator next = std::next(item); next != queue.end()) {
+    genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
+                   next);
+  } else {
     genNestedEvaluations(converter, eval);
+  }
 }
 
 template <typename OpTy, typename... Args>
-static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) {
+static OpTy genOpWithBody(const OpWithBodyGenInfo &info,
+                          const ConstructQueue &queue,
+                          ConstructQueue::iterator item, Args &&...args) {
   auto op = info.converter.getFirOpBuilder().create<OpTy>(
       info.loc, std::forward<Args>(args)...);
-  createBodyOfOp(*op, info);
+  createBodyOfOp(*op, info, queue, item);
   return op;
 }
 
@@ -1274,7 +1295,8 @@ static mlir::omp::BarrierOp
 genBarrierOp(Fortran::lower::AbstractConverter &converter,
              Fortran::lower::SymMap &symTable,
              Fortran::semantics::SemanticsContext &semaCtx,
-             Fortran::lower::pft::Evaluation &eval, mlir::Location loc) {
+             Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+             const ConstructQueue &queue, ConstructQueue::iterator item) {
   return converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(loc);
 }
 
@@ -1282,8 +1304,9 @@ static mlir::omp::CriticalOp
 genCriticalOp(Fortran::lower::AbstractConverter &converter,
               Fortran::lower::SymMap &symTable,
               Fortran::semantics::SemanticsContext &semaCtx,
-              Fortran::lower::pft::Evaluation &eval, bool genNested,
-              mlir::Location loc, const List<Clause> &clauses,
+              Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+              const List<Clause> &clauses, const ConstructQueue &queue,
+              ConstructQueue::iterator item,
               const std::optional<Fortran::parser::Name> &name) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::FlatSymbolRefAttr nameAttr;
@@ -1306,17 +1329,17 @@ genCriticalOp(Fortran::lower::AbstractConverter &converter,
 
   return genOpWithBody<mlir::omp::CriticalOp>(
       OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
-                        llvm::omp::Directive::OMPD_critical)
-          .setGenNested(genNested),
-      nameAttr);
+                        llvm::omp::Directive::OMPD_critical),
+      queue, item, nameAttr);
 }
 
 static mlir::omp::DistributeOp
 genDistributeOp(Fortran::lower::AbstractConverter &converter,
                 Fortran::lower::SymMap &symTable,
                 Fortran::semantics::SemanticsContext &semaCtx,
-                Fortran::lower::pft::Evaluation &eval, bool genNested,
-                mlir::Location loc, const List<Clause> &clauses) {
+                Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+                const List<Clause> &clauses, const ConstructQueue &queue,
+                ConstructQueue::iterator item) {
   TODO(loc, "Distribute construct");
   return nullptr;
 }
@@ -1326,7 +1349,8 @@ genFlushOp(Fortran::lower::AbstractConverter &converter,
            Fortran::lower::SymMap &symTable,
            Fortran::semantics::SemanticsContext &semaCtx,
            Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
-           const ObjectList &objects, const List<Clause> &clauses) {
+           const ObjectList &objects, const List<Clause> &clauses,
+           const ConstructQueue &queue, ConstructQueue::iterator item) {
   llvm::SmallVector<mlir::Value> operandRange;
   genFlushClauses(converter, semaCtx, objects, clauses, loc, operandRange);
 
@@ -1338,12 +1362,13 @@ static mlir::omp::MasterOp
 genMasterOp(Fortran::lower::AbstractConverter &converter,
             Fortran::lower::SymMap &symTable,
             Fortran::semantics::SemanticsContext &semaCtx,
-            Fortran::lower::pft::Evaluation &eval, bool genNested,
-            mlir::Location loc) {
+            Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+            const List<Clause> &clauses, const ConstructQueue &queue,
+            ConstructQueue::iterator item) {
   return genOpWithBody<mlir::omp::MasterOp>(
       OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
-                        llvm::omp::Directive::OMPD_master)
-          .setGenNested(genNested));
+                        llvm::omp::Directive::OMPD_master),
+      queue, item);
 }
 
 static mlir::omp::OrderedOp
@@ -1351,7 +1376,8 @@ genOrderedOp(Fortran::lower::AbstractConverter &converter,
              Fortran::lower::SymMap &symTable,
              Fortran::semantics::SemanticsContext &semaCtx,
              Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
-             const List<Clause> &clauses) {
+             const List<Clause> &clauses, const ConstructQueue &queue,
+             ConstructQueue::iterator item) {
   TODO(loc, "OMPD_ordered");
   return nullptr;
 }
@@ -1360,25 +1386,25 @@ static mlir::omp::OrderedRegionOp
 genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
                    Fortran::lower::SymMap &symTable,
                    Fortran::semantics::SemanticsContext &semaCtx,
-                   Fortran::lower::pft::Evaluation &eval, bool genNested,
-                   mlir::Location loc, const List<Clause> &clauses) {
+                   Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+                   const List<Clause> &clauses, const ConstructQueue &queue,
+                   ConstructQueue::iterator item) {
   mlir::omp::OrderedRegionClauseOps clauseOps;
   genOrderedRegionClauses(converter, semaCtx, clauses, loc, clauseOps);
 
   return genOpWithBody<mlir::omp::OrderedRegionOp>(
       OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
-                        llvm::omp::Directive::OMPD_ordered)
-          .setGenNested(genNested),
-      clauseOps);
+                        llvm::omp::Directive::OMPD_ordered),
+      queue, item, clauseOps);
 }
 
 static mlir::omp::ParallelOp
 genParallelOp(Fortran::lower::AbstractConverter &converter,
               Fortran::lower::SymMap &symTable,
               Fortran::semantics::SemanticsContext &semaCtx,
-              Fortran::lower::pft::Evaluation &eval, bool genNested,
-              mlir::Location loc, const List<Clause> &clauses,
-              bool outerCombined = false) {
+              Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+              const List<Clause> &clauses, const ConstructQueue &queue,
+              ConstructQueue::iterator item, bool outerCombined = false) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   Fortran::lower::StatementContext stmtCtx;
   mlir::omp::ParallelClauseOps clauseOps;
@@ -1397,14 +1423,14 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
   OpWithBodyGenInfo genInfo =
       OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
                         llvm::omp::Directive::OMPD_parallel)
-          .setGenNested(genNested)
           .setOuterCombined(outerCombined)
           .setClauses(&clauses)
           .setReductions(&reductionSyms, &reductionTypes)
           .setGenRegionEntryCb(reductionCallback);
 
   if (!enableDelayedPrivatization)
-    return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
+    return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item,
+                                                clauseOps);
 
   bool privatize = !outerCombined;
   DataSharingProcessor dsp(converter, semaCtx, clauses, eval,
@@ -1447,22 +1473,23 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
 
   // TODO Merge with the reduction CB.
   genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
-  return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
+  return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
 }
 
 static mlir::omp::SectionOp
 genSectionOp(Fortran::lower::AbstractConverter &converter,
              Fortran::lower::SymMap &symTable,
              Fortran::semantics::SemanticsContext &semaCtx,
-             Fortran::lower::pft::Evaluation &eval, bool genNested...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/90098


More information about the llvm-branch-commits mailing list