[llvm-branch-commits] [flang] [Flang][OpenMP] Push genEval closer to leaf lowering functions (PR #77760)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 11 04:51:03 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: Krzysztof Parzyszek (kparzysz)

<details>
<summary>Changes</summary>

Recursive lowering [4/5]

---

Patch is 24.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77760.diff


1 Files Affected:

- (modified) flang/lib/Lower/OpenMP.cpp (+135-98) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index aa315d7ab280e2..c0f5619344c605 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2178,7 +2178,7 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
 template <typename Op>
 static void createBodyOfOp(
     Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
-    Fortran::lower::pft::Evaluation &eval,
+    Fortran::lower::pft::Evaluation &eval, bool genNested,
     const Fortran::parser::OmpClauseList *clauses = nullptr,
     const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
     bool outerCombined = false, DataSharingProcessor *dsp = nullptr) {
@@ -2248,11 +2248,15 @@ static void createBodyOfOp(
     if (clauses)
       ClauseProcessor(converter, *clauses).processCopyin();
   }
+
+  if (genNested)
+    genNestedEvaluations(converter, eval);
 }
 
 static void genBodyOfTargetDataOp(
     Fortran::lower::AbstractConverter &converter,
-    Fortran::lower::pft::Evaluation &eval, mlir::omp::DataOp &dataOp,
+    Fortran::lower::pft::Evaluation &eval, bool genNested,
+    mlir::omp::DataOp &dataOp,
     const llvm::SmallVector<mlir::Type> &useDeviceTypes,
     const llvm::SmallVector<mlir::Location> &useDeviceLocs,
     const llvm::SmallVector<const Fortran::semantics::Symbol *>
@@ -2310,43 +2314,48 @@ static void genBodyOfTargetDataOp(
 
   // Set the insertion point after the marker.
   firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
+  if (genNested)
+    genNestedEvaluations(converter, eval);
 }
 
 template <typename OpTy, typename... Args>
 static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
-                          Fortran::lower::pft::Evaluation &eval,
+                          Fortran::lower::pft::Evaluation &eval, bool genNested,
                           mlir::Location currentLocation, bool outerCombined,
                           const Fortran::parser::OmpClauseList *clauseList,
                           Args &&...args) {
   auto op = converter.getFirOpBuilder().create<OpTy>(
       currentLocation, std::forward<Args>(args)...);
-  createBodyOfOp<OpTy>(op, converter, currentLocation, eval, clauseList,
+  createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
+                       clauseList,
                        /*args=*/{}, outerCombined);
   return op;
 }
 
 static mlir::omp::MasterOp
 genMasterOp(Fortran::lower::AbstractConverter &converter,
-            Fortran::lower::pft::Evaluation &eval,
+            Fortran::lower::pft::Evaluation &eval, bool genNested,
             mlir::Location currentLocation) {
-  return genOpWithBody<mlir::omp::MasterOp>(converter, eval, currentLocation,
-                                            /*outerCombined=*/false,
-                                            /*clauseList=*/nullptr,
-                                            /*resultTypes=*/mlir::TypeRange());
+  return genOpWithBody<mlir::omp::MasterOp>(
+      converter, eval, genNested, currentLocation,
+      /*outerCombined=*/false,
+      /*clauseList=*/nullptr,
+      /*resultTypes=*/mlir::TypeRange());
 }
 
 static mlir::omp::OrderedRegionOp
 genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
-                   Fortran::lower::pft::Evaluation &eval,
+                   Fortran::lower::pft::Evaluation &eval, bool genNested,
                    mlir::Location currentLocation) {
   return genOpWithBody<mlir::omp::OrderedRegionOp>(
-      converter, eval, currentLocation, /*outerCombined=*/false,
+      converter, eval, genNested, currentLocation,
+      /*outerCombined=*/false,
       /*clauseList=*/nullptr, /*simd=*/false);
 }
 
 static mlir::omp::ParallelOp
 genParallelOp(Fortran::lower::AbstractConverter &converter,
-              Fortran::lower::pft::Evaluation &eval,
+              Fortran::lower::pft::Evaluation &eval, bool genNested,
               mlir::Location currentLocation,
               const Fortran::parser::OmpClauseList &clauseList,
               bool outerCombined = false) {
@@ -2368,8 +2377,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
     cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
 
   return genOpWithBody<mlir::omp::ParallelOp>(
-      converter, eval, currentLocation, outerCombined, &clauseList,
-      /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
+      converter, eval, genNested, currentLocation, outerCombined,
+      &clauseList, /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
       numThreadsClauseOperand, allocateOperands, allocatorOperands,
       reductionVars,
       reductionDeclSymbols.empty()
@@ -2381,19 +2390,20 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
 
 static mlir::omp::SectionOp
 genSectionOp(Fortran::lower::AbstractConverter &converter,
-             Fortran::lower::pft::Evaluation &eval,
+             Fortran::lower::pft::Evaluation &eval, bool genNested,
              mlir::Location currentLocation,
              const Fortran::parser::OmpClauseList &sectionsClauseList) {
   // Currently only private/firstprivate clause is handled, and
   // all privatization is done within `omp.section` operations.
-  return genOpWithBody<mlir::omp::SectionOp>(converter, eval, currentLocation,
+  return genOpWithBody<mlir::omp::SectionOp>(converter, eval, genNested,
+                                             currentLocation,
                                              /*outerCombined=*/false,
                                              &sectionsClauseList);
 }
 
 static mlir::omp::SingleOp
 genSingleOp(Fortran::lower::AbstractConverter &converter,
-            Fortran::lower::pft::Evaluation &eval,
+            Fortran::lower::pft::Evaluation &eval, bool genNested,
             mlir::Location currentLocation,
             const Fortran::parser::OmpClauseList &beginClauseList,
             const Fortran::parser::OmpClauseList &endClauseList) {
@@ -2408,13 +2418,15 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
   ClauseProcessor(converter, endClauseList).processNowait(nowaitAttr);
 
   return genOpWithBody<mlir::omp::SingleOp>(
-      converter, eval, currentLocation, /*outerCombined=*/false,
-      &beginClauseList, allocateOperands, allocatorOperands, nowaitAttr);
+      converter, eval, genNested, currentLocation,
+      /*outerCombined=*/false, &beginClauseList, allocateOperands,
+      allocatorOperands, nowaitAttr);
 }
 
 static mlir::omp::TaskOp
 genTaskOp(Fortran::lower::AbstractConverter &converter,
-          Fortran::lower::pft::Evaluation &eval, mlir::Location currentLocation,
+          Fortran::lower::pft::Evaluation &eval, bool genNested,
+          mlir::Location currentLocation,
           const Fortran::parser::OmpClauseList &clauseList) {
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value ifClauseOperand, finalClauseOperand, priorityClauseOperand;
@@ -2439,8 +2451,9 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
       currentLocation, llvm::omp::Directive::OMPD_task);
 
   return genOpWithBody<mlir::omp::TaskOp>(
-      converter, eval, currentLocation, /*outerCombined=*/false, &clauseList,
-      ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr,
+      converter, eval, genNested, currentLocation,
+      /*outerCombined=*/false, &clauseList, ifClauseOperand, finalClauseOperand,
+      untiedAttr, mergeableAttr,
       /*in_reduction_vars=*/mlir::ValueRange(),
       /*in_reductions=*/nullptr, priorityClauseOperand,
       dependTypeOperands.empty()
@@ -2452,7 +2465,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
 
 static mlir::omp::TaskGroupOp
 genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
-               Fortran::lower::pft::Evaluation &eval,
+               Fortran::lower::pft::Evaluation &eval, bool genNested,
                mlir::Location currentLocation,
                const Fortran::parser::OmpClauseList &clauseList) {
   llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
@@ -2461,7 +2474,8 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
   cp.processTODO<Fortran::parser::OmpClause::TaskReduction>(
       currentLocation, llvm::omp::Directive::OMPD_taskgroup);
   return genOpWithBody<mlir::omp::TaskGroupOp>(
-      converter, eval, currentLocation, /*outerCombined=*/false, &clauseList,
+      converter, eval, genNested, currentLocation,
+      /*outerCombined=*/false, &clauseList,
       /*task_reduction_vars=*/mlir::ValueRange(),
       /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
 }
@@ -2470,7 +2484,7 @@ static mlir::omp::DataOp
 genDataOp(Fortran::lower::AbstractConverter &converter,
           Fortran::lower::pft::Evaluation &eval,
           Fortran::semantics::SemanticsContext &semanticsContext,
-          mlir::Location currentLocation,
+          bool genNested, mlir::Location currentLocation,
           const Fortran::parser::OmpClauseList &clauseList) {
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value ifClauseOperand, deviceOperand;
@@ -2494,8 +2508,9 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
   auto dataOp = converter.getFirOpBuilder().create<mlir::omp::DataOp>(
       currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
       deviceAddrOperands, mapOperands);
-  genBodyOfTargetDataOp(converter, eval, dataOp, useDeviceTypes, useDeviceLocs,
-                        useDeviceSymbols, currentLocation);
+  genBodyOfTargetDataOp(converter, eval, genNested, dataOp,
+                        useDeviceTypes, useDeviceLocs, useDeviceSymbols,
+                        currentLocation);
   return dataOp;
 }
 
@@ -2556,7 +2571,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
 // all the symbols present in mapSymbols as block arguments to this block.
 static void genBodyOfTargetOp(
     Fortran::lower::AbstractConverter &converter,
-    Fortran::lower::pft::Evaluation &eval, mlir::omp::TargetOp &targetOp,
+    Fortran::lower::pft::Evaluation &eval, bool genNested,
+    mlir::omp::TargetOp &targetOp,
     const llvm::SmallVector<mlir::Type> &mapSymTypes,
     const llvm::SmallVector<mlir::Location> &mapSymLocs,
     const llvm::SmallVector<const Fortran::semantics::Symbol *> &mapSymbols,
@@ -2686,7 +2702,6 @@ static void genBodyOfTargetOp(
 
   // Create blocks for unstructured regions. This has to be done since
   // blocks are initially allocated with the function as the parent region.
-  // the parent region of blocks.
   if (eval.lowerAsUnstructured()) {
     Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
                                             mlir::omp::YieldOp>(
@@ -2697,13 +2712,15 @@ static void genBodyOfTargetOp(
 
   // Create the insertion point after the marker.
   firOpBuilder.setInsertionPointAfter(undefMarker.getDefiningOp());
+  if (genNested)
+    genNestedEvaluations(converter, eval);
 }
 
 static mlir::omp::TargetOp
 genTargetOp(Fortran::lower::AbstractConverter &converter,
             Fortran::lower::pft::Evaluation &eval,
             Fortran::semantics::SemanticsContext &semanticsContext,
-            mlir::Location currentLocation,
+            bool genNested, mlir::Location currentLocation,
             const Fortran::parser::OmpClauseList &clauseList,
             llvm::omp::Directive directive, bool outerCombined = false) {
   Fortran::lower::StatementContext stmtCtx;
@@ -2803,15 +2820,15 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
       currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
       nowaitAttr, mapOperands);
 
-  genBodyOfTargetOp(converter, eval, targetOp, mapSymTypes, mapSymLocs,
-                    mapSymbols, currentLocation);
+  genBodyOfTargetOp(converter, eval, genNested, targetOp,
+                    mapSymTypes, mapSymLocs, mapSymbols, currentLocation);
 
   return targetOp;
 }
 
 static mlir::omp::TeamsOp
 genTeamsOp(Fortran::lower::AbstractConverter &converter,
-           Fortran::lower::pft::Evaluation &eval,
+           Fortran::lower::pft::Evaluation &eval, bool genNested,
            mlir::Location currentLocation,
            const Fortran::parser::OmpClauseList &clauseList,
            bool outerCombined = false) {
@@ -2832,7 +2849,8 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
       currentLocation, llvm::omp::Directive::OMPD_teams);
 
   return genOpWithBody<mlir::omp::TeamsOp>(
-      converter, eval, currentLocation, outerCombined, &clauseList,
+      converter, eval, genNested, currentLocation, outerCombined,
+      &clauseList,
       /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
       threadLimitClauseOperand, allocateOperands, allocatorOperands,
       reductionVars,
@@ -2916,6 +2934,7 @@ static void
 genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
                        Fortran::lower::pft::Evaluation &eval,
                        Fortran::semantics::SemanticsContext &semanticsContext,
+                       bool genNested,
                        const Fortran::parser::OpenMPSimpleStandaloneConstruct
                            &simpleStandaloneConstruct) {
   const auto &directive =
@@ -2943,7 +2962,8 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
     firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation);
     break;
   case llvm::omp::Directive::OMPD_target_data:
-    genDataOp(converter, eval, semanticsContext, currentLocation, opClauseList);
+    genDataOp(converter, eval, semanticsContext, genNested,
+              currentLocation, opClauseList);
     break;
   case llvm::omp::Directive::OMPD_target_enter_data:
     genEnterExitUpdateDataOp<mlir::omp::EnterDataOp>(
@@ -2990,6 +3010,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
           [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
                   &simpleStandaloneConstruct) {
             genOmpSimpleStandalone(converter, eval, semanticsContext,
+                                   /*genNested=*/true,
                                    simpleStandaloneConstruct);
           },
           [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
@@ -3060,12 +3081,12 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
       /*alignment_values=*/nullptr, ifClauseOperand, nontemporalVars,
       orderClauseOperand, simdlenClauseOperand, safelenClauseOperand,
       /*inclusive=*/firOpBuilder.getUnitAttr());
-  createBodyOfOp<mlir::omp::SimdLoopOp>(simdLoopOp, converter, loc, eval,
-                                        &loopOpClauseList, iv,
-                                        /*outer=*/false, &dsp);
 
-  genNestedEvaluations(converter, eval,
-                       Fortran::lower::getCollapseValue(loopOpClauseList));
+  auto *nestedEval = getEvalPastCollapse(
+      eval, Fortran::lower::getCollapseValue(loopOpClauseList));
+  createBodyOfOp<mlir::omp::SimdLoopOp>(simdLoopOp, converter, loc, *nestedEval,
+                                        /*genNested=*/true, &loopOpClauseList,
+                                        iv, /*outer=*/false, &dsp);
 }
 
 static void createWsLoop(Fortran::lower::AbstractConverter &converter,
@@ -3144,12 +3165,11 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
       wsLoopOp.setNowaitAttr(nowaitClauseOperand);
   }
 
-  createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, loc, eval,
-                                      &beginClauseList, iv,
+  auto *nestedEval = getEvalPastCollapse(
+      eval, Fortran::lower::getCollapseValue(beginClauseList));
+  createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, loc, *nestedEval,
+                                      /*genNested=*/true, &beginClauseList, iv,
                                       /*outer=*/false, &dsp);
-
-  genNestedEvaluations(converter, eval,
-                       Fortran::lower::getCollapseValue(beginClauseList));
 }
 
 static void genOMP(Fortran::lower::AbstractConverter &converter,
@@ -3186,13 +3206,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     if ((llvm::omp::allTargetSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      genTargetOp(converter, eval, semanticsContext, currentLocation,
-                  loopOpClauseList, ompDirective, /*outerCombined=*/true);
+      genTargetOp(converter, eval, semanticsContext, /*genNested=*/false,
+                  currentLocation, loopOpClauseList, ompDirective,
+                  /*outerCombined=*/true);
     }
     if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      genTeamsOp(converter, eval, currentLocation, loopOpClauseList,
+      genTeamsOp(converter, eval, /*genNested=*/false, currentLocation,
+                 loopOpClauseList,
                  /*outerCombined=*/true);
     }
     if (llvm::omp::allDistributeSet.test(ompDirective)) {
@@ -3202,7 +3224,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      genParallelOp(converter, eval, currentLocation, loopOpClauseList,
+      genParallelOp(converter, eval, /*genNested=*/false, currentLocation,
+                    loopOpClauseList,
                     /*outerCombined=*/true);
     }
   }
@@ -3273,76 +3296,89 @@ genOMP(Fortran::lower::AbstractConverter &converter,
       TODO(clauseLocation, "OpenMP Block construct clause");
   }
 
+  bool singleDirective = true;
   mlir::Location currentLocation = converter.genLocation(directive.source);
   switch (directive.v) {
   case llvm::omp::Directive::OMPD_master:
-    genMasterOp(converter, eval, currentLocation);
+    genMasterOp(converter, eval, /*genNested=*/true, currentLocation);
     break;
   case llvm::omp::Directive::OMPD_ordered:
-    genOrderedRegionOp(converter, eval, currentLocation);
+    genOrderedRegionOp(converter, eval, /*genNested=*/true, currentLocation);
     break;
   case llvm::omp::Directive::OMPD_parallel:
-    genParallelOp(converter, eval, currentLocation, beginClauseList);
+    genParallelOp(converter, eval, /*genNested=*/true, currentLocation,
+                  beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_single:
-    genSingleOp(converter, eval, currentLocation, beginClauseList,
-                endClauseList);
+    genSingleOp(converter, eval, /*genNested=*/true, currentLocation,
+                beginClauseList, endClauseList);
     break;
   case llvm::omp::Directive::OMPD_target:
-    genTargetOp(converter, eval, semanticsContext, currentLocation,
-                beginClauseList, directive.v);
+    genTargetOp(converter, eval, semanticsContext, /*genNested=*/true,
+                currentLocation, beginClauseList, directive.v);
     break;
   case llvm::omp::Directive::OMPD_target_data:
-    genDataOp(converter, eval, semanticsContext, currentLocation,
-              beginClauseList);
+    genDataOp(converter, eval, semanticsContext, /*genNested=*/true,
+              currentLocation, beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_task:
-    genTaskOp(converter, eval, currentLocation, beginClauseList);
+    genTaskOp(converter, eval, /*genNested=*/true, currentLocation,
+              beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_taskgroup:
-    genTaskGroupOp(converter, eval, currentLocation, beginClauseList);
+    genTaskGroupOp(converter, eval, /*genNested=*/true, currentLocation,
+                   beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_teams:
-    genTeamsOp(converter, eval, currentLocation, beginClauseList,
+    genTeamsOp(converter, eval, /*genNested=*/true, currentLocation,
+               beginClauseList,
                /*outerCombined=*/false);
     break;
   case llvm::omp::Directive::OMPD_workshare:
     TODO(currentLocation, "Workshare construct");
     break;
-  default: {
-    // Codegen for combined directives
-    bool combinedDirective = false;
-    if ((llvm::omp::allTargetSet & llvm::omp::blockConstructSet)
-            .test(directive.v)) {
-      genTargetOp(converter, eval, semanticsContext, currentLocation,
-                  beginClause...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/77760


More information about the llvm-branch-commits mailing list