[flang-commits] [flang] [flang][openacc] Allow acc routine at the top level (PR #69936)
via flang-commits
flang-commits at lists.llvm.org
Mon Oct 23 09:04:38 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Some compilers allow the `$acc routine(<name>)` to be placed at the program unit level. To be compatible, this patch enables the use of acc routine at this level. These acc routine directives must have a name.
---
Full diff: https://github.com/llvm/llvm-project/pull/69936.diff
15 Files Affected:
- (modified) flang/docs/OpenACC.md (+1)
- (modified) flang/include/flang/Lower/OpenACC.h (+6)
- (modified) flang/include/flang/Lower/PFTBuilder.h (+14-2)
- (modified) flang/include/flang/Parser/parse-tree.h (+3-1)
- (modified) flang/lib/Lower/Bridge.cpp (+13)
- (modified) flang/lib/Lower/OpenACC.cpp (+11-11)
- (modified) flang/lib/Lower/PFTBuilder.cpp (+24)
- (modified) flang/lib/Parser/program-parsers.cpp (+6-1)
- (modified) flang/lib/Semantics/program-tree.cpp (+4)
- (modified) flang/lib/Semantics/program-tree.h (+1)
- (modified) flang/lib/Semantics/resolve-directives.cpp (+19-10)
- (modified) flang/lib/Semantics/resolve-directives.h (+3-2)
- (modified) flang/lib/Semantics/resolve-names.cpp (+7-1)
- (added) flang/test/Lower/OpenACC/acc-routine02.f90 (+21)
- (added) flang/test/Semantics/OpenACC/acc-routine-validity02.f90 (+17)
``````````diff
diff --git a/flang/docs/OpenACC.md b/flang/docs/OpenACC.md
index e29c5f89f9b2482..41c974f837421b9 100644
--- a/flang/docs/OpenACC.md
+++ b/flang/docs/OpenACC.md
@@ -23,3 +23,4 @@ local:
warning instead of an error as other compiler accepts it.
* The `if` clause accepts scalar integer expression in addition to scalar
logical expression.
+* `!$acc routine` directive can be placed at the top level.
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index c73af0a6eb0f874..409956f0ecb309f 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -37,6 +37,7 @@ namespace Fortran {
namespace parser {
struct OpenACCConstruct;
struct OpenACCDeclarativeConstruct;
+struct OpenACCRoutineConstruct;
} // namespace parser
namespace semantics {
@@ -71,6 +72,11 @@ void genOpenACCDeclarativeConstruct(AbstractConverter &,
StatementContext &,
const parser::OpenACCDeclarativeConstruct &,
AccRoutineInfoMappingList &);
+void genOpenACCRoutineConstruct(AbstractConverter &,
+ Fortran::semantics::SemanticsContext &,
+ mlir::ModuleOp &,
+ const parser::OpenACCRoutineConstruct &,
+ AccRoutineInfoMappingList &);
void finalizeOpenACCRoutineAttachment(mlir::ModuleOp &,
AccRoutineInfoMappingList &);
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 5927fc1915ae34d..6f68dc7c9f525f1 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -135,6 +135,7 @@ using Constructs =
using Directives =
std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
+ parser::OpenACCRoutineConstruct,
parser::OpenACCDeclarativeConstruct, parser::OpenMPConstruct,
parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective>;
@@ -360,7 +361,8 @@ using ProgramVariant =
ReferenceVariant<parser::MainProgram, parser::FunctionSubprogram,
parser::SubroutineSubprogram, parser::Module,
parser::Submodule, parser::SeparateModuleSubprogram,
- parser::BlockData, parser::CompilerDirective>;
+ parser::BlockData, parser::CompilerDirective,
+ parser::OpenACCRoutineConstruct>;
/// A program is a list of program units.
/// These units can be function like, module like, or block data.
struct ProgramUnit : ProgramVariant {
@@ -763,10 +765,20 @@ struct CompilerDirectiveUnit : public ProgramUnit {
CompilerDirectiveUnit(const CompilerDirectiveUnit &) = delete;
};
+// Top level OpenACC routine directives
+struct OpenACCDirectiveUnit : public ProgramUnit {
+ OpenACCDirectiveUnit(const parser::OpenACCRoutineConstruct &directive,
+ const PftNode &parent)
+ : ProgramUnit{directive, parent}, routine{directive} {};
+ OpenACCDirectiveUnit(OpenACCDirectiveUnit &&) = default;
+ OpenACCDirectiveUnit(const OpenACCDirectiveUnit &) = delete;
+ const parser::OpenACCRoutineConstruct &routine;
+};
+
/// A Program is the top-level root of the PFT.
struct Program {
using Units = std::variant<FunctionLikeUnit, ModuleLikeUnit, BlockDataUnit,
- CompilerDirectiveUnit>;
+ CompilerDirectiveUnit, OpenACCDirectiveUnit>;
Program(semantics::CommonBlockList &&commonBlocks)
: commonBlocks{std::move(commonBlocks)} {}
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index a51921f2562b8a9..83c8db936934a03 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -262,6 +262,7 @@ struct PauseStmt;
struct OpenACCConstruct;
struct AccEndCombinedDirective;
struct OpenACCDeclarativeConstruct;
+struct OpenACCRoutineConstruct;
struct OpenMPConstruct;
struct OpenMPDeclarativeConstruct;
struct OmpEndLoopDirective;
@@ -558,7 +559,8 @@ struct ProgramUnit {
common::Indirection<FunctionSubprogram>,
common::Indirection<SubroutineSubprogram>, common::Indirection<Module>,
common::Indirection<Submodule>, common::Indirection<BlockData>,
- common::Indirection<CompilerDirective>>
+ common::Indirection<CompilerDirective>,
+ common::Indirection<OpenACCRoutineConstruct>>
u;
};
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index f26a1aaf0236fa5..ccd54e349786e51 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -316,6 +316,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
globalOmpRequiresSymbol = b.symTab.symbol();
},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+ [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
},
u);
}
@@ -328,6 +329,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
[&](Fortran::lower::pft::BlockDataUnit &b) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+ [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
+ builder = new fir::FirOpBuilder(bridge.getModule(),
+ bridge.getKindMap());
+ Fortran::lower::genOpenACCRoutineConstruct(
+ *this, bridge.getSemanticsContext(), bridge.getModule(),
+ d.routine, accRoutineInfos);
+ builder = nullptr;
+ },
},
u);
}
@@ -2320,6 +2329,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
genFIR(e);
}
+ void genFIR(const Fortran::parser::OpenACCRoutineConstruct &acc) {
+ // Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &)
+ }
+
void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 4fafcebc30d116a..bfef5e2226ed4be 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3143,29 +3143,26 @@ static void attachRoutineInfo(mlir::func::FuncOp func,
mlir::acc::RoutineInfoAttr::get(func.getContext(), routines));
}
-static void
-genACC(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semanticsContext,
- const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
- Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
+void Fortran::lower::genOpenACCRoutineConstruct(
+ Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext, mlir::ModuleOp &mod,
+ const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
+ Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Location loc = converter.genLocation(routineConstruct.source);
std::optional<Fortran::parser::Name> name =
std::get<std::optional<Fortran::parser::Name>>(routineConstruct.t);
const auto &clauses =
std::get<Fortran::parser::AccClauseList>(routineConstruct.t);
-
- mlir::ModuleOp mod = builder.getModule();
mlir::func::FuncOp funcOp;
std::string funcName;
if (name) {
funcName = converter.mangleName(*name->symbol);
- funcOp = builder.getNamedFunction(funcName);
+ funcOp = builder.getNamedFunction(mod, funcName);
} else {
funcOp = builder.getFunction();
funcName = funcOp.getName();
}
-
bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
hasNohost = false;
std::optional<std::string> bindName = std::nullopt;
@@ -3381,8 +3378,11 @@ void Fortran::lower::genOpenACCDeclarativeConstruct(
},
[&](const Fortran::parser::OpenACCRoutineConstruct
&routineConstruct) {
- genACC(converter, semanticsContext, routineConstruct,
- accRoutineInfos);
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ mlir::ModuleOp mod = builder.getModule();
+ Fortran::lower::genOpenACCRoutineConstruct(
+ converter, semanticsContext, mod, routineConstruct,
+ accRoutineInfos);
},
},
accDeclConstruct.u);
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index 97afdaf49b672a9..0946a85dcaddd88 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -241,6 +241,17 @@ class PFTBuilder {
return enterConstructOrDirective(directive);
}
+ bool Pre(const parser::OpenACCRoutineConstruct &directive) {
+ assert(pftParentStack.size() > 0 &&
+ "At least the Program must be a parent");
+ if (pftParentStack.back().isA<lower::pft::Program>()) {
+ addUnit(
+ lower::pft::OpenACCDirectiveUnit(directive, pftParentStack.back()));
+ return false;
+ }
+ return enterConstructOrDirective(directive);
+ }
+
private:
/// Initialize a new module-like unit and make it the builder's focus.
template <typename A>
@@ -1133,6 +1144,9 @@ class PFTDumper {
[&](const lower::pft::CompilerDirectiveUnit &unit) {
dumpCompilerDirectiveUnit(outputStream, unit);
},
+ [&](const lower::pft::OpenACCDirectiveUnit &unit) {
+ dumpOpenACCDirectiveUnit(outputStream, unit);
+ },
},
unit);
}
@@ -1280,6 +1294,16 @@ class PFTDumper {
outputStream << "\nEnd CompilerDirective\n\n";
}
+ void
+ dumpOpenACCDirectiveUnit(llvm::raw_ostream &outputStream,
+ const lower::pft::OpenACCDirectiveUnit &directive) {
+ outputStream << getNodeIndex(directive) << " ";
+ outputStream << "OpenACCDirective: !$acc ";
+ outputStream << directive.get<Fortran::parser::OpenACCRoutineConstruct>()
+ .source.ToString();
+ outputStream << "\nEnd OpenACCDirective\n\n";
+ }
+
template <typename T>
std::size_t getNodeIndex(const T &node) {
auto addr = static_cast<const void *>(&node);
diff --git a/flang/lib/Parser/program-parsers.cpp b/flang/lib/Parser/program-parsers.cpp
index 521ae43097adc6a..e24559bf14f7c92 100644
--- a/flang/lib/Parser/program-parsers.cpp
+++ b/flang/lib/Parser/program-parsers.cpp
@@ -46,6 +46,10 @@ static constexpr auto normalProgramUnit{StartNewSubprogram{} >> programUnit /
static constexpr auto globalCompilerDirective{
construct<ProgramUnit>(indirect(compilerDirective))};
+static constexpr auto globalOpenACCCompilerDirective{
+ construct<ProgramUnit>(indirect(skipStuffBeforeStatement >>
+ "!$ACC "_sptok >> Parser<OpenACCRoutineConstruct>{}))};
+
// R501 program -> program-unit [program-unit]...
// This is the top-level production for the Fortran language.
// F'2018 6.3.1 defines a program unit as a sequence of one or more lines,
@@ -58,7 +62,8 @@ TYPE_PARSER(
"nonstandard usage: empty source file"_port_en_US,
skipStuffBeforeStatement >> !nextCh >>
pure<std::list<ProgramUnit>>()) ||
- some(globalCompilerDirective || normalProgramUnit) /
+ some(globalCompilerDirective || globalOpenACCCompilerDirective ||
+ normalProgramUnit) /
skipStuffBeforeStatement))
// R504 specification-part ->
diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp
index cd631da93d698be..bf773f3810c847b 100644
--- a/flang/lib/Semantics/program-tree.cpp
+++ b/flang/lib/Semantics/program-tree.cpp
@@ -200,6 +200,10 @@ ProgramTree ProgramTree::Build(const parser::CompilerDirective &) {
DIE("ProgramTree::Build() called for CompilerDirective");
}
+ProgramTree ProgramTree::Build(const parser::OpenACCRoutineConstruct &) {
+ DIE("ProgramTree::Build() called for OpenACCRoutineConstruct");
+}
+
const parser::ParentIdentifier &ProgramTree::GetParentId() const {
const auto *stmt{
std::get<const parser::Statement<parser::SubmoduleStmt> *>(stmt_)};
diff --git a/flang/lib/Semantics/program-tree.h b/flang/lib/Semantics/program-tree.h
index 4bf6567c6adfe8b..d49b0405d8b122d 100644
--- a/flang/lib/Semantics/program-tree.h
+++ b/flang/lib/Semantics/program-tree.h
@@ -43,6 +43,7 @@ class ProgramTree {
static ProgramTree Build(const parser::Submodule &);
static ProgramTree Build(const parser::BlockData &);
static ProgramTree Build(const parser::CompilerDirective &);
+ static ProgramTree Build(const parser::OpenACCRoutineConstruct &);
ENUM_CLASS(Kind, // kind of node
Program, Function, Subroutine, MpSubprogram, Module, Submodule, BlockData)
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index f7720fcf43e5768..2a6523f579f2773 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -61,6 +61,9 @@ template <typename T> class DirectiveAttributeVisitor {
? std::nullopt
: std::make_optional<DirContext>(dirContext_.back());
}
+ void PushContext(const parser::CharBlock &source, T dir, Scope &scope) {
+ dirContext_.emplace_back(source, dir, scope);
+ }
void PushContext(const parser::CharBlock &source, T dir) {
dirContext_.emplace_back(source, dir, context_.FindScope(source));
}
@@ -115,8 +118,8 @@ template <typename T> class DirectiveAttributeVisitor {
class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
public:
- explicit AccAttributeVisitor(SemanticsContext &context)
- : DirectiveAttributeVisitor(context) {}
+ explicit AccAttributeVisitor(SemanticsContext &context, Scope *topScope)
+ : DirectiveAttributeVisitor(context), topScope_(topScope) {}
template <typename A> void Walk(const A &x) { parser::Walk(x, *this); }
template <typename A> bool Pre(const A &) { return true; }
@@ -281,6 +284,7 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
const llvm::acc::Clause clause, const parser::AccObjectList &objectList);
void AddRoutineInfoToSymbol(
Symbol &, const parser::OpenACCRoutineConstruct &);
+ Scope *topScope_;
};
// Data-sharing and Data-mapping attributes for data-refs in OpenMP construct
@@ -802,10 +806,6 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCDeclarativeConstruct &x) {
const auto &declDir{
std::get<parser::AccDeclarativeDirective>(declConstruct->t)};
PushContext(declDir.source, llvm::acc::Directive::ACCD_declare);
- } else if (const auto *routineConstruct{
- std::get_if<parser::OpenACCRoutineConstruct>(&x.u)}) {
- const auto &verbatim{std::get<parser::Verbatim>(routineConstruct->t)};
- PushContext(verbatim.source, llvm::acc::Directive::ACCD_routine);
}
ClearDataSharingAttributeObjects();
return true;
@@ -994,6 +994,13 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol(
}
bool AccAttributeVisitor::Pre(const parser::OpenACCRoutineConstruct &x) {
+ const auto &verbatim{std::get<parser::Verbatim>(x.t)};
+ if (topScope_) {
+ PushContext(
+ verbatim.source, llvm::acc::Directive::ACCD_routine, *topScope_);
+ } else {
+ PushContext(verbatim.source, llvm::acc::Directive::ACCD_routine);
+ }
const auto &optName{std::get<std::optional<parser::Name>>(x.t)};
if (optName) {
if (Symbol *sym = ResolveFctName(*optName)) {
@@ -1005,7 +1012,9 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCRoutineConstruct &x) {
(*optName).source);
}
} else {
- AddRoutineInfoToSymbol(*currScope().symbol(), x);
+ if (currScope().symbol()) {
+ AddRoutineInfoToSymbol(*currScope().symbol(), x);
+ }
}
return true;
}
@@ -2190,10 +2199,10 @@ void OmpAttributeVisitor::CheckMultipleAppearances(
}
}
-void ResolveAccParts(
- SemanticsContext &context, const parser::ProgramUnit &node) {
+void ResolveAccParts(SemanticsContext &context, const parser::ProgramUnit &node,
+ Scope *topScope) {
if (context.IsEnabled(common::LanguageFeature::OpenACC)) {
- AccAttributeVisitor{context}.Walk(node);
+ AccAttributeVisitor{context, topScope}.Walk(node);
}
}
diff --git a/flang/lib/Semantics/resolve-directives.h b/flang/lib/Semantics/resolve-directives.h
index 839165aaf30eb81..4aef8ad6c40081a 100644
--- a/flang/lib/Semantics/resolve-directives.h
+++ b/flang/lib/Semantics/resolve-directives.h
@@ -16,11 +16,12 @@ struct ProgramUnit;
} // namespace Fortran::parser
namespace Fortran::semantics {
-
+class Scope;
class SemanticsContext;
// Name resolution for OpenACC and OpenMP directives
-void ResolveAccParts(SemanticsContext &, const parser::ProgramUnit &);
+void ResolveAccParts(
+ SemanticsContext &, const parser::ProgramUnit &, Scope *topScope = {});
void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &);
void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &);
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 90c14806afbf82d..9f7a59b0b454a8a 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8323,6 +8323,11 @@ bool ResolveNamesVisitor::Pre(const parser::ProgramUnit &x) {
// TODO: global directives
return true;
}
+ if (std::holds_alternative<
+ common::Indirection<parser::OpenACCRoutineConstruct>>(x.u)) {
+ ResolveAccParts(context(), x, &topScope_);
+ return false;
+ }
auto root{ProgramTree::Build(x)};
SetScope(topScope_);
ResolveSpecificationParts(root);
@@ -8335,7 +8340,8 @@ bool ResolveNamesVisitor::Pre(const parser::ProgramUnit &x) {
template <typename A> std::set<SourceName> GetUses(const A &x) {
std::set<SourceName> uses;
- if constexpr (!std::is_same_v<A, parser::CompilerDirective>) {
+ if constexpr (!std::is_same_v<A, parser::CompilerDirective> &&
+ !std::is_same_v<A, parser::OpenACCRoutineConstruct>) {
const auto &spec{std::get<parser::SpecificationPart>(x.t)};
const auto &unitUses{std::get<
std::list<parser::Statement<common::Indirection<parser::UseStmt>>>>(
diff --git a/flang/test/Lower/OpenACC/acc-routine02.f90 b/flang/test/Lower/OpenACC/acc-routine02.f90
new file mode 100644
index 000000000000000..d93ece88235ed98
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-routine02.f90
@@ -0,0 +1,21 @@
+! This test checks lowering of OpenACC routine directive.
+
+! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1(a, n)
+ integer :: n
+ real :: a(n)
+end subroutine sub1
+
+!$acc routine(sub1)
+
+program test
+ integer, parameter :: N = 10
+ real :: a(N)
+ call sub1(a, N)
+end program
+
+! CHECK-LABEL: acc.routine @acc_routine_0 func(@_QPsub1)
+
+! CHECK: func.func @_QPsub1(%ar{{.*}}: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>}
diff --git a/flang/test/Semantics/OpenACC/acc-routine-validity02.f90 b/flang/test/Semantics/OpenACC/acc-routine-validity02.f90
new file mode 100644
index 000000000000000..9410a5b17745dd4
--- /dev/null
+++ b/flang/test/Semantics/OpenACC/acc-routine-validity02.f90
@@ -0,0 +1,17 @@
+! RUN: %python %S/../test_errors.py %s %flang -fopenacc
+
+! Check acc routine in the top level.
+
+subroutine sub1(a, n)
+ integer :: n
+ real :: a(n)
+end subroutine sub1
+
+!$acc routine(sub1)
+
+!dir$ value=1
+program test
+ integer, parameter :: N = 10
+ real :: a(N)
+ call sub1(a, N)
+end program
``````````
</details>
https://github.com/llvm/llvm-project/pull/69936
More information about the flang-commits
mailing list