[flang-commits] [flang] [Flang][OpenMP] NFC: Share DataSharingProcessor creation logic for all loop directives (PR #97565)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Fri Jul 5 02:17:04 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/97565

>From c8bf447e9ce66efd202c2cc5b88a1c554aca0861 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 3 Jul 2024 12:32:12 +0100
Subject: [PATCH] [Flang][OpenMP] NFC: Share DataSharingProcessor creation
 logic for all loop directives

This patch moves the logic associated with the creation of a
`DataSharingProcessor` instance for loop-associated OpenMP leaf constructs to
the `genOMPDispatch` function, avoiding code duplication for standalone and
composite loop constructs.

This also prevents privatization-related allocations to be later made inside of
loop wrappers when support for composite constructs is implemented.
---
 flang/lib/Lower/OpenMP/OpenMP.cpp | 94 ++++++++++++++++---------------
 1 file changed, 48 insertions(+), 46 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index e1efe0d1bd2d4..17804ff58edc0 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1316,12 +1316,9 @@ static mlir::omp::DistributeOp
 genDistributeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
                 semantics::SemanticsContext &semaCtx,
                 lower::pft::Evaluation &eval, mlir::Location loc,
-                const ConstructQueue &queue, ConstructQueue::iterator item) {
+                const ConstructQueue &queue, ConstructQueue::iterator item,
+                DataSharingProcessor &dsp) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  symTable.pushScope();
-  DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
-                           lower::omp::isLastItemInQueue(item, queue));
-  dsp.processStep1();
 
   lower::StatementContext stmtCtx;
   mlir::omp::LoopNestClauseOps loopClauseOps;
@@ -1359,7 +1356,6 @@ genDistributeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
                      .setGenRegionEntryCb(ivCallback),
                  queue, item);
 
-  symTable.popScope();
   return distributeOp;
 }
 
@@ -1580,12 +1576,8 @@ static mlir::omp::SimdOp
 genSimdOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
           semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
           mlir::Location loc, const ConstructQueue &queue,
-          ConstructQueue::iterator item) {
+          ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  symTable.pushScope();
-  DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
-                           lower::omp::isLastItemInQueue(item, queue));
-  dsp.processStep1();
 
   lower::StatementContext stmtCtx;
   mlir::omp::LoopNestClauseOps loopClauseOps;
@@ -1622,7 +1614,6 @@ genSimdOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
                      .setGenRegionEntryCb(ivCallback),
                  queue, item);
 
-  symTable.popScope();
   return simdOp;
 }
 
@@ -1861,7 +1852,8 @@ static mlir::omp::TaskloopOp
 genTaskloopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
               semantics::SemanticsContext &semaCtx,
               lower::pft::Evaluation &eval, mlir::Location loc,
-              const ConstructQueue &queue, ConstructQueue::iterator item) {
+              const ConstructQueue &queue, ConstructQueue::iterator item,
+              DataSharingProcessor &dsp) {
   TODO(loc, "Taskloop construct");
 }
 
@@ -1904,12 +1896,8 @@ static mlir::omp::WsloopOp
 genWsloopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
             semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
             mlir::Location loc, const ConstructQueue &queue,
-            ConstructQueue::iterator item) {
+            ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  symTable.pushScope();
-  DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
-                           lower::omp::isLastItemInQueue(item, queue));
-  dsp.processStep1();
 
   lower::StatementContext stmtCtx;
   mlir::omp::LoopNestClauseOps loopClauseOps;
@@ -1950,7 +1938,7 @@ genWsloopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
                      .setDataSharingProcessor(&dsp)
                      .setGenRegionEntryCb(ivCallback),
                  queue, item);
-  symTable.popScope();
+
   return wsloopOp;
 }
 
@@ -1962,7 +1950,7 @@ static void genCompositeDistributeParallelDo(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
     mlir::Location loc, const ConstructQueue &queue,
-    ConstructQueue::iterator item) {
+    ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
 }
 
@@ -1970,17 +1958,15 @@ static void genCompositeDistributeParallelDoSimd(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
     mlir::Location loc, const ConstructQueue &queue,
-    ConstructQueue::iterator item) {
+    ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
 }
 
-static void genCompositeDistributeSimd(lower::AbstractConverter &converter,
-                                       lower::SymMap &symTable,
-                                       semantics::SemanticsContext &semaCtx,
-                                       lower::pft::Evaluation &eval,
-                                       mlir::Location loc,
-                                       const ConstructQueue &queue,
-                                       ConstructQueue::iterator item) {
+static void genCompositeDistributeSimd(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   TODO(loc, "Composite DISTRIBUTE SIMD");
 }
 
@@ -1989,7 +1975,8 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
                                semantics::SemanticsContext &semaCtx,
                                lower::pft::Evaluation &eval, mlir::Location loc,
                                const ConstructQueue &queue,
-                               ConstructQueue::iterator item) {
+                               ConstructQueue::iterator item,
+                               DataSharingProcessor &dsp) {
   ClauseProcessor cp(converter, semaCtx, item->clauses);
   cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
                  clause::Safelen, clause::Simdlen>(loc,
@@ -2002,16 +1989,14 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
   // When support for vectorization is enabled, then we need to add handling of
   // if clause. Currently if clause can be skipped because we always assume
   // SIMD length = 1.
-  genWsloopOp(converter, symTable, semaCtx, eval, loc, queue, item);
+  genWsloopOp(converter, symTable, semaCtx, eval, loc, queue, item, dsp);
 }
 
-static void genCompositeTaskloopSimd(lower::AbstractConverter &converter,
-                                     lower::SymMap &symTable,
-                                     semantics::SemanticsContext &semaCtx,
-                                     lower::pft::Evaluation &eval,
-                                     mlir::Location loc,
-                                     const ConstructQueue &queue,
-                                     ConstructQueue::iterator item) {
+static void genCompositeTaskloopSimd(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::iterator item, DataSharingProcessor &dsp) {
   TODO(loc, "Composite TASKLOOP SIMD");
 }
 
@@ -2027,15 +2012,27 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                            ConstructQueue::iterator item) {
   assert(item != queue.end());
 
+  std::optional<DataSharingProcessor> loopDsp;
+  bool loopLeaf = llvm::omp::getDirectiveAssociation(item->id) ==
+                  llvm::omp::Association::Loop;
+  if (loopLeaf) {
+    symTable.pushScope();
+    loopDsp.emplace(converter, semaCtx, item->clauses, eval,
+                    /*shouldCollectPreDeterminedSymbols=*/true,
+                    enableDelayedPrivatization, &symTable);
+    loopDsp->processStep1();
+  }
+
   switch (llvm::omp::Directive dir = item->id) {
   case llvm::omp::Directive::OMPD_barrier:
     genBarrierOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_distribute:
-    genDistributeOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    genDistributeOp(converter, symTable, semaCtx, eval, loc, queue, item,
+                    *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_do:
-    genWsloopOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    genWsloopOp(converter, symTable, semaCtx, eval, loc, queue, item, *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_loop:
   case llvm::omp::Directive::OMPD_masked:
@@ -2058,7 +2055,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     genSectionsOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_simd:
-    genSimdOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    genSimdOp(converter, symTable, semaCtx, eval, loc, queue, item, *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_single:
     genSingleOp(converter, symTable, semaCtx, eval, loc, queue, item);
@@ -2088,7 +2085,8 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     genTaskgroupOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_taskloop:
-    genTaskloopOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    genTaskloopOp(converter, symTable, semaCtx, eval, loc, queue, item,
+                  *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_taskwait:
     genTaskwaitOp(converter, symTable, semaCtx, eval, loc, queue, item);
@@ -2114,26 +2112,30 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
   // Composite constructs
   case llvm::omp::Directive::OMPD_distribute_parallel_do:
     genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
-                                     queue, item);
+                                     queue, item, *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
     genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
-                                         loc, queue, item);
+                                         loc, queue, item, *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_distribute_simd:
     genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, queue,
-                               item);
+                               item, *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_do_simd:
-    genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item);
+    genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item,
+                       *loopDsp);
     break;
   case llvm::omp::Directive::OMPD_taskloop_simd:
     genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, queue,
-                             item);
+                             item, *loopDsp);
     break;
   default:
     break;
   }
+
+  if (loopLeaf)
+    symTable.popScope();
 }
 
 //===----------------------------------------------------------------------===//



More information about the flang-commits mailing list