[flang-commits] [flang] [flang][openacc] Split OpenACC context from function context (PR #71591)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue Nov 7 14:23:39 PST 2023


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/71591

The function context was used to finalize some aspect of implicit region created for declare directive. This lead to some issue in the lowering code generation when the function context was finalized. This patch split the OpenACC related finalization into a specific OpenACC context that can be managed separately from the function context. 

>From 04b339de09a59b52913be33c67fde8c520dfed64 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 7 Nov 2023 14:20:36 -0800
Subject: [PATCH] [flang][openacc] Split OpenACC context from function context

---
 flang/include/flang/Lower/Bridge.h             |  3 +++
 flang/lib/Lower/Bridge.cpp                     |  5 ++++-
 flang/lib/Lower/OpenACC.cpp                    | 16 ++++++++--------
 flang/test/Lower/OpenACC/HLFIR/acc-declare.f90 | 15 +++++++++++++++
 4 files changed, 30 insertions(+), 9 deletions(-)

diff --git a/flang/include/flang/Lower/Bridge.h b/flang/include/flang/Lower/Bridge.h
index b4ee77a0b166ec9..5821f1c29d0a6d6 100644
--- a/flang/include/flang/Lower/Bridge.h
+++ b/flang/include/flang/Lower/Bridge.h
@@ -108,6 +108,8 @@ class LoweringBridge {
 
   Fortran::lower::StatementContext &fctCtx() { return functionContext; }
 
+  Fortran::lower::StatementContext &openAccCtx() { return openAccContext; }
+
   bool validModule() { return getModule(); }
 
   //===--------------------------------------------------------------------===//
@@ -138,6 +140,7 @@ class LoweringBridge {
 
   Fortran::semantics::SemanticsContext &semanticsContext;
   Fortran::lower::StatementContext functionContext;
+  Fortran::lower::StatementContext openAccContext;
   const Fortran::common::IntrinsicTypeDefaultKinds &defaultKinds;
   const Fortran::evaluate::IntrinsicProcTable &intrinsics;
   const Fortran::evaluate::TargetCharacteristics &targetCharacteristics;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 8eb5e6865b83252..c1f813b553c8d53 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2384,7 +2384,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &accDecl) {
     genOpenACCDeclarativeConstruct(*this, bridge.getSemanticsContext(),
-                                   bridge.fctCtx(), accDecl, accRoutineInfos);
+                                   bridge.openAccCtx(), accDecl,
+                                   accRoutineInfos);
     for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
       genFIR(e);
   }
@@ -4200,6 +4201,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   void startNewFunction(Fortran::lower::pft::FunctionLikeUnit &funit) {
     assert(!builder && "expected nullptr");
     bridge.fctCtx().pushScope();
+    bridge.openAccCtx().pushScope();
     const Fortran::semantics::Scope &scope = funit.getScope();
     LLVM_DEBUG(llvm::dbgs() << "\n[bridge - startNewFunction]";
                if (auto *sym = scope.symbol()) llvm::dbgs() << " " << *sym;
@@ -4439,6 +4441,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// Finish translation of a function.
   void endNewFunction(Fortran::lower::pft::FunctionLikeUnit &funit) {
     setCurrentPosition(Fortran::lower::pft::stmtSourceLoc(funit.endStmt));
+    bridge.openAccCtx().finalizeAndPop();
     if (funit.isMainProgram()) {
       bridge.fctCtx().finalizeAndPop();
       genExitRoutine();
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 1c045b4273e2607..a097e9a35f3e37c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2989,7 +2989,7 @@ genGlobalCtorsWithModifier(Fortran::lower::AbstractConverter &converter,
 static void
 genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
                      Fortran::semantics::SemanticsContext &semanticsContext,
-                     Fortran::lower::StatementContext &fctCtx,
+                     Fortran::lower::StatementContext &openAccCtx,
                      mlir::Location loc,
                      const Fortran::parser::AccClauseList &accClauseList) {
   llvm::SmallVector<mlir::Value> dataClauseOperands, copyEntryOperands,
@@ -3102,9 +3102,9 @@ genDeclareInFunction(Fortran::lower::AbstractConverter &converter,
                         {});
     builder.setInsertionPointToEnd(&declareOp.getRegion().back());
   }
-  fctCtx.attachCleanup([&builder, declareOp, loc, createEntryOperands,
-                        copyEntryOperands, copyoutEntryOperands,
-                        deviceResidentEntryOperands]() {
+  openAccCtx.attachCleanup([&builder, declareOp, loc, createEntryOperands,
+                            copyEntryOperands, copyoutEntryOperands,
+                            deviceResidentEntryOperands]() {
     auto parentOp = builder.getBlock()->getParentOp();
     if (mlir::isa<mlir::acc::DeclareOp>(parentOp)) {
       builder.create<mlir::acc::TerminatorOp>(loc);
@@ -3164,7 +3164,7 @@ genDeclareInModule(Fortran::lower::AbstractConverter &converter,
 
 static void genACC(Fortran::lower::AbstractConverter &converter,
                    Fortran::semantics::SemanticsContext &semanticsContext,
-                   Fortran::lower::StatementContext &fctCtx,
+                   Fortran::lower::StatementContext &openAccCtx,
                    const Fortran::parser::OpenACCStandaloneDeclarativeConstruct
                        &declareConstruct) {
 
@@ -3182,7 +3182,7 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
     auto funcOp =
         builder.getBlock()->getParent()->getParentOfType<mlir::func::FuncOp>();
     if (funcOp)
-      genDeclareInFunction(converter, semanticsContext, fctCtx,
+      genDeclareInFunction(converter, semanticsContext, openAccCtx,
                            directiveLocation, accClauseList);
     else if (moduleOp)
       genDeclareInModule(converter, moduleOp, accClauseList);
@@ -3449,7 +3449,7 @@ void Fortran::lower::genOpenACCConstruct(
 void Fortran::lower::genOpenACCDeclarativeConstruct(
     Fortran::lower::AbstractConverter &converter,
     Fortran::semantics::SemanticsContext &semanticsContext,
-    Fortran::lower::StatementContext &fctCtx,
+    Fortran::lower::StatementContext &openAccCtx,
     const Fortran::parser::OpenACCDeclarativeConstruct &accDeclConstruct,
     Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
 
@@ -3457,7 +3457,7 @@ void Fortran::lower::genOpenACCDeclarativeConstruct(
       common::visitors{
           [&](const Fortran::parser::OpenACCStandaloneDeclarativeConstruct
                   &standaloneDeclarativeConstruct) {
-            genACC(converter, semanticsContext, fctCtx,
+            genACC(converter, semanticsContext, openAccCtx,
                    standaloneDeclarativeConstruct);
           },
           [&](const Fortran::parser::OpenACCRoutineConstruct
diff --git a/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90 b/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
index 6ebdd39802fef0f..92daa0314bcd97e 100644
--- a/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
+++ b/flang/test/Lower/OpenACC/HLFIR/acc-declare.f90
@@ -226,6 +226,21 @@ subroutine acc_declare_deviceptr2()
 ! HLFIR: %[[DEVICEPTR:.*]] = acc.deviceptr varPtr(%[[DECL]]#1 : !fir.ref<!fir.array<100xf32>>)   bounds(%{{.*}}) -> !fir.ref<!fir.array<100xf32>> {name = "dataparam"}
 ! ALL: acc.declare dataOperands(%[[DEVICEPTR]] : !fir.ref<!fir.array<100xf32>>)
 
+  function acc_declare_in_func()
+    real :: a(1024)
+    !$acc declare device_resident(a)
+  end function acc_declare_in_func
+
+! ALL-LABEL: func.func @_QMacc_declarePacc_declare_in_func() -> f32 {
+! HLFIR: %[[DEVICE_RESIDENT:.*]] = acc.declare_device_resident varPtr(%{{.*}}#1 : !fir.ref<!fir.array<1024xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<1024xf32>> {name = "a"}
+! HLFIR: acc.declare dataOperands(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>) {
+! HLFIR:   acc.terminator
+! HLFIR: }
+! HLFIR: acc.delete accPtr(%[[DEVICE_RESIDENT]] : !fir.ref<!fir.array<1024xf32>>) bounds(%6) {dataClause = #acc<data_clause acc_declare_device_resident>, name = "a"}
+! HLFIR: %[[LOAD:.*]] = fir.load %{{.*}}#1 : !fir.ref<f32>
+! HLFIR: return %[[LOAD]] : f32
+! ALL: }
+
 end module
 
 module acc_declare_allocatable_test



More information about the flang-commits mailing list