[llvm-branch-commits] [flang] [Flang][OpenMP] Push genEval calls to individual operations, NFC (PR #77758)
    via llvm-branch-commits 
    llvm-branch-commits at lists.llvm.org
       
    Thu Jan 11 04:49:18 PST 2024
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Krzysztof Parzyszek (kparzysz)
<details>
<summary>Changes</summary>
Introduce `genNestedEvaluations` that will lower all evaluations nested in the given, accouting for a potential COLLAPSE directive.
Recursive lowering [2/5]
---
Full diff: https://github.com/llvm/llvm-project/pull/77758.diff
1 Files Affected:
- (modified) flang/lib/Lower/OpenMP.cpp (+66-62) 
``````````diff
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 99690b03eca1d3..496b4ba27a0533 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -110,6 +110,32 @@ static void gatherFuncAndVarSyms(
   }
 }
 
+static Fortran::lower::pft::Evaluation *
+getEvalPastCollapse(Fortran::lower::pft::Evaluation &eval, int collapseValue) {
+  if (collapseValue == 0)
+    return &eval;
+
+  Fortran::lower::pft::Evaluation *curEval = &eval.getFirstNestedEvaluation();
+  for (int i = 1; i < collapseValue; i++) {
+    // The nested evaluations should be DoConstructs (i.e. they should form
+    // a loop nest). Each DoConstruct is a tuple <NonLabelDoStmt, Block,
+    // EndDoStmt>.
+    assert(curEval->isA<Fortran::parser::DoConstruct>());
+    curEval = &*std::next(curEval->getNestedEvaluations().begin());
+  }
+  return curEval;
+}
+
+static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
+                                 Fortran::lower::pft::Evaluation &eval,
+                                 int collapseValue = 0) {
+  Fortran::lower::pft::Evaluation *curEval =
+      getEvalPastCollapse(eval, collapseValue);
+
+  for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
+    converter.genEval(e);
+}
+
 //===----------------------------------------------------------------------===//
 // DataSharingProcessor
 //===----------------------------------------------------------------------===//
@@ -2944,7 +2970,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
 
 static void
 genOMP(Fortran::lower::AbstractConverter &converter,
-       Fortran::lower::pft::Evaluation &eval,
+       Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval,
        Fortran::semantics::SemanticsContext &semanticsContext,
        const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
   std::visit(
@@ -3025,15 +3051,17 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
   createBodyOfOp<mlir::omp::SimdLoopOp>(simdLoopOp, converter, loc, eval,
                                         &loopOpClauseList, iv,
                                         /*outer=*/false, &dsp);
+
+  genNestedEvaluations(converter, eval,
+                       Fortran::lower::getCollapseValue(loopOpClauseList));
 }
 
-static void
-createWsLoop(Fortran::lower::AbstractConverter &converter,
-             Fortran::lower::pft::Evaluation &eval,
-             llvm::omp::Directive ompDirective,
-             const Fortran::parser::OmpClauseList &beginClauseList,
-             const Fortran::parser::OmpClauseList *endClauseList,
-             mlir::Location loc) {
+static void createWsLoop(Fortran::lower::AbstractConverter &converter,
+                         Fortran::lower::pft::Evaluation &eval,
+                         llvm::omp::Directive ompDirective,
+                         const Fortran::parser::OmpClauseList &beginClauseList,
+                         const Fortran::parser::OmpClauseList *endClauseList,
+                         mlir::Location loc) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   DataSharingProcessor dsp(converter, beginClauseList, eval);
   dsp.processStep1();
@@ -3107,9 +3135,13 @@ createWsLoop(Fortran::lower::AbstractConverter &converter,
   createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, loc, eval,
                                       &beginClauseList, iv,
                                       /*outer=*/false, &dsp);
+
+  genNestedEvaluations(converter, eval,
+                       Fortran::lower::getCollapseValue(beginClauseList));
 }
 
 static void genOMP(Fortran::lower::AbstractConverter &converter,
+                   Fortran::lower::SymMap &symTable,
                    Fortran::lower::pft::Evaluation &eval,
                    Fortran::semantics::SemanticsContext &semanticsContext,
                    const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
@@ -3179,11 +3211,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     createWsLoop(converter, eval, ompDirective, loopOpClauseList, endClauseList,
                  currentLocation);
   }
+  genOpenMPReduction(converter, loopOpClauseList);
 }
 
 static void
 genOMP(Fortran::lower::AbstractConverter &converter,
-       Fortran::lower::pft::Evaluation &eval,
+       Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval,
        Fortran::semantics::SemanticsContext &semanticsContext,
        const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
   const auto &beginBlockDirective =
@@ -3298,11 +3331,14 @@ genOMP(Fortran::lower::AbstractConverter &converter,
     break;
   }
   }
+
+  genNestedEvaluations(converter, eval);
+  genOpenMPReduction(converter, beginClauseList);
 }
 
 static void
 genOMP(Fortran::lower::AbstractConverter &converter,
-       Fortran::lower::pft::Evaluation &eval,
+       Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval,
        const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
@@ -3336,11 +3372,12 @@ genOMP(Fortran::lower::AbstractConverter &converter,
   }();
   createBodyOfOp<mlir::omp::CriticalOp>(criticalOp, converter, currentLocation,
                                         eval);
+  genNestedEvaluations(converter, eval);
 }
 
 static void
 genOMP(Fortran::lower::AbstractConverter &converter,
-       Fortran::lower::pft::Evaluation &eval,
+       Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval,
        const Fortran::parser::OpenMPSectionConstruct §ionConstruct) {
   mlir::Location currentLocation = converter.getCurrentLocation();
   const Fortran::parser::OpenMPConstruct *parentOmpConstruct =
@@ -3359,14 +3396,17 @@ genOMP(Fortran::lower::AbstractConverter &converter,
               .t);
   // Currently only private/firstprivate clause is handled, and
   // all privatization is done within `omp.section` operations.
+  symTable.pushScope();
   genOpWithBody<mlir::omp::SectionOp>(converter, eval, currentLocation,
                                       /*outerCombined=*/false,
                                       §ionsClauseList);
+  genNestedEvaluations(converter, eval);
+  symTable.popScope();
 }
 
 static void
 genOMP(Fortran::lower::AbstractConverter &converter,
-       Fortran::lower::pft::Evaluation &eval,
+       Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval,
        const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) {
   mlir::Location currentLocation = converter.getCurrentLocation();
   llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
@@ -3406,11 +3446,13 @@ genOMP(Fortran::lower::AbstractConverter &converter,
                                        /*reduction_vars=*/mlir::ValueRange(),
                                        /*reductions=*/nullptr, allocateOperands,
                                        allocatorOperands, nowaitClauseOperand);
+
+  genNestedEvaluations(converter, eval);
 }
 
 static void
 genOMP(Fortran::lower::AbstractConverter &converter,
-       Fortran::lower::pft::Evaluation &eval,
+       Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval,
        const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
   std::visit(
       Fortran::common::visitors{
@@ -3504,6 +3546,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
 }
 
 static void genOMP(Fortran::lower::AbstractConverter &converter,
+                   Fortran::lower::SymMap &symTable,
                    Fortran::semantics::SemanticsContext &semanticsContext,
                    Fortran::lower::pft::Evaluation &eval,
                    const Fortran::parser::OpenMPConstruct &ompConstruct) {
@@ -3511,17 +3554,18 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
       Fortran::common::visitors{
           [&](const Fortran::parser::OpenMPStandaloneConstruct
                   &standaloneConstruct) {
-            genOMP(converter, eval, semanticsContext, standaloneConstruct);
+            genOMP(converter, symTable, eval, semanticsContext,
+                   standaloneConstruct);
           },
           [&](const Fortran::parser::OpenMPSectionsConstruct
                   §ionsConstruct) {
-            genOMP(converter, eval, sectionsConstruct);
+            genOMP(converter, symTable, eval, sectionsConstruct);
           },
           [&](const Fortran::parser::OpenMPSectionConstruct §ionConstruct) {
-            genOMP(converter, eval, sectionConstruct);
+            genOMP(converter, symTable, eval, sectionConstruct);
           },
           [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
-            genOMP(converter, eval, semanticsContext, loopConstruct);
+            genOMP(converter, symTable, eval, semanticsContext, loopConstruct);
           },
           [&](const Fortran::parser::OpenMPDeclarativeAllocate
                   &execAllocConstruct) {
@@ -3536,14 +3580,14 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
             TODO(converter.getCurrentLocation(), "OpenMPAllocatorsConstruct");
           },
           [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
-            genOMP(converter, eval, semanticsContext, blockConstruct);
+            genOMP(converter, symTable, eval, semanticsContext, blockConstruct);
           },
           [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
-            genOMP(converter, eval, atomicConstruct);
+            genOMP(converter, symTable, eval, atomicConstruct);
           },
           [&](const Fortran::parser::OpenMPCriticalConstruct
                   &criticalConstruct) {
-            genOMP(converter, eval, criticalConstruct);
+            genOMP(converter, symTable, eval, criticalConstruct);
           },
       },
       ompConstruct.u);
@@ -3607,47 +3651,8 @@ void Fortran::lower::genOpenMPConstruct(
     Fortran::semantics::SemanticsContext &semanticsContext,
     Fortran::lower::pft::Evaluation &eval,
     const Fortran::parser::OpenMPConstruct &omp) {
-
   symTable.pushScope();
-  genOMP(converter, semanticsContext, eval, omp);
-
-  const Fortran::parser::OpenMPLoopConstruct *ompLoop =
-      std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
-  const Fortran::parser::OpenMPBlockConstruct *ompBlock =
-      std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u);
-
-  // If loop is part of an OpenMP Construct then the OpenMP dialect
-  // workshare loop operation has already been created. Only the
-  // body needs to be created here and the do_loop can be skipped.
-  // Skip the number of collapsed loops, which is 1 when there is a
-  // no collapse requested.
-
-  Fortran::lower::pft::Evaluation *curEval = &eval;
-  const Fortran::parser::OmpClauseList *loopOpClauseList = nullptr;
-  if (ompLoop) {
-    loopOpClauseList = &std::get<Fortran::parser::OmpClauseList>(
-        std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t);
-    int64_t collapseValue = Fortran::lower::getCollapseValue(*loopOpClauseList);
-
-    curEval = &curEval->getFirstNestedEvaluation();
-    for (int64_t i = 1; i < collapseValue; i++) {
-      curEval = &*std::next(curEval->getNestedEvaluations().begin());
-    }
-  }
-
-  for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
-    converter.genEval(e);
-
-  if (ompLoop) {
-    genOpenMPReduction(converter, *loopOpClauseList);
-  } else if (ompBlock) {
-    const auto &blockStart =
-        std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t);
-    const auto &blockClauses =
-        std::get<Fortran::parser::OmpClauseList>(blockStart.t);
-    genOpenMPReduction(converter, blockClauses);
-  }
-
+  genOMP(converter, symTable, semanticsContext, eval, omp);
   symTable.popScope();
 }
 
@@ -3656,8 +3661,7 @@ void Fortran::lower::genOpenMPDeclarativeConstruct(
     Fortran::lower::pft::Evaluation &eval,
     const Fortran::parser::OpenMPDeclarativeConstruct &omp) {
   genOMP(converter, eval, omp);
-  for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations())
-    converter.genEval(e);
+  genNestedEvaluations(converter, eval);
 }
 
 void Fortran::lower::genOpenMPSymbolProperties(
``````````
</details>
https://github.com/llvm/llvm-project/pull/77758
    
    
More information about the llvm-branch-commits
mailing list