[flang-commits] [flang] [flang][OpenMP][Lower] fix statement context cleanup insertion point (PR #133891)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Mon Apr 7 09:31:34 PDT 2025


https://github.com/tblah updated https://github.com/llvm/llvm-project/pull/133891

>From f842579fa10cee4eb28654ef45a80934734ddf44 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Thu, 27 Mar 2025 17:13:36 +0000
Subject: [PATCH 1/3] [flang][OpenMP][Lower] fix statement context cleanup
 insertion point

The statement context is used for lowering clauses for openmp
operations using generalised helpers from flang lowering. The statement
context stores closures which generate code for cleaning up temporary
values generated by the lowering helper. These closures are run when the
statement construct is destroyed. Keeping the statement context local to
the clause or operation being lowered without any special handling was
not correct because any cleanup code would be generated at the insertion
point when that statement context went out of scope (which would in
general be inside of the newly created container operation). It would be
better to generate the cleanup code after the newly created operation
(clause processing is synchronous even for deferred tasks).

Currently supported clauses are mostly populated with simple scalar values
that require no cleanup. Even the simple array sections added by #132994
needed no cleanup because indexing the right values of the array did not
create any temporaries. Supporting array sections with vector indexing
will generate hlfir.destroy operations for cleanup. This patch fixes
where those will be created. Those hlfir.destroy operations don't
generate any FIR (or LLVM) code, but the issue still exists
theoretically.

I wasn't able to find any clauses which have any cleanup to use to test
this PR. It is probably NFC for the current lowering. This will be
tested in the PR adding vector subscripting of array sections.
---
 flang/lib/Lower/OpenMP/OpenMP.cpp | 328 +++++++++++++++---------------
 1 file changed, 167 insertions(+), 161 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index ab90b4609e855..d81e618949502 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1763,7 +1763,6 @@ static void genTaskClauses(lower::AbstractConverter &converter,
   cp.processPriority(stmtCtx, clauseOps);
   cp.processUntied(clauseOps);
   cp.processDetach(clauseOps);
-  // TODO Support delayed privatization.
 
   cp.processTODO<clause::Affinity, clause::InReduction>(
       loc, llvm::omp::Directive::OMPD_task);
@@ -1914,12 +1913,11 @@ static mlir::omp::LoopNestOp genLoopNestOp(
       queue, item, clauseOps);
 }
 
-static void genLoopOp(lower::AbstractConverter &converter,
-                      lower::SymMap &symTable,
-                      semantics::SemanticsContext &semaCtx,
-                      lower::pft::Evaluation &eval, mlir::Location loc,
-                      const ConstructQueue &queue,
-                      ConstructQueue::const_iterator item) {
+static mlir::Operation *
+genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+          semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+          mlir::Location loc, const ConstructQueue &queue,
+          ConstructQueue::const_iterator item) {
   mlir::omp::LoopOperands loopClauseOps;
   llvm::SmallVector<const semantics::Symbol *> loopReductionSyms;
   genLoopClauses(converter, semaCtx, item->clauses, loc, loopClauseOps,
@@ -1946,14 +1944,15 @@ static void genLoopOp(lower::AbstractConverter &converter,
   genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
                 loopNestClauseOps, iv, {{loopOp, loopArgs}},
                 llvm::omp::Directive::OMPD_loop, dsp);
+  return loopOp;
 }
 
 static mlir::omp::MaskedOp
 genMaskedOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+            lower::StatementContext &stmtCtx,
             semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
             mlir::Location loc, const ConstructQueue &queue,
             ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
   mlir::omp::MaskedOperands clauseOps;
   genMaskedClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
 
@@ -2157,13 +2156,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   return sectionsOp;
 }
 
-static void genScopeOp(lower::AbstractConverter &converter,
-                       lower::SymMap &symTable,
-                       semantics::SemanticsContext &semaCtx,
-                       lower::pft::Evaluation &eval, mlir::Location loc,
-                       const ConstructQueue &queue,
-                       ConstructQueue::const_iterator item) {
+static mlir::Operation *
+genScopeOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+           semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+           mlir::Location loc, const ConstructQueue &queue,
+           ConstructQueue::const_iterator item) {
   TODO(loc, "Scope construct");
+  return nullptr;
 }
 
 static mlir::omp::SingleOp
@@ -2183,11 +2182,11 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
 static mlir::omp::TargetOp
 genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+            lower::StatementContext &stmtCtx,
             semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
             mlir::Location loc, const ConstructQueue &queue,
             ConstructQueue::const_iterator item) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  lower::StatementContext stmtCtx;
   bool isTargetDevice =
       llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp())
           .getIsTargetDevice();
@@ -2366,13 +2365,11 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   return targetOp;
 }
 
-static mlir::omp::TargetDataOp
-genTargetDataOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
-                semantics::SemanticsContext &semaCtx,
-                lower::pft::Evaluation &eval, mlir::Location loc,
-                const ConstructQueue &queue,
-                ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
+static mlir::omp::TargetDataOp genTargetDataOp(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item) {
   mlir::omp::TargetDataOperands clauseOps;
   llvm::SmallVector<const semantics::Symbol *> useDeviceAddrSyms,
       useDevicePtrSyms;
@@ -2402,10 +2399,10 @@ genTargetDataOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 template <typename OpTy>
 static OpTy genTargetEnterExitUpdateDataOp(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
-    semantics::SemanticsContext &semaCtx, mlir::Location loc,
-    const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::const_iterator item) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  lower::StatementContext stmtCtx;
 
   // GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
   [[maybe_unused]] llvm::omp::Directive directive;
@@ -2428,10 +2425,10 @@ static OpTy genTargetEnterExitUpdateDataOp(
 
 static mlir::omp::TaskOp
 genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+          lower::StatementContext &stmtCtx,
           semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
           mlir::Location loc, const ConstructQueue &queue,
           ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
   mlir::omp::TaskOperands clauseOps;
   genTaskClauses(converter, semaCtx, symTable, stmtCtx, item->clauses, loc,
                  clauseOps);
@@ -2498,13 +2495,11 @@ genTaskyieldOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   return converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(loc);
 }
 
-static mlir::omp::WorkshareOp
-genWorkshareOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
-               semantics::SemanticsContext &semaCtx,
-               lower::pft::Evaluation &eval, mlir::Location loc,
-               const ConstructQueue &queue,
-               ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
+static mlir::omp::WorkshareOp genWorkshareOp(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item) {
   mlir::omp::WorkshareOperands clauseOps;
   genWorkshareClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
                       clauseOps);
@@ -2518,11 +2513,10 @@ genWorkshareOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
 static mlir::omp::TeamsOp
 genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+           lower::StatementContext &stmtCtx,
            semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
            mlir::Location loc, const ConstructQueue &queue,
            ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
-
   mlir::omp::TeamsOperands clauseOps;
   llvm::SmallVector<const semantics::Symbol *> reductionSyms;
   genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
@@ -2546,15 +2540,11 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 // also be a leaf of a composite construct
 //===----------------------------------------------------------------------===//
 
-static void genStandaloneDistribute(lower::AbstractConverter &converter,
-                                    lower::SymMap &symTable,
-                                    semantics::SemanticsContext &semaCtx,
-                                    lower::pft::Evaluation &eval,
-                                    mlir::Location loc,
-                                    const ConstructQueue &queue,
-                                    ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
-
+static mlir::Operation *genStandaloneDistribute(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item) {
   mlir::omp::DistributeOperands distributeClauseOps;
   genDistributeClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
                        distributeClauseOps);
@@ -2578,16 +2568,14 @@ static void genStandaloneDistribute(lower::AbstractConverter &converter,
   genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
                 loopNestClauseOps, iv, {{distributeOp, distributeArgs}},
                 llvm::omp::Directive::OMPD_distribute, dsp);
+  return distributeOp;
 }
 
-static void genStandaloneDo(lower::AbstractConverter &converter,
-                            lower::SymMap &symTable,
-                            semantics::SemanticsContext &semaCtx,
-                            lower::pft::Evaluation &eval, mlir::Location loc,
-                            const ConstructQueue &queue,
-                            ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
-
+static mlir::Operation *genStandaloneDo(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item) {
   mlir::omp::WsloopOperands wsloopClauseOps;
   llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
   genWsloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
@@ -2614,17 +2602,14 @@ static void genStandaloneDo(lower::AbstractConverter &converter,
   genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
                 loopNestClauseOps, iv, {{wsloopOp, wsloopArgs}},
                 llvm::omp::Directive::OMPD_do, dsp);
+  return wsloopOp;
 }
 
-static void genStandaloneParallel(lower::AbstractConverter &converter,
-                                  lower::SymMap &symTable,
-                                  semantics::SemanticsContext &semaCtx,
-                                  lower::pft::Evaluation &eval,
-                                  mlir::Location loc,
-                                  const ConstructQueue &queue,
-                                  ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
-
+static mlir::Operation *genStandaloneParallel(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item) {
   mlir::omp::ParallelOperands parallelClauseOps;
   llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
   genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
@@ -2644,17 +2629,18 @@ static void genStandaloneParallel(lower::AbstractConverter &converter,
   parallelArgs.priv.vars = parallelClauseOps.privateVars;
   parallelArgs.reduction.syms = parallelReductionSyms;
   parallelArgs.reduction.vars = parallelClauseOps.reductionVars;
-  genParallelOp(converter, symTable, semaCtx, eval, loc, queue, item,
-                parallelClauseOps, parallelArgs,
-                enableDelayedPrivatization ? &dsp.value() : nullptr);
-}
-
-static void genStandaloneSimd(lower::AbstractConverter &converter,
-                              lower::SymMap &symTable,
-                              semantics::SemanticsContext &semaCtx,
-                              lower::pft::Evaluation &eval, mlir::Location loc,
-                              const ConstructQueue &queue,
-                              ConstructQueue::const_iterator item) {
+  return genParallelOp(converter, symTable, semaCtx, eval, loc, queue, item,
+                       parallelClauseOps, parallelArgs,
+                       enableDelayedPrivatization ? &dsp.value() : nullptr);
+}
+
+static mlir::Operation *genStandaloneSimd(lower::AbstractConverter &converter,
+                                          lower::SymMap &symTable,
+                                          semantics::SemanticsContext &semaCtx,
+                                          lower::pft::Evaluation &eval,
+                                          mlir::Location loc,
+                                          const ConstructQueue &queue,
+                                          ConstructQueue::const_iterator item) {
   mlir::omp::SimdOperands simdClauseOps;
   llvm::SmallVector<const semantics::Symbol *> simdReductionSyms;
   genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps,
@@ -2681,29 +2667,27 @@ static void genStandaloneSimd(lower::AbstractConverter &converter,
   genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
                 loopNestClauseOps, iv, {{simdOp, simdArgs}},
                 llvm::omp::Directive::OMPD_simd, dsp);
+  return simdOp;
 }
 
-static void genStandaloneTaskloop(lower::AbstractConverter &converter,
-                                  lower::SymMap &symTable,
-                                  semantics::SemanticsContext &semaCtx,
-                                  lower::pft::Evaluation &eval,
-                                  mlir::Location loc,
-                                  const ConstructQueue &queue,
-                                  ConstructQueue::const_iterator item) {
+static mlir::Operation *genStandaloneTaskloop(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+    mlir::Location loc, const ConstructQueue &queue,
+    ConstructQueue::const_iterator item) {
   TODO(loc, "Taskloop construct");
+  return nullptr;
 }
 
 //===----------------------------------------------------------------------===//
 // Code generation functions for composite constructs
 //===----------------------------------------------------------------------===//
 
-static void genCompositeDistributeParallelDo(
+static mlir::Operation *genCompositeDistributeParallelDo(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
-    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
-    mlir::Location loc, const ConstructQueue &queue,
-    ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
-
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item) {
   assert(std::distance(item, queue.end()) == 3 && "Invalid leaf constructs");
   ConstructQueue::const_iterator distributeItem = item;
   ConstructQueue::const_iterator parallelItem = std::next(distributeItem);
@@ -2762,15 +2746,14 @@ static void genCompositeDistributeParallelDo(
                 loopNestClauseOps, iv,
                 {{distributeOp, distributeArgs}, {wsloopOp, wsloopArgs}},
                 llvm::omp::Directive::OMPD_distribute_parallel_do, dsp);
+  return distributeOp;
 }
 
-static void genCompositeDistributeParallelDoSimd(
+static mlir::Operation *genCompositeDistributeParallelDoSimd(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
-    semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
-    mlir::Location loc, const ConstructQueue &queue,
-    ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
-
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item) {
   assert(std::distance(item, queue.end()) == 4 && "Invalid leaf constructs");
   ConstructQueue::const_iterator distributeItem = item;
   ConstructQueue::const_iterator parallelItem = std::next(distributeItem);
@@ -2854,17 +2837,14 @@ static void genCompositeDistributeParallelDoSimd(
                  {simdOp, simdArgs}},
                 llvm::omp::Directive::OMPD_distribute_parallel_do_simd,
                 simdItemDSP);
+  return distributeOp;
 }
 
-static void genCompositeDistributeSimd(lower::AbstractConverter &converter,
-                                       lower::SymMap &symTable,
-                                       semantics::SemanticsContext &semaCtx,
-                                       lower::pft::Evaluation &eval,
-                                       mlir::Location loc,
-                                       const ConstructQueue &queue,
-                                       ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
-
+static mlir::Operation *genCompositeDistributeSimd(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item) {
   assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
   ConstructQueue::const_iterator distributeItem = item;
   ConstructQueue::const_iterator simdItem = std::next(distributeItem);
@@ -2911,16 +2891,14 @@ static void genCompositeDistributeSimd(lower::AbstractConverter &converter,
                 loopNestClauseOps, iv,
                 {{distributeOp, distributeArgs}, {simdOp, simdArgs}},
                 llvm::omp::Directive::OMPD_distribute_simd, dsp);
+  return distributeOp;
 }
 
-static void genCompositeDoSimd(lower::AbstractConverter &converter,
-                               lower::SymMap &symTable,
-                               semantics::SemanticsContext &semaCtx,
-                               lower::pft::Evaluation &eval, mlir::Location loc,
-                               const ConstructQueue &queue,
-                               ConstructQueue::const_iterator item) {
-  lower::StatementContext stmtCtx;
-
+static mlir::Operation *genCompositeDoSimd(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item) {
   assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
   ConstructQueue::const_iterator doItem = item;
   ConstructQueue::const_iterator simdItem = std::next(doItem);
@@ -2970,30 +2948,29 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
                 loopNestClauseOps, iv,
                 {{wsloopOp, wsloopArgs}, {simdOp, simdArgs}},
                 llvm::omp::Directive::OMPD_do_simd, dsp);
+  return wsloopOp;
 }
 
-static void genCompositeTaskloopSimd(lower::AbstractConverter &converter,
-                                     lower::SymMap &symTable,
-                                     semantics::SemanticsContext &semaCtx,
-                                     lower::pft::Evaluation &eval,
-                                     mlir::Location loc,
-                                     const ConstructQueue &queue,
-                                     ConstructQueue::const_iterator item) {
+static mlir::Operation *genCompositeTaskloopSimd(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item) {
   assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
   TODO(loc, "Composite TASKLOOP SIMD");
+  return nullptr;
 }
 
 //===----------------------------------------------------------------------===//
 // Dispatch
 //===----------------------------------------------------------------------===//
 
-static bool genOMPCompositeDispatch(lower::AbstractConverter &converter,
-                                    lower::SymMap &symTable,
-                                    semantics::SemanticsContext &semaCtx,
-                                    lower::pft::Evaluation &eval,
-                                    mlir::Location loc,
-                                    const ConstructQueue &queue,
-                                    ConstructQueue::const_iterator item) {
+static bool genOMPCompositeDispatch(
+    lower::AbstractConverter &converter, lower::SymMap &symTable,
+    lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
+    lower::pft::Evaluation &eval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item,
+    mlir::Operation *&newOp) {
   using llvm::omp::Directive;
   using lower::omp::matchLeafSequence;
 
@@ -3002,20 +2979,21 @@ static bool genOMPCompositeDispatch(lower::AbstractConverter &converter,
   // correct. Consider per-leaf privatization of composite constructs once
   // delayed privatization is supported by all participating ops.
   if (matchLeafSequence(item, queue, Directive::OMPD_distribute_parallel_do))
-    genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
-                                     queue, item);
+    newOp = genCompositeDistributeParallelDo(converter, symTable, stmtCtx,
+                                             semaCtx, eval, loc, queue, item);
   else if (matchLeafSequence(item, queue,
                              Directive::OMPD_distribute_parallel_do_simd))
-    genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
-                                         loc, queue, item);
+    newOp = genCompositeDistributeParallelDoSimd(
+        converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
   else if (matchLeafSequence(item, queue, Directive::OMPD_distribute_simd))
-    genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, queue,
-                               item);
+    newOp = genCompositeDistributeSimd(converter, symTable, stmtCtx, semaCtx,
+                                       eval, loc, queue, item);
   else if (matchLeafSequence(item, queue, Directive::OMPD_do_simd))
-    genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genCompositeDoSimd(converter, symTable, stmtCtx, semaCtx, eval, loc,
+                               queue, item);
   else if (matchLeafSequence(item, queue, Directive::OMPD_taskloop_simd))
-    genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, queue,
-                             item);
+    newOp = genCompositeTaskloopSimd(converter, symTable, stmtCtx, semaCtx,
+                                     eval, loc, queue, item);
   else
     return false;
 
@@ -3030,46 +3008,64 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                            ConstructQueue::const_iterator item) {
   assert(item != queue.end());
 
+  lower::StatementContext stmtCtx;
+  mlir::Operation *newOp = nullptr;
+
+  // Generate cleanup code for the stmtCtx after newOp
+  auto finalizeStmtCtx = [&]() {
+    if (newOp) {
+      fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+      fir::FirOpBuilder::InsertionGuard guard(builder);
+      builder.setInsertionPointAfter(newOp);
+      stmtCtx.finalizeAndPop();
+    }
+  };
+
   bool loopLeaf = llvm::omp::getDirectiveAssociation(item->id) ==
                   llvm::omp::Association::Loop;
   if (loopLeaf) {
     symTable.pushScope();
-    if (genOMPCompositeDispatch(converter, symTable, semaCtx, eval, loc, queue,
-                                item)) {
+    if (genOMPCompositeDispatch(converter, symTable, stmtCtx, semaCtx, eval,
+                                loc, queue, item, newOp)) {
       symTable.popScope();
+      finalizeStmtCtx();
       return;
     }
   }
 
   switch (llvm::omp::Directive dir = item->id) {
   case llvm::omp::Directive::OMPD_barrier:
-    genBarrierOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genBarrierOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_distribute:
-    genStandaloneDistribute(converter, symTable, semaCtx, eval, loc, queue,
-                            item);
+    newOp = genStandaloneDistribute(converter, symTable, stmtCtx, semaCtx, eval,
+                                    loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_do:
-    genStandaloneDo(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genStandaloneDo(converter, symTable, stmtCtx, semaCtx, eval, loc,
+                            queue, item);
     break;
   case llvm::omp::Directive::OMPD_loop:
-    genLoopOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genLoopOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_masked:
-    genMaskedOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genMaskedOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
+                        item);
     break;
   case llvm::omp::Directive::OMPD_master:
-    genMasterOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genMasterOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_ordered:
     // Block-associated "ordered" construct.
-    genOrderedRegionOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genOrderedRegionOp(converter, symTable, semaCtx, eval, loc, queue,
+                               item);
     break;
   case llvm::omp::Directive::OMPD_parallel:
-    genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genStandaloneParallel(converter, symTable, stmtCtx, semaCtx, eval,
+                                  loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_scan:
-    genScanOp(converter, symTable, semaCtx, loc, queue, item);
+    newOp = genScanOp(converter, symTable, semaCtx, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_section:
     llvm_unreachable("genOMPDispatch: OMPD_section");
@@ -3082,49 +3078,57 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     // in genBodyOfOp
     break;
   case llvm::omp::Directive::OMPD_simd:
-    genStandaloneSimd(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp =
+        genStandaloneSimd(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_scope:
-    genScopeOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genScopeOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_single:
-    genSingleOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genSingleOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_target:
-    genTargetOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genTargetOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
+                        item);
     break;
   case llvm::omp::Directive::OMPD_target_data:
-    genTargetDataOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genTargetDataOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
+                            queue, item);
     break;
   case llvm::omp::Directive::OMPD_target_enter_data:
-    genTargetEnterExitUpdateDataOp<mlir::omp::TargetEnterDataOp>(
-        converter, symTable, semaCtx, loc, queue, item);
+    newOp = genTargetEnterExitUpdateDataOp<mlir::omp::TargetEnterDataOp>(
+        converter, symTable, stmtCtx, semaCtx, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_target_exit_data:
-    genTargetEnterExitUpdateDataOp<mlir::omp::TargetExitDataOp>(
-        converter, symTable, semaCtx, loc, queue, item);
+    newOp = genTargetEnterExitUpdateDataOp<mlir::omp::TargetExitDataOp>(
+        converter, symTable, stmtCtx, semaCtx, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_target_update:
-    genTargetEnterExitUpdateDataOp<mlir::omp::TargetUpdateOp>(
-        converter, symTable, semaCtx, loc, queue, item);
+    newOp = genTargetEnterExitUpdateDataOp<mlir::omp::TargetUpdateOp>(
+        converter, symTable, stmtCtx, semaCtx, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_task:
-    genTaskOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genTaskOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
+                      item);
     break;
   case llvm::omp::Directive::OMPD_taskgroup:
-    genTaskgroupOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp =
+        genTaskgroupOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_taskloop:
-    genStandaloneTaskloop(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genStandaloneTaskloop(converter, symTable, semaCtx, eval, loc,
+                                  queue, item);
     break;
   case llvm::omp::Directive::OMPD_taskwait:
-    genTaskwaitOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genTaskwaitOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_taskyield:
-    genTaskyieldOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp =
+        genTaskyieldOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_teams:
-    genTeamsOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
+                       item);
     break;
   case llvm::omp::Directive::OMPD_tile:
   case llvm::omp::Directive::OMPD_unroll:
@@ -3132,7 +3136,8 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                   llvm::omp::getOpenMPDirectiveName(dir) + ")");
   // case llvm::omp::Directive::OMPD_workdistribute:
   case llvm::omp::Directive::OMPD_workshare:
-    genWorkshareOp(converter, symTable, semaCtx, eval, loc, queue, item);
+    newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,
+                           queue, item);
     break;
   default:
     // Combined and composite constructs should have been split into a sequence
@@ -3142,6 +3147,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     break;
   }
 
+  finalizeStmtCtx();
   if (loopLeaf)
     symTable.popScope();
 }

>From f6ae92c3685f30c62adf2b4a0663456b215ebe3e Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 7 Apr 2025 16:19:57 +0000
Subject: [PATCH 2/3] Standardise on returning the concrete op type

---
 flang/lib/Lower/OpenMP/OpenMP.cpp | 33 +++++++++++++++----------------
 1 file changed, 16 insertions(+), 17 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index d81e618949502..0968df4f88e28 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1913,7 +1913,7 @@ static mlir::omp::LoopNestOp genLoopNestOp(
       queue, item, clauseOps);
 }
 
-static mlir::Operation *
+static mlir::omp::LoopOp
 genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
           semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
           mlir::Location loc, const ConstructQueue &queue,
@@ -2540,7 +2540,7 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 // also be a leaf of a composite construct
 //===----------------------------------------------------------------------===//
 
-static mlir::Operation *genStandaloneDistribute(
+static mlir::omp::DistributeOp genStandaloneDistribute(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
     lower::pft::Evaluation &eval, mlir::Location loc,
@@ -2571,7 +2571,7 @@ static mlir::Operation *genStandaloneDistribute(
   return distributeOp;
 }
 
-static mlir::Operation *genStandaloneDo(
+static mlir::omp::WsloopOp genStandaloneDo(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
     lower::pft::Evaluation &eval, mlir::Location loc,
@@ -2605,7 +2605,7 @@ static mlir::Operation *genStandaloneDo(
   return wsloopOp;
 }
 
-static mlir::Operation *genStandaloneParallel(
+static mlir::omp::ParallelOp genStandaloneParallel(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
     lower::pft::Evaluation &eval, mlir::Location loc,
@@ -2634,13 +2634,12 @@ static mlir::Operation *genStandaloneParallel(
                        enableDelayedPrivatization ? &dsp.value() : nullptr);
 }
 
-static mlir::Operation *genStandaloneSimd(lower::AbstractConverter &converter,
-                                          lower::SymMap &symTable,
-                                          semantics::SemanticsContext &semaCtx,
-                                          lower::pft::Evaluation &eval,
-                                          mlir::Location loc,
-                                          const ConstructQueue &queue,
-                                          ConstructQueue::const_iterator item) {
+static mlir::omp::SimdOp
+genStandaloneSimd(lower::AbstractConverter &converter, lower::SymMap &symTable,
+                  semantics::SemanticsContext &semaCtx,
+                  lower::pft::Evaluation &eval, mlir::Location loc,
+                  const ConstructQueue &queue,
+                  ConstructQueue::const_iterator item) {
   mlir::omp::SimdOperands simdClauseOps;
   llvm::SmallVector<const semantics::Symbol *> simdReductionSyms;
   genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps,
@@ -2670,7 +2669,7 @@ static mlir::Operation *genStandaloneSimd(lower::AbstractConverter &converter,
   return simdOp;
 }
 
-static mlir::Operation *genStandaloneTaskloop(
+static mlir::omp::TaskloopOp genStandaloneTaskloop(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
     mlir::Location loc, const ConstructQueue &queue,
@@ -2683,7 +2682,7 @@ static mlir::Operation *genStandaloneTaskloop(
 // Code generation functions for composite constructs
 //===----------------------------------------------------------------------===//
 
-static mlir::Operation *genCompositeDistributeParallelDo(
+static mlir::omp::DistributeOp genCompositeDistributeParallelDo(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
     lower::pft::Evaluation &eval, mlir::Location loc,
@@ -2749,7 +2748,7 @@ static mlir::Operation *genCompositeDistributeParallelDo(
   return distributeOp;
 }
 
-static mlir::Operation *genCompositeDistributeParallelDoSimd(
+static mlir::omp::DistributeOp genCompositeDistributeParallelDoSimd(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
     lower::pft::Evaluation &eval, mlir::Location loc,
@@ -2840,7 +2839,7 @@ static mlir::Operation *genCompositeDistributeParallelDoSimd(
   return distributeOp;
 }
 
-static mlir::Operation *genCompositeDistributeSimd(
+static mlir::omp::DistributeOp genCompositeDistributeSimd(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
     lower::pft::Evaluation &eval, mlir::Location loc,
@@ -2894,7 +2893,7 @@ static mlir::Operation *genCompositeDistributeSimd(
   return distributeOp;
 }
 
-static mlir::Operation *genCompositeDoSimd(
+static mlir::omp::WsloopOp genCompositeDoSimd(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
     lower::pft::Evaluation &eval, mlir::Location loc,
@@ -2951,7 +2950,7 @@ static mlir::Operation *genCompositeDoSimd(
   return wsloopOp;
 }
 
-static mlir::Operation *genCompositeTaskloopSimd(
+static mlir::omp::TaskloopOp genCompositeTaskloopSimd(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     lower::StatementContext &stmtCtx, semantics::SemanticsContext &semaCtx,
     lower::pft::Evaluation &eval, mlir::Location loc,

>From 670b95b3661d433d05917b9f0484501ba1f1b836 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 7 Apr 2025 16:30:45 +0000
Subject: [PATCH 3/3] Add test

---
 flang/test/Lower/OpenMP/clause-cleanup.f90 | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)
 create mode 100644 flang/test/Lower/OpenMP/clause-cleanup.f90

diff --git a/flang/test/Lower/OpenMP/clause-cleanup.f90 b/flang/test/Lower/OpenMP/clause-cleanup.f90
new file mode 100644
index 0000000000000..79de44cf42c72
--- /dev/null
+++ b/flang/test/Lower/OpenMP/clause-cleanup.f90
@@ -0,0 +1,17 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+subroutine test1(a)
+integer :: a(:)
+
+!$omp parallel num_threads(count(a .eq. 1))
+print *, "don't optimize me"
+!$omp end parallel
+end subroutine
+
+! CHECK:     %[[EXPR:.*]] = hlfir.elemental {{.*}} -> !hlfir.expr<?x!fir.logical<4>>
+! CHECK:     %[[COUNT:.*]] = hlfir.count %[[EXPR]]
+! CHECK:     omp.parallel num_threads(%[[COUNT]] : i32) {
+! CHECK-NOT:   hlfir.destory %[[EXPR]]
+! CHECK:     omp.terminator
+! CHECK:    }
+! CHECK:    hlfir.destroy %[[EXPR]]



More information about the flang-commits mailing list