[llvm-branch-commits] [flang] [llvm] [Frontend][OpenMP] Refactor getLeafConstructs, add getCompoundConstruct (PR #87247)
Krzysztof Parzyszek via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Apr 17 07:57:01 PDT 2024
https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/87247
>From f725face892cef4faf9f17d4b549541bdbcd7e08 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Fri, 29 Mar 2024 09:20:41 -0500
Subject: [PATCH 1/3] [flang][OpenMP] Move clause/object conversion to happen
early, in genOMP
This removes the last use of genOmpObectList2, which has now been removed.
---
flang/lib/Lower/OpenMP/ClauseProcessor.h | 5 +-
flang/lib/Lower/OpenMP/DataSharingProcessor.h | 5 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 424 +++++++++---------
flang/lib/Lower/OpenMP/Utils.cpp | 30 +-
flang/lib/Lower/OpenMP/Utils.h | 6 +-
5 files changed, 218 insertions(+), 252 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index db7a1b8335f818..f4d659b70cfee7 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -49,9 +49,8 @@ class ClauseProcessor {
public:
ClauseProcessor(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses)
- : converter(converter), semaCtx(semaCtx),
- clauses(makeClauses(clauses, semaCtx)) {}
+ const List<Clause> &clauses)
+ : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
// 'Unique' clauses: They can appear at most once in the clause list.
bool processCollapse(
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
index c11ee299c5d085..ef7b14327278e3 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
@@ -78,13 +78,12 @@ class DataSharingProcessor {
public:
DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &opClauseList,
+ const List<Clause> &clauses,
Fortran::lower::pft::Evaluation &eval,
bool useDelayedPrivatization = false,
Fortran::lower::SymMap *symTable = nullptr)
: hasLastPrivateOp(false), converter(converter),
- firOpBuilder(converter.getFirOpBuilder()),
- clauses(omp::makeClauses(opClauseList, semaCtx)), eval(eval),
+ firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval),
useDelayedPrivatization(useDelayedPrivatization), symTable(symTable) {}
// Privatisation is split into two steps.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index edae453972d3d9..23dc25ac1ae9a1 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -17,6 +17,7 @@
#include "DataSharingProcessor.h"
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
+#include "Utils.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
@@ -310,14 +311,15 @@ static void getDeclareTargetInfo(
} else if (const auto *clauseList{
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
spec.u)}) {
- if (clauseList->v.empty()) {
+ List<Clause> clauses = makeClauses(*clauseList, semaCtx);
+ if (clauses.empty()) {
// Case: declare target, implicit capture of function
symbolAndClause.emplace_back(
mlir::omp::DeclareTargetCaptureClause::to,
eval.getOwningProcedure()->getSubprogramSymbol());
}
- ClauseProcessor cp(converter, semaCtx, *clauseList);
+ ClauseProcessor cp(converter, semaCtx, clauses);
cp.processDeviceType(clauseOps);
cp.processEnter(symbolAndClause);
cp.processLink(symbolAndClause);
@@ -597,14 +599,11 @@ static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) {
// TODO: Generate the reduction operation during lowering instead of creating
// and removing operations since this is not a robust approach. Also, removing
// ops in the builder (instead of a rewriter) is probably not the best approach.
-static void
-genOpenMPReduction(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauseList) {
+static void genOpenMPReduction(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- List<Clause> clauses{makeClauses(clauseList, semaCtx)};
-
for (const Clause &clause : clauses) {
if (const auto &reductionClause =
std::get_if<clause::Reduction>(&clause.u)) {
@@ -812,7 +811,7 @@ struct OpWithBodyGenInfo {
return *this;
}
- OpWithBodyGenInfo &setClauses(const Fortran::parser::OmpClauseList *value) {
+ OpWithBodyGenInfo &setClauses(const List<Clause> *value) {
clauses = value;
return *this;
}
@@ -848,7 +847,7 @@ struct OpWithBodyGenInfo {
/// [in] is this an outer operation - prevents privatization.
bool outerCombined = false;
/// [in] list of clauses to process.
- const Fortran::parser::OmpClauseList *clauses = nullptr;
+ const List<Clause> *clauses = nullptr;
/// [in] if provided, processes the construct's data-sharing attributes.
DataSharingProcessor *dsp = nullptr;
/// [in] if provided, list of reduction symbols
@@ -1226,36 +1225,33 @@ static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) {
// Code generation functions for clauses
//===----------------------------------------------------------------------===//
-static void genCriticalDeclareClauses(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- mlir::omp::CriticalClauseOps &clauseOps, llvm::StringRef name) {
+static void
+genCriticalDeclareClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses, mlir::Location loc,
+ mlir::omp::CriticalClauseOps &clauseOps,
+ llvm::StringRef name) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processHint(clauseOps);
clauseOps.nameAttr =
mlir::StringAttr::get(converter.getFirOpBuilder().getContext(), name);
}
-static void genFlushClauses(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx,
- const std::optional<Fortran::parser::OmpObjectList> &objects,
- const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>
- &clauses,
- mlir::Location loc, llvm::SmallVectorImpl<mlir::Value> &operandRange) {
- if (objects)
- genObjectList2(*objects, converter, operandRange);
-
- if (clauses && clauses->size() > 0)
+static void genFlushClauses(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ const ObjectList &objects,
+ const List<Clause> &clauses, mlir::Location loc,
+ llvm::SmallVectorImpl<mlir::Value> &operandRange) {
+ genObjectList(objects, converter, operandRange);
+
+ if (clauses.size() > 0)
TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause");
}
static void
genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::OrderedRegionClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processTODO<clause::Simd>(loc, llvm::omp::Directive::OMPD_ordered);
@@ -1264,9 +1260,9 @@ genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter,
static void genParallelClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- bool processReduction, mlir::omp::ParallelClauseOps &clauseOps,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, bool processReduction,
+ mlir::omp::ParallelClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
@@ -1286,8 +1282,7 @@ static void genParallelClauses(
static void genSectionsClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
bool clausesFromBeginSections,
mlir::omp::SectionsClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
@@ -1304,9 +1299,8 @@ static void genSimdLoopClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- mlir::omp::SimdLoopClauseOps &clauseOps,
+ Fortran::lower::pft::Evaluation &eval, const List<Clause> &clauses,
+ mlir::Location loc, mlir::omp::SimdLoopClauseOps &clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processCollapse(loc, eval, clauseOps, iv);
@@ -1324,9 +1318,8 @@ static void genSimdLoopClauses(
static void genSingleClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &beginClauses,
- const Fortran::parser::OmpClauseList &endClauses,
- mlir::Location loc,
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc,
mlir::omp::SingleClauseOps &clauseOps) {
ClauseProcessor bcp(converter, semaCtx, beginClauses);
bcp.processAllocate(clauseOps);
@@ -1340,9 +1333,8 @@ static void genSingleClauses(Fortran::lower::AbstractConverter &converter,
static void genTargetClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- bool processHostOnlyClauses, bool processReduction,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, bool processHostOnlyClauses, bool processReduction,
mlir::omp::TargetClauseOps &clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &mapSyms,
llvm::SmallVectorImpl<mlir::Location> &mapSymLocs,
@@ -1368,9 +1360,8 @@ static void genTargetClauses(
static void genTargetDataClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- mlir::omp::TargetDataClauseOps &clauseOps,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, mlir::omp::TargetDataClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) {
@@ -1401,9 +1392,8 @@ static void genTargetDataClauses(
static void genTargetEnterExitUpdateDataClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses, mlir::Location loc,
- llvm::omp::Directive directive,
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
+ mlir::Location loc, llvm::omp::Directive directive,
mlir::omp::TargetEnterExitUpdateDataClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processDepend(clauseOps);
@@ -1422,8 +1412,7 @@ static void genTargetEnterExitUpdateDataClauses(
static void genTaskClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
@@ -1442,8 +1431,7 @@ static void genTaskClauses(Fortran::lower::AbstractConverter &converter,
static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
@@ -1453,8 +1441,7 @@ static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter,
static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskwaitClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processTODO<clause::Depend, clause::Nowait>(
@@ -1464,8 +1451,7 @@ static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter,
static void genTeamsClauses(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::OmpClauseList &clauses,
- mlir::Location loc,
+ const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TeamsClauseOps &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
@@ -1482,9 +1468,8 @@ static void genWsloopClauses(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::StatementContext &stmtCtx,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauses,
- const Fortran::parser::OmpClauseList *endClauses, mlir::Location loc,
+ Fortran::lower::pft::Evaluation &eval, const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc,
mlir::omp::WsloopClauseOps &clauseOps,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
@@ -1501,8 +1486,8 @@ static void genWsloopClauses(
if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
- if (endClauses) {
- ClauseProcessor ecp(converter, semaCtx, *endClauses);
+ if (!endClauses.empty()) {
+ ClauseProcessor ecp(converter, semaCtx, endClauses);
ecp.processNowait(clauseOps);
}
@@ -1525,8 +1510,7 @@ static mlir::omp::CriticalOp
genCriticalOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
const std::optional<Fortran::parser::Name> &name) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::FlatSymbolRefAttr nameAttr;
@@ -1537,7 +1521,7 @@ genCriticalOp(Fortran::lower::AbstractConverter &converter,
auto global = mod.lookupSymbol<mlir::omp::CriticalDeclareOp>(nameStr);
if (!global) {
mlir::omp::CriticalClauseOps clauseOps;
- genCriticalDeclareClauses(converter, semaCtx, clauseList, loc, clauseOps,
+ genCriticalDeclareClauses(converter, semaCtx, clauses, loc, clauseOps,
nameStr);
mlir::OpBuilder modBuilder(mod.getBodyRegion());
@@ -1556,8 +1540,7 @@ static mlir::omp::DistributeOp
genDistributeOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
TODO(loc, "Distribute construct");
return nullptr;
}
@@ -1566,12 +1549,9 @@ static mlir::omp::FlushOp
genFlushOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const std::optional<Fortran::parser::OmpObjectList> &objectList,
- const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>
- &clauseList) {
+ const ObjectList &objects, const List<Clause> &clauses) {
llvm::SmallVector<mlir::Value> operandRange;
- genFlushClauses(converter, semaCtx, objectList, clauseList, loc,
- operandRange);
+ genFlushClauses(converter, semaCtx, objects, clauses, loc, operandRange);
return converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
converter.getCurrentLocation(), operandRange);
@@ -1591,7 +1571,7 @@ static mlir::omp::OrderedOp
genOrderedOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ const List<Clause> &clauses) {
TODO(loc, "OMPD_ordered");
return nullptr;
}
@@ -1600,10 +1580,9 @@ static mlir::omp::OrderedRegionOp
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
mlir::omp::OrderedRegionClauseOps clauseOps;
- genOrderedRegionClauses(converter, semaCtx, clauseList, loc, clauseOps);
+ genOrderedRegionClauses(converter, semaCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::OrderedRegionOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested),
@@ -1615,8 +1594,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1624,7 +1602,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms;
llvm::SmallVector<mlir::Type> reductionTypes;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
- genParallelClauses(converter, semaCtx, stmtCtx, clauseList, loc,
+ genParallelClauses(converter, semaCtx, stmtCtx, clauses, loc,
/*processReduction=*/!outerCombined, clauseOps,
reductionTypes, reductionSyms);
@@ -1637,7 +1615,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
- .setClauses(&clauseList)
+ .setClauses(&clauses)
.setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(reductionCallback);
@@ -1645,7 +1623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
bool privatize = !outerCombined;
- DataSharingProcessor dsp(converter, semaCtx, clauseList, eval,
+ DataSharingProcessor dsp(converter, semaCtx, clauses, eval,
/*useDelayedPrivatization=*/true, &symTable);
if (privatize)
@@ -1692,14 +1670,13 @@ static mlir::omp::SectionOp
genSectionOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
// Currently only private/firstprivate clause is handled, and
// all privatization is done within `omp.section` operations.
return genOpWithBody<mlir::omp::SectionOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&clauseList));
+ .setClauses(&clauses));
}
static mlir::omp::SectionsOp
@@ -1716,18 +1693,17 @@ static mlir::omp::SimdLoopOp
genSimdLoopOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
- DataSharingProcessor dsp(converter, semaCtx, clauseList, eval);
+ const List<Clause> &clauses) {
+ DataSharingProcessor dsp(converter, semaCtx, clauses, eval);
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
mlir::omp::SimdLoopClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
- genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauseList, loc,
- clauseOps, iv);
+ genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauses, loc, clauseOps,
+ iv);
- auto *nestedEval =
- getCollapsedLoopEval(eval, Fortran::lower::getCollapseValue(clauseList));
+ auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(clauses));
auto ivCallback = [&](mlir::Operation *op) {
return genLoopVars(op, converter, loc, iv);
@@ -1735,7 +1711,7 @@ genSimdLoopOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::SimdLoopOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&clauseList)
+ .setClauses(&clauses)
.setDataSharingProcessor(&dsp)
.setGenRegionEntryCb(ivCallback),
clauseOps);
@@ -1745,17 +1721,16 @@ static mlir::omp::SingleOp
genSingleOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList &endClauseList) {
+ mlir::Location loc, const List<Clause> &beginClauses,
+ const List<Clause> &endClauses) {
mlir::omp::SingleClauseOps clauseOps;
- genSingleClauses(converter, semaCtx, beginClauseList, endClauseList, loc,
+ genSingleClauses(converter, semaCtx, beginClauses, endClauses, loc,
clauseOps);
return genOpWithBody<mlir::omp::SingleOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&beginClauseList),
+ .setClauses(&beginClauses),
clauseOps);
}
@@ -1763,8 +1738,7 @@ static mlir::omp::TargetOp
genTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
bool outerCombined = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1777,7 +1751,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms;
llvm::SmallVector<mlir::Location> mapSymLocs;
llvm::SmallVector<mlir::Type> mapSymTypes;
- genTargetClauses(converter, semaCtx, stmtCtx, clauseList, loc,
+ genTargetClauses(converter, semaCtx, stmtCtx, clauses, loc,
processHostOnlyClauses, /*processReduction=*/outerCombined,
clauseOps, mapSyms, mapSymLocs, mapSymTypes);
@@ -1875,14 +1849,13 @@ static mlir::omp::TargetDataOp
genTargetDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TargetDataClauseOps clauseOps;
llvm::SmallVector<mlir::Type> useDeviceTypes;
llvm::SmallVector<mlir::Location> useDeviceLocs;
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms;
- genTargetDataClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps,
+ genTargetDataClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps,
useDeviceTypes, useDeviceLocs, useDeviceSyms);
auto targetDataOp =
@@ -1894,11 +1867,11 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
return targetDataOp;
}
-template <typename OpTy>
-static OpTy genTargetEnterExitUpdateDataOp(
- Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+template <typename OpTy> static OpTy
+genTargetEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ mlir::Location loc,
+ const List<Clause> &clauses) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
@@ -1915,8 +1888,8 @@ static OpTy genTargetEnterExitUpdateDataOp(
}
mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps;
- genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauseList,
- loc, directive, clauseOps);
+ genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauses, loc,
+ directive, clauseOps);
return firOpBuilder.create<OpTy>(loc, clauseOps);
}
@@ -1925,16 +1898,15 @@ static mlir::omp::TaskOp
genTaskOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TaskClauseOps clauseOps;
- genTaskClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps);
+ genTaskClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::TaskOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&clauseList),
+ .setClauses(&clauses),
clauseOps);
}
@@ -1942,15 +1914,14 @@ static mlir::omp::TaskgroupOp
genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::Location loc, const List<Clause> &clauses) {
mlir::omp::TaskgroupClauseOps clauseOps;
- genTaskgroupClauses(converter, semaCtx, clauseList, loc, clauseOps);
+ genTaskgroupClauses(converter, semaCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
- .setClauses(&clauseList),
+ .setClauses(&clauses),
clauseOps);
}
@@ -1958,7 +1929,7 @@ static mlir::omp::TaskloopOp
genTaskloopOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ const List<Clause> &clauses) {
TODO(loc, "Taskloop construct");
}
@@ -1966,9 +1937,9 @@ static mlir::omp::TaskwaitOp
genTaskwaitOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &clauseList) {
+ const List<Clause> &clauses) {
mlir::omp::TaskwaitClauseOps clauseOps;
- genTaskwaitClauses(converter, semaCtx, clauseList, loc, clauseOps);
+ genTaskwaitClauses(converter, semaCtx, clauses, loc, clauseOps);
return converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(loc,
clauseOps);
}
@@ -1984,17 +1955,17 @@ static mlir::omp::TeamsOp
genTeamsOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList,
+ mlir::Location loc, const List<Clause> &clauses,
bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
mlir::omp::TeamsClauseOps clauseOps;
- genTeamsClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps);
+ genTeamsClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps);
return genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
- .setClauses(&clauseList),
+ .setClauses(&clauses),
clauseOps);
}
@@ -2002,9 +1973,8 @@ static mlir::omp::WsloopOp
genWsloopOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList) {
- DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval);
+ const List<Clause> &beginClauses, const List<Clause> &endClauses) {
+ DataSharingProcessor dsp(converter, semaCtx, beginClauses, eval);
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
@@ -2012,12 +1982,10 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
llvm::SmallVector<mlir::Type> reductionTypes;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
- genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauseList,
- endClauseList, loc, clauseOps, iv, reductionTypes,
- reductionSyms);
+ genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauses, endClauses,
+ loc, clauseOps, iv, reductionTypes, reductionSyms);
- auto *nestedEval = getCollapsedLoopEval(
- eval, Fortran::lower::getCollapseValue(beginClauseList));
+ auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(beginClauses));
auto ivCallback = [&](mlir::Operation *op) {
return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms,
@@ -2026,7 +1994,7 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
return genOpWithBody<mlir::omp::WsloopOp>(
OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&beginClauseList)
+ .setClauses(&beginClauses)
.setDataSharingProcessor(&dsp)
.setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(ivCallback),
@@ -2041,8 +2009,8 @@ static void genCompositeDistributeParallelDo(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
}
@@ -2050,8 +2018,8 @@ static void genCompositeDistributeParallelDoSimd(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
}
@@ -2059,8 +2027,8 @@ static void genCompositeDistributeSimd(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) {
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses, mlir::Location loc) {
TODO(loc, "Composite DISTRIBUTE SIMD");
}
@@ -2068,10 +2036,10 @@ static void
genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList,
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses,
mlir::Location loc) {
- ClauseProcessor cp(converter, semaCtx, beginClauseList);
+ ClauseProcessor cp(converter, semaCtx, beginClauses);
cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
clause::Order, clause::Safelen, clause::Simdlen>(
loc, llvm::omp::OMPD_do_simd);
@@ -2083,15 +2051,15 @@ genCompositeDoSimd(Fortran::lower::AbstractConverter &converter,
// When support for vectorization is enabled, then we need to add handling of
// if clause. Currently if clause can be skipped because we always assume
// SIMD length = 1.
- genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList);
+ genWsloopOp(converter, semaCtx, eval, loc, beginClauses, endClauses);
}
static void
genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OmpClauseList &beginClauseList,
- const Fortran::parser::OmpClauseList *endClauseList,
+ const List<Clause> &beginClauses,
+ const List<Clause> &endClauses,
mlir::Location loc) {
TODO(loc, "Composite TASKLOOP SIMD");
}
@@ -2201,8 +2169,9 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
const auto &directive =
std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
simpleStandaloneConstruct.t);
- const auto &clauseList =
- std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t);
+ List<Clause> clauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t),
+ semaCtx);
mlir::Location currentLocation = converter.genLocation(directive.source);
switch (directive.v) {
@@ -2212,29 +2181,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
genBarrierOp(converter, semaCtx, eval, currentLocation);
break;
case llvm::omp::Directive::OMPD_taskwait:
- genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauseList);
+ genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_taskyield:
genTaskyieldOp(converter, semaCtx, eval, currentLocation);
break;
case llvm::omp::Directive::OMPD_target_data:
genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true,
- currentLocation, clauseList);
+ currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_target_enter_data:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetEnterDataOp>(
- converter, semaCtx, currentLocation, clauseList);
+ converter, semaCtx, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_target_exit_data:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetExitDataOp>(
- converter, semaCtx, currentLocation, clauseList);
+ converter, semaCtx, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_target_update:
genTargetEnterExitUpdateDataOp<mlir::omp::TargetUpdateOp>(
- converter, semaCtx, currentLocation, clauseList);
+ converter, semaCtx, currentLocation, clauses);
break;
case llvm::omp::Directive::OMPD_ordered:
- genOrderedOp(converter, semaCtx, eval, currentLocation, clauseList);
+ genOrderedOp(converter, semaCtx, eval, currentLocation, clauses);
break;
}
}
@@ -2251,8 +2220,14 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const auto &clauseList =
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
flushConstruct.t);
+ ObjectList objects =
+ objectList ? makeObjects(*objectList, semaCtx) : ObjectList{};
+ List<Clause> clauses =
+ clauseList ? makeList(*clauseList,
+ [&](auto &&s) { return makeClause(s.v, semaCtx); })
+ : List<Clause>{};
mlir::Location currentLocation = converter.genLocation(verbatim.source);
- genFlushOp(converter, semaCtx, eval, currentLocation, objectList, clauseList);
+ genFlushOp(converter, semaCtx, eval, currentLocation, objects, clauses);
}
static void
@@ -2357,44 +2332,44 @@ genOMP(Fortran::lower::AbstractConverter &converter,
converter.genLocation(beginBlockDirective.source);
const auto origDirective =
std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t).v;
- const auto &beginClauseList =
- std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
- const auto &endClauseList =
- std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t);
+ List<Clause> beginClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t), semaCtx);
+ List<Clause> endClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t), semaCtx);
assert(llvm::omp::blockConstructSet.test(origDirective) &&
"Expected block construct");
- for (const Fortran::parser::OmpClause &clause : beginClauseList.v) {
+ for (const Clause &clause : beginClauses) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
- if (!std::get_if<Fortran::parser::OmpClause::If>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::NumThreads>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Allocate>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Default>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Final>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Priority>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Firstprivate>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Simd>(&clause.u)) {
+ if (!std::get_if<clause::If>(&clause.u) &&
+ !std::get_if<clause::NumThreads>(&clause.u) &&
+ !std::get_if<clause::ProcBind>(&clause.u) &&
+ !std::get_if<clause::Allocate>(&clause.u) &&
+ !std::get_if<clause::Default>(&clause.u) &&
+ !std::get_if<clause::Final>(&clause.u) &&
+ !std::get_if<clause::Priority>(&clause.u) &&
+ !std::get_if<clause::Reduction>(&clause.u) &&
+ !std::get_if<clause::Depend>(&clause.u) &&
+ !std::get_if<clause::Private>(&clause.u) &&
+ !std::get_if<clause::Firstprivate>(&clause.u) &&
+ !std::get_if<clause::Copyin>(&clause.u) &&
+ !std::get_if<clause::Shared>(&clause.u) &&
+ !std::get_if<clause::Threads>(&clause.u) &&
+ !std::get_if<clause::Map>(&clause.u) &&
+ !std::get_if<clause::UseDevicePtr>(&clause.u) &&
+ !std::get_if<clause::UseDeviceAddr>(&clause.u) &&
+ !std::get_if<clause::ThreadLimit>(&clause.u) &&
+ !std::get_if<clause::NumTeams>(&clause.u) &&
+ !std::get_if<clause::Simd>(&clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
}
}
- for (const auto &clause : endClauseList.v) {
+ for (const Clause &clause : endClauses) {
mlir::Location clauseLocation = converter.genLocation(clause.source);
- if (!std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::Copyprivate>(&clause.u))
+ if (!std::get_if<clause::Nowait>(&clause.u) &&
+ !std::get_if<clause::Copyprivate>(&clause.u))
TODO(clauseLocation, "OpenMP Block construct clause");
}
@@ -2413,44 +2388,44 @@ genOMP(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_ordered:
// 2.17.9 ORDERED construct.
genOrderedRegionOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_parallel:
// 2.6 PARALLEL construct.
genParallelOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, beginClauseList, outerCombined);
+ currentLocation, beginClauses, outerCombined);
break;
case llvm::omp::Directive::OMPD_single:
// 2.8.2 SINGLE construct.
genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, endClauseList);
+ beginClauses, endClauses);
break;
case llvm::omp::Directive::OMPD_target:
// 2.12.5 TARGET construct.
genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, outerCombined);
+ beginClauses, outerCombined);
break;
case llvm::omp::Directive::OMPD_target_data:
// 2.12.2 TARGET DATA construct.
genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_task:
// 2.10.1 TASK construct.
genTaskOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_taskgroup:
// 2.17.6 TASKGROUP construct.
genTaskgroupOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_teams:
// 2.7 TEAMS construct.
// FIXME Pass the outerCombined argument or rename it to better describe
// what it represents if it must always be `false` in this context.
genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_workshare:
// 2.8.3 WORKSHARE construct.
@@ -2458,7 +2433,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
// implementation for this feature will come later. For the codes
// that use this construct, add a single construct for now.
genSingleOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, endClauseList);
+ beginClauses, endClauses);
break;
default:
llvm_unreachable("Unexpected block construct");
@@ -2476,11 +2451,12 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
const auto &cd =
std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
- const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
+ List<Clause> clauses =
+ makeClauses(std::get<Fortran::parser::OmpClauseList>(cd.t), semaCtx);
const auto &name = std::get<std::optional<Fortran::parser::Name>>(cd.t);
mlir::Location currentLocation = converter.getCurrentLocation();
genCriticalOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
- clauseList, name);
+ clauses, name);
}
static void
@@ -2499,8 +2475,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
const auto &beginLoopDirective =
std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t);
- const auto &beginClauseList =
- std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t);
+ List<Clause> beginClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
mlir::Location currentLocation =
converter.genLocation(beginLoopDirective.source);
const auto origDirective =
@@ -2509,15 +2485,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
assert(llvm::omp::loopConstructSet.test(origDirective) &&
"Expected loop construct");
- const auto *endClauseList = [&]() {
- using RetTy = const Fortran::parser::OmpClauseList *;
+ List<Clause> endClauses = [&]() {
if (auto &endLoopDirective =
std::get<std::optional<Fortran::parser::OmpEndLoopDirective>>(
loopConstruct.t)) {
- return RetTy(
- &std::get<Fortran::parser::OmpClauseList>((*endLoopDirective).t));
+ return makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(endLoopDirective->t),
+ semaCtx);
}
- return RetTy();
+ return List<Clause>{};
}();
std::optional<llvm::omp::Directive> nextDir = origDirective;
@@ -2530,29 +2506,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_distribute_parallel_do:
// 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct.
genCompositeDistributeParallelDo(converter, semaCtx, eval,
- beginClauseList, endClauseList,
+ beginClauses, endClauses,
currentLocation);
break;
case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
// 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct.
genCompositeDistributeParallelDoSimd(converter, semaCtx, eval,
- beginClauseList, endClauseList,
+ beginClauses, endClauses,
currentLocation);
break;
case llvm::omp::Directive::OMPD_distribute_simd:
// 2.9.4.2 DISTRIBUTE SIMD construct.
- genCompositeDistributeSimd(converter, semaCtx, eval, beginClauseList,
- endClauseList, currentLocation);
+ genCompositeDistributeSimd(converter, semaCtx, eval, beginClauses,
+ endClauses, currentLocation);
break;
case llvm::omp::Directive::OMPD_do_simd:
// 2.9.3.2 Worksharing-Loop SIMD construct.
- genCompositeDoSimd(converter, semaCtx, eval, beginClauseList,
- endClauseList, currentLocation);
+ genCompositeDoSimd(converter, semaCtx, eval, beginClauses,
+ endClauses, currentLocation);
break;
case llvm::omp::Directive::OMPD_taskloop_simd:
// 2.10.3 TASKLOOP SIMD construct.
- genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauseList,
- endClauseList, currentLocation);
+ genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauses,
+ endClauses, currentLocation);
break;
default:
llvm_unreachable("Unexpected composite construct");
@@ -2563,12 +2539,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_distribute:
// 2.9.4.1 DISTRIBUTE construct.
genDistributeOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_do:
// 2.9.2 Worksharing-Loop construct.
- genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList,
- endClauseList);
+ genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauses,
+ endClauses);
break;
case llvm::omp::Directive::OMPD_parallel:
// 2.6 PARALLEL construct.
@@ -2577,24 +2553,24 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
// Maybe rename the argument if it represents something else or
// initialize it properly.
genParallelOp(converter, symTable, semaCtx, eval, genNested,
- currentLocation, beginClauseList,
+ currentLocation, beginClauses,
/*outerCombined=*/true);
break;
case llvm::omp::Directive::OMPD_simd:
// 2.9.3.1 SIMD construct.
genSimdLoopOp(converter, semaCtx, eval, currentLocation,
- beginClauseList);
- genOpenMPReduction(converter, semaCtx, beginClauseList);
+ beginClauses);
+ genOpenMPReduction(converter, semaCtx, beginClauses);
break;
case llvm::omp::Directive::OMPD_target:
// 2.12.5 TARGET construct.
genTargetOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, /*outerCombined=*/true);
+ beginClauses, /*outerCombined=*/true);
break;
case llvm::omp::Directive::OMPD_taskloop:
// 2.10.2 TASKLOOP construct.
genTaskloopOp(converter, semaCtx, eval, currentLocation,
- beginClauseList);
+ beginClauses);
break;
case llvm::omp::Directive::OMPD_teams:
// 2.7 TEAMS construct.
@@ -2603,7 +2579,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
// Maybe rename the argument if it represents something else or
// initialize it properly.
genTeamsOp(converter, semaCtx, eval, genNested, currentLocation,
- beginClauseList, /*outerCombined=*/true);
+ beginClauses, /*outerCombined=*/true);
break;
case llvm::omp::Directive::OMPD_loop:
case llvm::omp::Directive::OMPD_masked:
@@ -2639,14 +2615,15 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) {
const auto &beginSectionsDirective =
std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t);
- const auto &beginClauseList =
- std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t);
+ List<Clause> beginClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t),
+ semaCtx);
// Process clauses before optional omp.parallel, so that new variables are
// allocated outside of the parallel region
mlir::Location currentLocation = converter.getCurrentLocation();
mlir::omp::SectionsClauseOps clauseOps;
- genSectionsClauses(converter, semaCtx, beginClauseList, currentLocation,
+ genSectionsClauses(converter, semaCtx, beginClauses, currentLocation,
/*clausesFromBeginSections=*/true, clauseOps);
// Parallel wrapper of PARALLEL SECTIONS construct
@@ -2655,14 +2632,15 @@ genOMP(Fortran::lower::AbstractConverter &converter,
.v;
if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
genParallelOp(converter, symTable, semaCtx, eval,
- /*genNested=*/false, currentLocation, beginClauseList,
+ /*genNested=*/false, currentLocation, beginClauses,
/*outerCombined=*/true);
} else {
const auto &endSectionsDirective =
std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t);
- const auto &endClauseList =
- std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t);
- genSectionsClauses(converter, semaCtx, endClauseList, currentLocation,
+ List<Clause> endClauses = makeClauses(
+ std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t),
+ semaCtx);
+ genSectionsClauses(converter, semaCtx, endClauses, currentLocation,
/*clausesFromBeginSections=*/false, clauseOps);
}
@@ -2678,7 +2656,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) {
symTable.pushScope();
genSectionOp(converter, semaCtx, neval, /*genNested=*/true, currentLocation,
- beginClauseList);
+ beginClauses);
symTable.popScope();
firOpBuilder.restoreInsertionPoint(ip);
}
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index b9c0660aa4da8e..da3f2be73e5095 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -36,6 +36,17 @@ namespace Fortran {
namespace lower {
namespace omp {
+int64_t getCollapseValue(const List<Clause> &clauses) {
+ auto iter = llvm::find_if(clauses, [](const Clause &clause) {
+ return clause.id == llvm::omp::Clause::OMPC_collapse;
+ });
+ if (iter != clauses.end()) {
+ const auto &collapse = std::get<clause::Collapse>(iter->u);
+ return evaluate::ToInt64(collapse.v).value();
+ }
+ return 1;
+}
+
void genObjectList(const ObjectList &objects,
Fortran::lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands) {
@@ -52,25 +63,6 @@ void genObjectList(const ObjectList &objects,
}
}
-void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
- Fortran::lower::AbstractConverter &converter,
- llvm::SmallVectorImpl<mlir::Value> &operands) {
- auto addOperands = [&](Fortran::lower::SymbolRef sym) {
- const mlir::Value variable = converter.getSymbolAddress(sym);
- if (variable) {
- operands.push_back(variable);
- } else if (const auto *details =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
- operands.push_back(converter.getSymbolAddress(details->symbol()));
- converter.copySymbolBinding(details->symbol(), sym);
- }
- };
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
- addOperands(*sym);
- }
-}
-
mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
std::size_t loopVarTypeSize) {
// OpenMP runtime requires 32-bit or 64-bit loop variables.
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 4074bf73987d5b..b3a9f7f30c98bd 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -58,6 +58,8 @@ void gatherFuncAndVarSyms(
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);
+int64_t getCollapseValue(const List<Clause> &clauses);
+
Fortran::semantics::Symbol *
getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject);
@@ -65,10 +67,6 @@ void genObjectList(const ObjectList &objects,
Fortran::lower::AbstractConverter &converter,
llvm::SmallVectorImpl<mlir::Value> &operands);
-void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
- Fortran::lower::AbstractConverter &converter,
- llvm::SmallVectorImpl<mlir::Value> &operands);
-
} // namespace omp
} // namespace lower
} // namespace Fortran
>From 291dc48d5e0b7e0ee39681a1276bd1d63f456b01 Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Mon, 1 Apr 2024 10:07:45 -0500
Subject: [PATCH 2/3] [Frontend][OpenMP] Refactor getLeafConstructs, add
getCompoundConstruct
Emit a special leaf constuct table in DirectiveEmitter.cpp, which will
allow both decomposition of a construct into leafs, and composition of
constituent constructs into a single compound construct (is possible).
---
llvm/include/llvm/Frontend/OpenMP/OMP.h | 7 +
llvm/lib/Frontend/OpenMP/OMP.cpp | 64 +++++-
llvm/test/TableGen/directive1.td | 19 +-
llvm/test/TableGen/directive2.td | 19 +-
llvm/unittests/Frontend/CMakeLists.txt | 1 +
llvm/unittests/Frontend/OpenMPComposeTest.cpp | 41 ++++
llvm/utils/TableGen/DirectiveEmitter.cpp | 194 +++++++++++-------
7 files changed, 258 insertions(+), 87 deletions(-)
create mode 100644 llvm/unittests/Frontend/OpenMPComposeTest.cpp
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.h b/llvm/include/llvm/Frontend/OpenMP/OMP.h
index a85cd9d344c6d7..4ed47f15dfe59e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.h
@@ -15,4 +15,11 @@
#include "llvm/Frontend/OpenMP/OMP.h.inc"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace llvm::omp {
+ArrayRef<Directive> getLeafConstructs(Directive D);
+Directive getCompoundConstruct(ArrayRef<Directive> Parts);
+} // namespace llvm::omp
+
#endif // LLVM_FRONTEND_OPENMP_OMP_H
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index 4f2f95392648b3..dd99d3d074fd1e 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -8,12 +8,74 @@
#include "llvm/Frontend/OpenMP/OMP.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/ErrorHandling.h"
+#include <algorithm>
+#include <iterator>
+#include <type_traits>
+
using namespace llvm;
-using namespace omp;
+using namespace llvm::omp;
#define GEN_DIRECTIVES_IMPL
#include "llvm/Frontend/OpenMP/OMP.inc"
+
+namespace llvm::omp {
+ArrayRef<Directive> getLeafConstructs(Directive D) {
+ auto Idx = static_cast<int>(D);
+ if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize))
+ return {};
+ const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
+ return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
+}
+
+Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
+ if (Parts.empty())
+ return OMPD_unknown;
+
+ // Parts don't have to be leafs, so expand them into leafs first.
+ // Store the expanded leafs in the same format as rows in the leaf
+ // table (generated by tablegen).
+ SmallVector<Directive> RawLeafs(2);
+ for (Directive P : Parts) {
+ ArrayRef<Directive> Ls = getLeafConstructs(P);
+ if (!Ls.empty())
+ RawLeafs.append(Ls.begin(), Ls.end());
+ else
+ RawLeafs.push_back(P);
+ }
+
+ auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)};
+ if (GivenLeafs.size() == 1)
+ return GivenLeafs.front();
+ RawLeafs[1] = static_cast<Directive>(GivenLeafs.size());
+
+ auto Iter = llvm::lower_bound(
+ LeafConstructTable,
+ static_cast<std::decay_t<decltype(*LeafConstructTable)>>(RawLeafs.data()),
+ [](const auto *RowA, const auto *RowB) {
+ const auto *BeginA = &RowA[2];
+ const auto *EndA = BeginA + static_cast<int>(RowA[1]);
+ const auto *BeginB = &RowB[2];
+ const auto *EndB = BeginB + static_cast<int>(RowB[1]);
+ if (BeginA == EndA && BeginB == EndB)
+ return static_cast<int>(RowA[0]) < static_cast<int>(RowB[0]);
+ return std::lexicographical_compare(BeginA, EndA, BeginB, EndB);
+ });
+
+ if (Iter == std::end(LeafConstructTable))
+ return OMPD_unknown;
+
+ // Verify that we got a match.
+ Directive Found = (*Iter)[0];
+ ArrayRef<Directive> FoundLeafs = getLeafConstructs(Found);
+ if (FoundLeafs == GivenLeafs)
+ return Found;
+ return OMPD_unknown;
+}
+} // namespace llvm::omp
diff --git a/llvm/test/TableGen/directive1.td b/llvm/test/TableGen/directive1.td
index 3184f625ead928..e6150210e7e9a4 100644
--- a/llvm/test/TableGen/directive1.td
+++ b/llvm/test/TableGen/directive1.td
@@ -52,6 +52,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-EMPTY:
// CHECK-NEXT: #include "llvm/ADT/ArrayRef.h"
// CHECK-NEXT: #include "llvm/ADT/BitmaskEnum.h"
+// CHECK-NEXT: #include <cstddef>
// CHECK-EMPTY:
// CHECK-NEXT: namespace llvm {
// CHECK-NEXT: class StringRef;
@@ -112,7 +113,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version.
// CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
// CHECK-EMPTY:
-// CHECK-NEXT: llvm::ArrayRef<Directive> getLeafConstructs(Directive D);
+// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; }
// CHECK-NEXT: Association getDirectiveAssociation(Directive D);
// CHECK-NEXT: AKind getAKind(StringRef);
// CHECK-NEXT: llvm::StringRef getTdlAKindName(AKind);
@@ -359,13 +360,6 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind");
// IMPL-NEXT: }
// IMPL-EMPTY:
-// IMPL-NEXT: llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) {
-// IMPL-NEXT: switch (Dir) {
-// IMPL-NEXT: default:
-// IMPL-NEXT: return ArrayRef<llvm::tdl::Directive>{};
-// IMPL-NEXT: } // switch (Dir)
-// IMPL-NEXT: }
-// IMPL-EMPTY:
// IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) {
// IMPL-NEXT: switch (Dir) {
// IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira:
@@ -374,4 +368,13 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Unexpected directive");
// IMPL-NEXT: }
// IMPL-EMPTY:
+// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int));
+// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = {
+// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0),
+// IMPL-NEXT: };
+// IMPL-EMPTY:
+// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = {
+// IMPL-NEXT: 0,
+// IMPL-NEXT: };
+// IMPL-EMPTY:
// IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL
diff --git a/llvm/test/TableGen/directive2.td b/llvm/test/TableGen/directive2.td
index d6fa4835c8dfdc..1750022e1f94ea 100644
--- a/llvm/test/TableGen/directive2.td
+++ b/llvm/test/TableGen/directive2.td
@@ -45,6 +45,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-NEXT: #define LLVM_Tdl_INC
// CHECK-EMPTY:
// CHECK-NEXT: #include "llvm/ADT/ArrayRef.h"
+// CHECK-NEXT: #include <cstddef>
// CHECK-EMPTY:
// CHECK-NEXT: namespace llvm {
// CHECK-NEXT: class StringRef;
@@ -88,7 +89,7 @@ def TDL_DirA : Directive<"dira"> {
// CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version.
// CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version);
// CHECK-EMPTY:
-// CHECK-NEXT: llvm::ArrayRef<Directive> getLeafConstructs(Directive D);
+// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; }
// CHECK-NEXT: Association getDirectiveAssociation(Directive D);
// CHECK-NEXT: } // namespace tdl
// CHECK-NEXT: } // namespace llvm
@@ -290,13 +291,6 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind");
// IMPL-NEXT: }
// IMPL-EMPTY:
-// IMPL-NEXT: llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) {
-// IMPL-NEXT: switch (Dir) {
-// IMPL-NEXT: default:
-// IMPL-NEXT: return ArrayRef<llvm::tdl::Directive>{};
-// IMPL-NEXT: } // switch (Dir)
-// IMPL-NEXT: }
-// IMPL-EMPTY:
// IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) {
// IMPL-NEXT: switch (Dir) {
// IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira:
@@ -305,4 +299,13 @@ def TDL_DirA : Directive<"dira"> {
// IMPL-NEXT: llvm_unreachable("Unexpected directive");
// IMPL-NEXT: }
// IMPL-EMPTY:
+// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int));
+// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = {
+// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0),
+// IMPL-NEXT: };
+// IMPL-EMPTY:
+// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = {
+// IMPL-NEXT: 0,
+// IMPL-NEXT: };
+// IMPL-EMPTY:
// IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL
diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt
index c6f60142d6276a..ddb6a16cbb984e 100644
--- a/llvm/unittests/Frontend/CMakeLists.txt
+++ b/llvm/unittests/Frontend/CMakeLists.txt
@@ -14,6 +14,7 @@ add_llvm_unittest(LLVMFrontendTests
OpenMPContextTest.cpp
OpenMPIRBuilderTest.cpp
OpenMPParsingTest.cpp
+ OpenMPComposeTest.cpp
DEPENDS
acc_gen
diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
new file mode 100644
index 00000000000000..29b1be4eb3432c
--- /dev/null
+++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
@@ -0,0 +1,41 @@
+//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Frontend/OpenMP/OMP.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace llvm::omp;
+
+TEST(Composition, GetLeafConstructs) {
+ ArrayRef<Directive> L1 = getLeafConstructs(OMPD_loop);
+ ASSERT_EQ(L1, (ArrayRef<Directive>{}));
+ ArrayRef<Directive> L2 = getLeafConstructs(OMPD_parallel_for);
+ ASSERT_EQ(L2, (ArrayRef<Directive>{OMPD_parallel, OMPD_for}));
+ ArrayRef<Directive> L3 = getLeafConstructs(OMPD_parallel_for_simd);
+ ASSERT_EQ(L3, (ArrayRef<Directive>{OMPD_parallel, OMPD_for, OMPD_simd}));
+}
+
+TEST(Composition, GetCompoundConstruct) {
+ Directive C1 =
+ getCompoundConstruct({OMPD_target, OMPD_teams, OMPD_distribute});
+ ASSERT_EQ(C1, OMPD_target_teams_distribute);
+ Directive C2 = getCompoundConstruct({OMPD_target});
+ ASSERT_EQ(C2, OMPD_target);
+ Directive C3 = getCompoundConstruct({OMPD_target, OMPD_masked});
+ ASSERT_EQ(C3, OMPD_unknown);
+ Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
+ ASSERT_EQ(C4, OMPD_target_teams_distribute);
+ Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
+ ASSERT_EQ(C5, OMPD_target_teams_distribute);
+ Directive C6 = getCompoundConstruct({});
+ ASSERT_EQ(C6, OMPD_unknown);
+ Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
+ ASSERT_EQ(C7, OMPD_parallel_for_simd);
+}
diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp
index e0edf1720f8ac5..2d2b7748491897 100644
--- a/llvm/utils/TableGen/DirectiveEmitter.cpp
+++ b/llvm/utils/TableGen/DirectiveEmitter.cpp
@@ -20,6 +20,9 @@
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
+#include <numeric>
+#include <vector>
+
using namespace llvm;
namespace {
@@ -39,7 +42,8 @@ class IfDefScope {
};
} // namespace
-// Generate enum class
+// Generate enum class. Entries are emitted in the order in which they appear
+// in the `Records` vector.
static void GenerateEnumClass(const std::vector<Record *> &Records,
raw_ostream &OS, StringRef Enum, StringRef Prefix,
const DirectiveLanguage &DirLang,
@@ -175,6 +179,16 @@ bool DirectiveLanguage::HasValidityErrors() const {
return HasDuplicateClausesInDirectives(getDirectives());
}
+// Count the maximum number of leaf constituents per construct.
+static size_t GetMaxLeafCount(const DirectiveLanguage &DirLang) {
+ size_t MaxCount = 0;
+ for (Record *R : DirLang.getDirectives()) {
+ size_t Count = Directive{R}.getLeafConstructs().size();
+ MaxCount = std::max(MaxCount, Count);
+ }
+ return MaxCount;
+}
+
// Generate the declaration section for the enumeration in the directive
// language
static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
@@ -189,6 +203,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
if (DirLang.hasEnableBitmaskEnumInNamespace())
OS << "#include \"llvm/ADT/BitmaskEnum.h\"\n";
+ OS << "#include <cstddef>\n"; // for size_t
OS << "\n";
OS << "namespace llvm {\n";
OS << "class StringRef;\n";
@@ -244,7 +259,8 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) {
OS << "bool isAllowedClauseForDirective(Directive D, "
<< "Clause C, unsigned Version);\n";
OS << "\n";
- OS << "llvm::ArrayRef<Directive> getLeafConstructs(Directive D);\n";
+ OS << "constexpr std::size_t getMaxLeafCount() { return "
+ << GetMaxLeafCount(DirLang) << "; }\n";
OS << "Association getDirectiveAssociation(Directive D);\n";
if (EnumHelperFuncs.length() > 0) {
OS << EnumHelperFuncs;
@@ -396,6 +412,19 @@ GenerateCaseForVersionedClauses(const std::vector<Record *> &Clauses,
}
}
+static std::string GetDirectiveName(const DirectiveLanguage &DirLang,
+ const Record *Rec) {
+ Directive Dir{Rec};
+ return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::" +
+ DirLang.getDirectivePrefix() + Dir.getFormattedName())
+ .str();
+}
+
+static std::string GetDirectiveType(const DirectiveLanguage &DirLang) {
+ return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::Directive")
+ .str();
+}
+
// Generate the isAllowedClauseForDirective function implementation.
static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
raw_ostream &OS) {
@@ -450,77 +479,102 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang,
OS << "}\n"; // End of function isAllowedClauseForDirective
}
-// Generate the getLeafConstructs function implementation.
-static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang,
- raw_ostream &OS) {
- auto getQualifiedName = [&](StringRef Formatted) -> std::string {
- return (llvm::Twine("llvm::") + DirLang.getCppNamespace() +
- "::Directive::" + DirLang.getDirectivePrefix() + Formatted)
- .str();
- };
-
- // For each list of leaves, generate a static local object, then
- // return a reference to that object for a given directive, e.g.
+static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS,
+ StringRef TableName) {
+ // The leaf constructs are emitted in a form of a 2D table, where each
+ // row corresponds to a directive (and there is a row for each directive).
//
- // static ListTy leafConstructs_A_B = { A, B };
- // static ListTy leafConstructs_C_D_E = { C, D, E };
- // switch (Dir) {
- // case A_B:
- // return leafConstructs_A_B;
- // case C_D_E:
- // return leafConstructs_C_D_E;
- // }
-
- // Map from a record that defines a directive to the name of the
- // local object with the list of its leaves.
- DenseMap<Record *, std::string> ListNames;
-
- std::string DirectiveTypeName =
- std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive";
-
- OS << '\n';
-
- // ArrayRef<...> llvm::<ns>::GetLeafConstructs(llvm::<ns>::Directive Dir)
- OS << "llvm::ArrayRef<" << DirectiveTypeName
- << "> llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs("
- << DirectiveTypeName << " Dir) ";
- OS << "{\n";
-
- // Generate the locals.
- for (Record *R : DirLang.getDirectives()) {
- Directive Dir{R};
+ // Each row consists of
+ // - the id of the directive itself,
+ // - number of leaf constructs that will follow (0 for leafs),
+ // - ids of the leaf constructs (none if the directive is itself a leaf).
+ // The total number of these entries is at most MaxLeafCount+2. If this
+ // number is less than that, it is padded to occupy exactly MaxLeafCount+2
+ // entries in memory.
+ //
+ // The rows are stored in the table in the lexicographical order. This
+ // is intended to enable binary search when mapping a sequence of leafs
+ // back to the compound directive.
+ // The consequence of that is that in order to find a row corresponding
+ // to the given directive, we'd need to scan the first element of each
+ // row. To avoid this, an auxiliary ordering table is created, such that
+ // row for Dir_A = table[auxiliary[Dir_A]].
+
+ std::vector<Record *> Directives = DirLang.getDirectives();
+ DenseMap<Record *, size_t> DirId; // Record * -> llvm::omp::Directive
+
+ for (auto [Idx, Rec] : llvm::enumerate(Directives))
+ DirId.insert(std::make_pair(Rec, Idx));
+
+ using LeafList = std::vector<int>;
+ int MaxLeafCount = GetMaxLeafCount(DirLang);
+
+ // The initial leaf table, rows order is same as directive order.
+ std::vector<LeafList> LeafTable(Directives.size());
+ for (auto [Idx, Rec] : llvm::enumerate(Directives)) {
+ Directive Dir{Rec};
+ std::vector<Record *> Leaves = Dir.getLeafConstructs();
+
+ auto &List = LeafTable[Idx];
+ List.resize(MaxLeafCount + 2);
+ List[0] = Idx; // The id of the directive itself.
+ List[1] = Leaves.size(); // The number of leaves to follow.
+
+ for (int I = 0; I != MaxLeafCount; ++I)
+ List[I + 2] =
+ static_cast<size_t>(I) < Leaves.size() ? DirId.at(Leaves[I]) : -1;
+ }
- std::vector<Record *> LeafConstructs = Dir.getLeafConstructs();
- if (LeafConstructs.empty())
- continue;
+ // Avoid sorting the vector<vector> array, instead sort an index array.
+ // It will also be useful later to create the auxiliary indexing array.
+ std::vector<int> Ordering(Directives.size());
+ std::iota(Ordering.begin(), Ordering.end(), 0);
+
+ llvm::sort(Ordering, [&](int A, int B) {
+ auto &LeavesA = LeafTable[A];
+ auto &LeavesB = LeafTable[B];
+ if (LeavesA[1] == 0 && LeavesB[1] == 0)
+ return LeavesA[0] < LeavesB[0];
+ return std::lexicographical_compare(&LeavesA[2], &LeavesA[2] + LeavesA[1],
+ &LeavesB[2], &LeavesB[2] + LeavesB[1]);
+ });
- std::string ListName = "leafConstructs_" + Dir.getFormattedName();
- OS << " static const " << DirectiveTypeName << ' ' << ListName
- << "[] = {\n";
- for (Record *L : LeafConstructs) {
- Directive LeafDir{L};
- OS << " " << getQualifiedName(LeafDir.getFormattedName()) << ",\n";
+ // Emit the table
+
+ // The directives are emitted into a scoped enum, for which the underlying
+ // type is `int` (by default). The code above uses `int` to store directive
+ // ids, so make sure that we catch it when something changes in the
+ // underlying type.
+ std::string DirectiveType = GetDirectiveType(DirLang);
+ OS << "\nstatic_assert(sizeof(" << DirectiveType << ") == sizeof(int));\n";
+
+ OS << "[[maybe_unused]] static const " << DirectiveType << ' ' << TableName
+ << "[][" << MaxLeafCount + 2 << "] = {\n";
+ for (size_t I = 0, E = Directives.size(); I != E; ++I) {
+ auto &Leaves = LeafTable[Ordering[I]];
+ OS << " " << GetDirectiveName(DirLang, Directives[Leaves[0]]);
+ OS << ", static_cast<" << DirectiveType << ">(" << Leaves[1] << "),";
+ for (size_t I = 2, E = Leaves.size(); I != E; ++I) {
+ int Idx = Leaves[I];
+ if (Idx >= 0)
+ OS << ' ' << GetDirectiveName(DirLang, Directives[Leaves[I]]) << ',';
+ else
+ OS << " static_cast<" << DirectiveType << ">(-1),";
}
- OS << " };\n";
- ListNames.insert(std::make_pair(R, std::move(ListName)));
- }
-
- if (!ListNames.empty())
OS << '\n';
- OS << " switch (Dir) {\n";
- for (Record *R : DirLang.getDirectives()) {
- auto F = ListNames.find(R);
- if (F == ListNames.end())
- continue;
-
- Directive Dir{R};
- OS << " case " << getQualifiedName(Dir.getFormattedName()) << ":\n";
- OS << " return " << F->second << ";\n";
}
- OS << " default:\n";
- OS << " return ArrayRef<" << DirectiveTypeName << ">{};\n";
- OS << " } // switch (Dir)\n";
- OS << "}\n";
+ OS << "};\n\n";
+
+ // Emit the auxiliary index table: it's the inverse of the `Ordering`
+ // table above.
+ OS << "[[maybe_unused]] static const int " << TableName << "Ordering[] = {\n";
+ OS << " ";
+ std::vector<int> Reverse(Ordering.size());
+ for (int I = 0, E = Ordering.size(); I != E; ++I)
+ Reverse[Ordering[I]] = I;
+ for (int Idx : Reverse)
+ OS << ' ' << Idx << ',';
+ OS << "\n};\n";
}
static void GenerateGetDirectiveAssociation(const DirectiveLanguage &DirLang,
@@ -1105,11 +1159,11 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang,
// isAllowedClauseForDirective(Directive D, Clause C, unsigned Version)
GenerateIsAllowedClause(DirLang, OS);
- // getLeafConstructs(Directive D)
- GenerateGetLeafConstructs(DirLang, OS);
-
// getDirectiveAssociation(Directive D)
GenerateGetDirectiveAssociation(DirLang, OS);
+
+ // Leaf table for getLeafConstructs, etc.
+ EmitLeafTable(DirLang, OS, "LeafConstructTable");
}
// Generate the implemenation section for the enumeration in the directive
>From 0d92781c7a52ed2fbab33ae6e7b3dae61cfd42ae Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Tue, 2 Apr 2024 08:20:15 -0500
Subject: [PATCH 3/3] Address review comments
---
llvm/lib/Frontend/OpenMP/OMP.cpp | 10 ++++++++--
llvm/unittests/Frontend/OpenMPComposeTest.cpp | 10 ++++------
2 files changed, 12 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp
index dd99d3d074fd1e..7504c9076fde1b 100644
--- a/llvm/lib/Frontend/OpenMP/OMP.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMP.cpp
@@ -27,8 +27,8 @@ using namespace llvm::omp;
namespace llvm::omp {
ArrayRef<Directive> getLeafConstructs(Directive D) {
- auto Idx = static_cast<int>(D);
- if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize))
+ auto Idx = static_cast<std::size_t>(D);
+ if (Idx >= Directive_enumSize)
return {};
const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]];
return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1]));
@@ -50,6 +50,12 @@ Directive getCompoundConstruct(ArrayRef<Directive> Parts) {
RawLeafs.push_back(P);
}
+ // RawLeafs will be used as key in the binary search. The search doesn't
+ // guarantee that the exact same entry will be found (since RawLeafs may
+ // not correspond to any compound directive). Because of that, we will
+ // need to compare the search result with the given set of leafs.
+ // Also, if there is only one leaf in the list, it corresponds to itself,
+ // no search is necessary.
auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)};
if (GivenLeafs.size() == 1)
return GivenLeafs.front();
diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
index 29b1be4eb3432c..c3e0880ece8641 100644
--- a/llvm/unittests/Frontend/OpenMPComposeTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp
@@ -32,10 +32,8 @@ TEST(Composition, GetCompoundConstruct) {
ASSERT_EQ(C3, OMPD_unknown);
Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
ASSERT_EQ(C4, OMPD_target_teams_distribute);
- Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute});
- ASSERT_EQ(C5, OMPD_target_teams_distribute);
- Directive C6 = getCompoundConstruct({});
- ASSERT_EQ(C6, OMPD_unknown);
- Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
- ASSERT_EQ(C7, OMPD_parallel_for_simd);
+ Directive C5 = getCompoundConstruct({});
+ ASSERT_EQ(C5, OMPD_unknown);
+ Directive C6 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd});
+ ASSERT_EQ(C6, OMPD_parallel_for_simd);
}
More information about the llvm-branch-commits
mailing list