[flang-commits] [flang] [Flang][OpenMP] Finalize Statement context of clauses before op creation (PR #71679)

Kiran Chandramohan via flang-commits flang-commits at lists.llvm.org
Thu Nov 9 05:51:16 PST 2023


https://github.com/kiranchandramohan updated https://github.com/llvm/llvm-project/pull/71679

>From 33ca18cca35022d01cf21aa0ba0d1f15af9a5ae9 Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Wed, 8 Nov 2023 14:10:53 +0000
Subject: [PATCH 1/2] [Flang][OpenMP] Finalize Statement context of clauses
 before op creation

Currently Statement context used during processing of clauses is
finalized after the OpenMP operation creation. This causes the
finalization to go inside the OpenMP region and this can lead to
multiple finalizations in a parallel region. This fix proposes to
finalize them before the OpenMP op creation.

Note: The fix only does this for if clause, it can be extended to others
if the approach is fine.

Fixes #71342
---
 flang/lib/Lower/OpenMP.cpp              | 25 +++++++++----------------
 flang/test/Lower/OpenMP/parallel-if.f90 | 12 ++++++++++++
 2 files changed, 21 insertions(+), 16 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/parallel-if.f90

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index c2568c629e521c1..315d1ea656ed0e7 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -561,8 +561,7 @@ class ClauseProcessor {
   bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
                      llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
   bool
-  processIf(Fortran::lower::StatementContext &stmtCtx,
-            Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
+  processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
             mlir::Value &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
@@ -1672,9 +1671,9 @@ bool ClauseProcessor::processDepend(
 }
 
 bool ClauseProcessor::processIf(
-    Fortran::lower::StatementContext &stmtCtx,
     Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
     mlir::Value &result) const {
+  Fortran::lower::StatementContext stmtCtx;
   bool found = false;
   findRepeatableClause<ClauseTy::If>(
       [&](const ClauseTy::If *ifClause,
@@ -2305,8 +2304,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
 
   ClauseProcessor cp(converter, clauseList);
-  cp.processIf(stmtCtx,
-               Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
+  cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
                ifClauseOperand);
   cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
   cp.processProcBind(procBindKindAttr);
@@ -2359,8 +2357,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
       dependOperands;
 
   ClauseProcessor cp(converter, clauseList);
-  cp.processIf(stmtCtx,
-               Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
+  cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
                ifClauseOperand);
   cp.processAllocate(allocatorOperands, allocateOperands);
   cp.processDefault();
@@ -2417,8 +2414,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
 
   ClauseProcessor cp(converter, clauseList);
-  cp.processIf(stmtCtx,
-               Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
+  cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
                ifClauseOperand);
   cp.processDevice(stmtCtx, deviceOperand);
   cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
@@ -2463,7 +2459,7 @@ genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
   }
 
   ClauseProcessor cp(converter, clauseList);
-  cp.processIf(stmtCtx, directiveName, ifClauseOperand);
+  cp.processIf(directiveName, ifClauseOperand);
   cp.processDevice(stmtCtx, deviceOperand);
   cp.processNowait(nowaitAttr);
   cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
@@ -2587,8 +2583,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
 
   ClauseProcessor cp(converter, clauseList);
-  cp.processIf(stmtCtx,
-               Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
+  cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
                ifClauseOperand);
   cp.processDevice(stmtCtx, deviceOperand);
   cp.processThreadLimit(stmtCtx, threadLimitOperand);
@@ -2742,8 +2737,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
 
   ClauseProcessor cp(converter, clauseList);
-  cp.processIf(stmtCtx,
-               Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
+  cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
                ifClauseOperand);
   cp.processAllocate(allocatorOperands, allocateOperands);
   cp.processDefault();
@@ -3018,8 +3012,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     llvm::SmallVector<mlir::Value> alignedVars, nontemporalVars;
     mlir::Value ifClauseOperand;
     mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand;
-    cp.processIf(stmtCtx,
-                 Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
+    cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
                  ifClauseOperand);
     cp.processSimdlen(simdlenClauseOperand);
     cp.processSafelen(safelenClauseOperand);
diff --git a/flang/test/Lower/OpenMP/parallel-if.f90 b/flang/test/Lower/OpenMP/parallel-if.f90
new file mode 100644
index 000000000000000..661bc0e46598d95
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-if.f90
@@ -0,0 +1,12 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+!CHECK-LABEL: func @_QPtest1
+subroutine test1(a)
+integer :: a(:,:)
+!CHECK: hlfir.destroy
+!CHECK: omp.parallel if
+!$omp parallel if(any(a .eq. 1))
+!CHECK-NOT: hlfir.destroy
+  print *, "Hello"
+!$omp end parallel
+end subroutine

>From 01ea250aad75411a3bf635b3f3605ae686867b92 Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Thu, 9 Nov 2023 13:48:18 +0000
Subject: [PATCH 2/2] [Flang][OpenMP] Move stmtCtx to helper function

---
 flang/lib/Lower/OpenMP.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 315d1ea656ed0e7..fd6698d8da3264c 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1115,7 +1115,6 @@ genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
 
 static mlir::Value getIfClauseOperand(
     Fortran::lower::AbstractConverter &converter,
-    Fortran::lower::StatementContext &stmtCtx,
     const Fortran::parser::OmpClause::If *ifClause,
     Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
     mlir::Location clauseLocation) {
@@ -1126,6 +1125,7 @@ static mlir::Value getIfClauseOperand(
   if (directive && directive.value() != directiveName)
     return nullptr;
 
+  Fortran::lower::StatementContext stmtCtx;
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
   mlir::Value ifVal = fir::getBase(
@@ -1673,13 +1673,12 @@ bool ClauseProcessor::processDepend(
 bool ClauseProcessor::processIf(
     Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
     mlir::Value &result) const {
-  Fortran::lower::StatementContext stmtCtx;
   bool found = false;
   findRepeatableClause<ClauseTy::If>(
       [&](const ClauseTy::If *ifClause,
           const Fortran::parser::CharBlock &source) {
         mlir::Location clauseLocation = converter.genLocation(source);
-        mlir::Value operand = getIfClauseOperand(converter, stmtCtx, ifClause,
+        mlir::Value operand = getIfClauseOperand(converter, ifClause,
                                                  directiveName, clauseLocation);
         // Assume that, at most, a single 'if' clause will be applicable to the
         // given directive.



More information about the flang-commits mailing list