[flang-commits] [flang] [Flang][OpenMP] NFC: Share DataSharingProcessor creation logic for all loop directives (PR #97565)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Thu Jul 4 07:52:00 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/97565

>From c2a4e9ce35b1ab9529e517bf65b3fc8414eec587 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 3 Jul 2024 12:32:12 +0100
Subject: [PATCH] [Flang][OpenMP] NFC: Share DataSharingProcessor creation
 logic for all loop directives

This patch moves the logic associated with the creation of a
`DataSharingProcessor` instance for loop-associated OpenMP leaf constructs to
the `genOMPDispatch` function, avoiding code duplication for standalone and
composite loop constructs.

This also prevents privatization-related allocations to be later made inside of
loop wrappers when support for composite constructs is implemented.
---
 flang/lib/Lower/OpenMP/OpenMP.cpp | 94 ++++++++++++++++---------------
 1 file changed, 48 insertions(+), 46 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 5b2cf7930e4a3..4ee25413ad6ac 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1342,12 +1342,9 @@ static mlir::omp::DistributeOp
 genDistributeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
                 semantics::SemanticsContext &semaCtx,
                 lower::pft::Evaluation &eval, mlir::Location loc,
-                const ConstructQueue &queue, ConstructQueue::iterator item) {
+                const ConstructQueue &queue, ConstructQueue::iterator item,
+                DataSharingProcessor &dsp) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  symTable.pushScope();
-  DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
-                           lower::omp::isLastItemInQueue(item, queue));
-  dsp.processStep1();
 
   lower::StatementContext stmtCtx;
   mlir::omp::LoopNestClauseOps loopClauseOps;
@@ -1385,7 +1382,6 @@ genDistributeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
                      .setGenRegionEntryCb(ivCallback),
                  queue, item);
 
-  symTable.popScope();
   return distributeOp;
 }
 
@@ -1613,12 +1609,8 @@ static mlir::omp::SimdOp
 genSimdOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
           semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
           mlir::Location loc, const ConstructQueue &queue,
-          ConstructQueue::iterator item) {
+          ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  symTable.pushScope();
-  DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
-                           lower::omp::isLastItemInQueue(item, queue));
-  dsp.processStep1();
 
   lower::StatementContext stmtCtx;
   mlir::omp::LoopNestClauseOps loopClauseOps;
@@ -1655,7 +1647,6 @@ genSimdOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
                      .setGenRegionEntryCb(ivCallback),
                  queue, item);
 
-  symTable.popScope();
   return simdOp;
 }
 
@@ -1894,7 +1885,8 @@ static mlir::omp::TaskloopOp
 genTaskloopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
               semantics::SemanticsContext &semaCtx,
               lower::pft::Evaluation &eval, mlir::Location loc,
-              const ConstructQueue &queue, ConstructQueue::iterator item) {
+              const ConstructQueue &queue, ConstructQueue::iterator item,
+              DataSharingProcessor &dsp) {
   TODO(loc, "Taskloop construct");
 }
 
@@ -1938,12 +1930,8 @@ static mlir::omp::WsloopOp
 genWsloopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
             semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
             mlir::Location loc, const ConstructQueue &queue,
-            ConstructQueue::iterator item) {
+            ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  symTable.pushScope();
-  DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
-                           lower::omp::isLastItemInQueue(item, queue));
-  dsp.processStep1();
 
   lower::StatementContext stmtCtx;
   mlir::omp::LoopNestClauseOps loopClauseOps;
@@ -1985,7 +1973,7 @@ genWsloopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
                      .setReductions(&reductionSyms, &reductionTypes)
                      .setGenRegionEntryCb(ivCallback),
                  queue, item);
-  symTable.popScope();
+
   return wsloopOp;
 }
 
@@ -1997,7 +1985,7 @@ static void genCompositeDistributeParallelDo(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
     mlir::Location loc, const ConstructQueue &queue,
-    ConstructQueue::iterator item) {
+    ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
 }
 
@@ -2005,17 +1993,15 @@ static void genCompositeDistributeParallelDoSimd(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
     mlir::Location loc, const ConstructQueue &queue,
-    ConstructQueue::iterator item) {
+    ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
 }
 
-static void genCompositeDistributeSimd(lower::AbstractConverter &converter,
-                                       lower::SymMap &symTable,
-                                       semantics::SemanticsContext &semaCtx,
-                                       lower::pft::Evaluation &eval,
-                                       mlir::Location loc,
-                                       const ConstructQueue &queue,
-                                       ConstructQueue::iterator item) {
+static void genCompositeDistributeSimd(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   TODO(loc, "Composite DISTRIBUTE SIMD");
 }
 
@@ -2024,7 +2010,8 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
                                semantics::SemanticsContext &semaCtx,
                                lower::pft::Evaluation &eval, mlir::Location loc,
                                const ConstructQueue &queue,
-                               ConstructQueue::iterator item) {
+                               ConstructQueue::iterator item,
+                               DataSharingProcessor &dsp) {
   ClauseProcessor cp(converter, semaCtx, item->clauses);
   cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
                  clause::Safelen, clause::Simdlen>(loc,
@@ -2037,16 +2024,14 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
   // When support for vectorization is enabled, then we need to add handling of
   // if clause. Currently if clause can be skipped because we always assume
   // SIMD length = 1.
-  genWsloopOp(converter, symTable, semaCtx, eval, loc, queue, item);
+  genWsloopOp(converter, symTable, semaCtx, eval, loc, queue, item, dsp);
 }
 
-static void genCompositeTaskloopSimd(lower::AbstractConverter &converter,
-                                     lower::SymMap &symTable,
-                                     semantics::SemanticsContext &semaCtx,
-                                     lower::pft::Evaluation &eval,
-                                     mlir::Location loc,
-                                     const ConstructQueue &queue,
-                                     ConstructQueue::iterator item) {
+static void genCompositeTaskloopSimd(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   TODO(loc, "Composite TASKLOOP SIMD");
 }
 
@@ -2062,15 +2047,27 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                            ConstructQueue::iterator item) {
   assert(item != queue.end());
 
+  std::optional<DataSharingProcessor> loopDsp;
+  bool loopLeaf = llvm::omp::getDirectiveAssociation(item->id) ==
+                  llvm::omp::Association::Loop;
+  if (loopLeaf) {
+    symTable.pushScope();
+    loopDsp.emplace(converter, semaCtx, item->clauses, eval,
+                    /*shouldCollectPreDeterminedSymbols=*/true,
+                    enableDelayedPrivatization, &symTable);
+    loopDsp->processStep1();
+  }
+
   switch (llvm::omp::Directive dir = item->id) {
   case llvm::omp::Directive::OMPD_barrier:
     genBarrierOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_distribute:
-    genDistributeOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    genDistributeOp(converter, symTable, semaCtx, eval, loc, queue, item,
+                    *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_do:
-    genWsloopOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    genWsloopOp(converter, symTable, semaCtx, eval, loc, queue, item, *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_loop:
   case llvm::omp::Directive::OMPD_masked:
@@ -2094,7 +2091,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     genSectionsOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_simd:
-    genSimdOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    genSimdOp(converter, symTable, semaCtx, eval, loc, queue, item, *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_single:
     genSingleOp(converter, symTable, semaCtx, eval, loc, queue, item);
@@ -2124,7 +2121,8 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     genTaskgroupOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_taskloop:
-    genTaskloopOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    genTaskloopOp(converter, symTable, semaCtx, eval, loc, queue, item,
+                  *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_taskwait:
     genTaskwaitOp(converter, symTable, semaCtx, eval, loc, queue, item);
@@ -2150,26 +2148,30 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
   // Composite constructs
   case llvm::omp::Directive::OMPD_distribute_parallel_do:
     genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
-                                     queue, item);
+                                     queue, item, *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
     genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
-                                         loc, queue, item);
+                                         loc, queue, item, *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_distribute_simd:
     genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, queue,
-                               item);
+                               item, *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_do_simd:
-    genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item);
+    genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item,
+                       *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_taskloop_simd:
     genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, queue,
-                             item);
+                             item, *loopDsp);
     break;
   default:
     break;
   }
+
+  if (loopLeaf)
+    symTable.popScope();
 }
 
 //===----------------------------------------------------------------------===//



More information about the flang-commits mailing list