[llvm-branch-commits] [flang] [Flang][OpenMP][Lower] Use clause operand structures (PR #86802)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Mar 28 09:32:07 PDT 2024
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/86802
>From 7af7e9d13fc2134e76bb532bfa4313aa3df17924 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 26 Mar 2024 16:46:56 +0000
Subject: [PATCH] [Flang][OpenMP][Lower] Use clause operand structures
This patch updates Flang lowering to use the new set of OpenMP clause operand
structures and their groupings into directive-specific sets of clause operands.
It simplifies the passing of information from the clause processor and the
creation of operations.
The `DataSharingProcessor` is slightly modified to not hold delayed
privatization state. Instead, optional arguments are added to `processStep1`
which are only passed when delayed privatization is used. This enables using
the clause operand structure for `private` and removes the need for the ad-hoc
`DelayedPrivatizationInfo` structure.
The processing of the `schedule` clause is updated to process the `chunk`
modifier rather than requiring two separate calls to the `ClauseProcessor`.
Lowering of a block-associated `ordered` construct is updated to emit a TODO
error if the `simd` clause is specified, since it is not currently supported by
the `ClauseProcessor` or later compilation stages.
Removed processing of `schedule` from `omp.simdloop`, as it doesn't apply to
`simd` constructs.
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 261 +++++----
flang/lib/Lower/OpenMP/ClauseProcessor.h | 105 ++--
.../lib/Lower/OpenMP/DataSharingProcessor.cpp | 38 +-
flang/lib/Lower/OpenMP/DataSharingProcessor.h | 45 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 517 +++++++-----------
5 files changed, 428 insertions(+), 538 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 0a57a1496289f4..ee1f6c2fbc7e89 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -162,14 +162,13 @@ getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
ifVal);
}
-static void
-addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
- const omp::ObjectList &objects,
- llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) {
+static void addUseDeviceClause(
+ Fortran::lower::AbstractConverter &converter,
+ const omp::ObjectList &objects,
+ llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) {
genObjectList(objects, converter, operands);
for (mlir::Value &operand : operands) {
checkMapType(operand.getLoc(), operand.getType());
@@ -177,25 +176,24 @@ addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
useDeviceLocs.push_back(operand.getLoc());
}
for (const omp::Object &object : objects)
- useDeviceSymbols.push_back(object.id());
+ useDeviceSyms.push_back(object.id());
}
static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
mlir::Location loc,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
+ mlir::omp::CollapseClauseOps &result,
std::size_t loopVarTypeSize) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
// The types of lower bound, upper bound, and step are converted into the
// type of the loop variable if necessary.
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
- lowerBound[it] =
- firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]);
- upperBound[it] =
- firOpBuilder.createConvert(loc, loopVarType, upperBound[it]);
- step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]);
+ for (unsigned it = 0; it < (unsigned)result.loopLBVar.size(); it++) {
+ result.loopLBVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopLBVar[it]);
+ result.loopUBVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopUBVar[it]);
+ result.loopStepVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopStepVar[it]);
}
}
@@ -205,9 +203,7 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
bool ClauseProcessor::processCollapse(
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
+ mlir::omp::CollapseClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const {
bool found = false;
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -238,15 +234,15 @@ bool ClauseProcessor::processCollapse(
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
assert(bounds && "Expected bounds for worksharing do loop");
Fortran::lower::StatementContext stmtCtx;
- lowerBound.push_back(fir::getBase(converter.genExprValue(
+ result.loopLBVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
- upperBound.push_back(fir::getBase(converter.genExprValue(
+ result.loopUBVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
if (bounds->step) {
- step.push_back(fir::getBase(converter.genExprValue(
+ result.loopStepVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
} else { // If `step` is not present, assume it as `1`.
- step.push_back(firOpBuilder.createIntegerConstant(
+ result.loopStepVar.push_back(firOpBuilder.createIntegerConstant(
currentLocation, firOpBuilder.getIntegerType(32), 1));
}
iv.push_back(bounds->name.thing.symbol);
@@ -257,8 +253,7 @@ bool ClauseProcessor::processCollapse(
&*std::next(doConstructEval->getNestedEvaluations().begin());
} while (collapseValue > 0);
- convertLoopBounds(converter, currentLocation, lowerBound, upperBound, step,
- loopVarTypeSize);
+ convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
return found;
}
@@ -286,7 +281,7 @@ bool ClauseProcessor::processDefault() const {
}
bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+ mlir::omp::DeviceClauseOps &result) const {
const Fortran::parser::CharBlock *source = nullptr;
if (auto *clause = findUniqueClause<omp::clause::Device>(&source)) {
mlir::Location clauseLocation = converter.genLocation(*source);
@@ -298,25 +293,26 @@ bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
}
}
const auto &deviceExpr = std::get<omp::SomeExpr>(clause->t);
- result = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
+ result.deviceVar =
+ fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processDeviceType(
- mlir::omp::DeclareTargetDeviceType &result) const {
+ mlir::omp::DeviceTypeClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::DeviceType>()) {
// Case: declare target ... device_type(any | host | nohost)
switch (clause->v) {
case omp::clause::DeviceType::DeviceTypeDescription::Nohost:
- result = mlir::omp::DeclareTargetDeviceType::nohost;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::nohost;
break;
case omp::clause::DeviceType::DeviceTypeDescription::Host:
- result = mlir::omp::DeclareTargetDeviceType::host;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::host;
break;
case omp::clause::DeviceType::DeviceTypeDescription::Any:
- result = mlir::omp::DeclareTargetDeviceType::any;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::any;
break;
}
return true;
@@ -325,7 +321,7 @@ bool ClauseProcessor::processDeviceType(
}
bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+ mlir::omp::FinalClauseOps &result) const {
const Fortran::parser::CharBlock *source = nullptr;
if (auto *clause = findUniqueClause<omp::clause::Final>(&source)) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -333,100 +329,108 @@ bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
mlir::Value finalVal =
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
- result = firOpBuilder.createConvert(clauseLocation,
- firOpBuilder.getI1Type(), finalVal);
+ result.finalVar = firOpBuilder.createConvert(
+ clauseLocation, firOpBuilder.getI1Type(), finalVal);
return true;
}
return false;
}
-bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Hint>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
int64_t hintValue = *Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(hintValue);
+ result.hintAttr = firOpBuilder.getI64IntegerAttr(hintValue);
return true;
}
return false;
}
-bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Mergeable>(result);
+bool ClauseProcessor::processMergeable(
+ mlir::omp::MergeableClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Mergeable>(result.mergeableAttr);
}
-bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Nowait>(result);
+bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Nowait>(result.nowaitAttr);
}
-bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+bool ClauseProcessor::processNumTeams(
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::NumTeamsClauseOps &result) const {
// TODO Get lower and upper bounds for num_teams when parser is updated to
// accept both.
if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
// auto lowerBound = std::get<std::optional<ExprTy>>(clause->t);
auto &upperBound = std::get<ExprTy>(clause->t);
- result = fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+ result.numTeamsUpperVar =
+ fir::getBase(converter.genExprValue(upperBound, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processNumThreads(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::NumThreadsClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.numThreadsVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
-bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processOrdered(
+ mlir::omp::OrderedClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Ordered>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
int64_t orderedClauseValue = 0l;
if (clause->v.has_value())
orderedClauseValue = *Fortran::evaluate::ToInt64(*clause->v);
- result = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
+ result.orderedAttr = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
return true;
}
return false;
}
-bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+bool ClauseProcessor::processPriority(
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::PriorityClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Priority>()) {
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.priorityVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processProcBind(
- mlir::omp::ClauseProcBindKindAttr &result) const {
+ mlir::omp::ProcBindClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::ProcBind>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- result = genProcBindKindAttr(firOpBuilder, *clause);
+ result.procBindKindAttr = genProcBindKindAttr(firOpBuilder, *clause);
return true;
}
return false;
}
-bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processSafelen(
+ mlir::omp::SafelenClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
const std::optional<std::int64_t> safelenVal =
Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(*safelenVal);
+ result.safelenAttr = firOpBuilder.getI64IntegerAttr(*safelenVal);
return true;
}
return false;
}
bool ClauseProcessor::processSchedule(
- mlir::omp::ClauseScheduleKindAttr &valAttr,
- mlir::omp::ScheduleModifierAttr &modifierAttr,
- mlir::UnitAttr &simdModifierAttr) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::ScheduleClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::MLIRContext *context = firOpBuilder.getContext();
@@ -451,53 +455,51 @@ bool ClauseProcessor::processSchedule(
break;
}
- mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
+ result.scheduleValAttr =
+ mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
+ mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
if (scheduleModifier != mlir::omp::ScheduleModifier::none)
- modifierAttr =
+ result.scheduleModAttr =
mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none)
- simdModifierAttr = firOpBuilder.getUnitAttr();
+ result.scheduleSimdAttr = firOpBuilder.getUnitAttr();
- valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processScheduleChunk(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
- if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
if (const auto &chunkExpr = std::get<omp::MaybeExpr>(clause->t))
- result = fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
+ result.scheduleChunkVar =
+ fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
+
return true;
}
return false;
}
-bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processSimdlen(
+ mlir::omp::SimdlenClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Simdlen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
const std::optional<std::int64_t> simdlenVal =
Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(*simdlenVal);
+ result.simdlenAttr = firOpBuilder.getI64IntegerAttr(*simdlenVal);
return true;
}
return false;
}
bool ClauseProcessor::processThreadLimit(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::ThreadLimitClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.threadLimitVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
-bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Untied>(result);
+bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Untied>(result.untiedAttr);
}
//===----------------------------------------------------------------------===//
@@ -505,13 +507,12 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
//===----------------------------------------------------------------------===//
bool ClauseProcessor::processAllocate(
- llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
- llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const {
+ mlir::omp::AllocateClauseOps &result) const {
return findRepeatableClause<omp::clause::Allocate>(
[&](const omp::clause::Allocate &clause,
const Fortran::parser::CharBlock &) {
- genAllocateClause(converter, clause, allocatorOperands,
- allocateOperands);
+ genAllocateClause(converter, clause, result.allocatorVars,
+ result.allocateVars);
});
}
@@ -660,10 +661,9 @@ createCopyFunc(mlir::Location loc, Fortran::lower::AbstractConverter &converter,
return funcOp;
}
-bool ClauseProcessor::processCopyPrivate(
+bool ClauseProcessor::processCopyprivate(
mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> ©PrivateVars,
- llvm::SmallVectorImpl<mlir::Attribute> ©PrivateFuncs) const {
+ mlir::omp::CopyprivateClauseOps &result) const {
auto addCopyPrivateVar = [&](Fortran::semantics::Symbol *sym) {
mlir::Value symVal = converter.getSymbolAddress(*sym);
auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>();
@@ -690,10 +690,10 @@ bool ClauseProcessor::processCopyPrivate(
cpVar = alloca;
}
- copyPrivateVars.push_back(cpVar);
+ result.copyprivateVars.push_back(cpVar);
mlir::func::FuncOp funcOp =
createCopyFunc(currentLocation, converter, cpVar.getType(), attrs);
- copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
+ result.copyprivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
};
bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
@@ -714,9 +714,7 @@ bool ClauseProcessor::processCopyPrivate(
return hasCopyPrivate;
}
-bool ClauseProcessor::processDepend(
- llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
- llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
+bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<omp::clause::Depend>(
@@ -731,7 +729,7 @@ bool ClauseProcessor::processDepend(
mlir::omp::ClauseTaskDependAttr dependTypeOperand =
genDependKindAttr(firOpBuilder, kind);
- dependTypeOperands.append(objects.size(), dependTypeOperand);
+ result.dependTypeAttrs.append(objects.size(), dependTypeOperand);
for (const omp::Object &object : objects) {
assert(object.ref() && "Expecting designator");
@@ -746,14 +744,14 @@ bool ClauseProcessor::processDepend(
Fortran::semantics::Symbol *sym = object.id();
const mlir::Value variable = converter.getSymbolAddress(*sym);
- dependOperands.push_back(variable);
+ result.dependVars.push_back(variable);
}
});
}
bool ClauseProcessor::processIf(
omp::clause::If::DirectiveNameModifier directiveName,
- mlir::Value &result) const {
+ mlir::omp::IfClauseOps &result) const {
bool found = false;
findRepeatableClause<omp::clause::If>(
[&](const omp::clause::If &clause,
@@ -764,7 +762,7 @@ bool ClauseProcessor::processIf(
// Assume that, at most, a single 'if' clause will be applicable to the
// given directive.
if (operand) {
- result = operand;
+ result.ifVar = operand;
found = true;
}
});
@@ -807,12 +805,10 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
bool ClauseProcessor::processMap(
mlir::Location currentLocation, const llvm::omp::Directive &directive,
- Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
+ Fortran::lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
- const {
+ llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause,
@@ -887,25 +883,23 @@ bool ClauseProcessor::processMap(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
- mapOperands.push_back(mapOp);
- if (mapSymTypes)
- mapSymTypes->push_back(symAddr.getType());
+ result.mapVars.push_back(mapOp);
+
+ if (mapSyms)
+ mapSyms->push_back(object.id());
if (mapSymLocs)
mapSymLocs->push_back(symAddr.getLoc());
-
- if (mapSymbols)
- mapSymbols->push_back(object.id());
+ if (mapSymTypes)
+ mapSymTypes->push_back(symAddr.getType());
}
});
}
bool ClauseProcessor::processReduction(
- mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> &outReductionVars,
- llvm::SmallVectorImpl<mlir::Type> &outReductionTypes,
- llvm::SmallVectorImpl<mlir::Attribute> &outReductionDeclSymbols,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *outReductionSymbols) const {
+ mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
+ llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *outReductionSyms)
+ const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &clause,
const Fortran::parser::CharBlock &) {
@@ -915,30 +909,31 @@ bool ClauseProcessor::processReduction(
// whether to do the reduction byref.
llvm::SmallVector<mlir::Value> reductionVars;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(currentLocation, converter, clause,
reductionVars, reductionDeclSymbols,
- outReductionSymbols ? &reductionSymbols
- : nullptr);
+ outReductionSyms ? &reductionSyms : nullptr);
// Copy local lists into the output.
- llvm::copy(reductionVars, std::back_inserter(outReductionVars));
+ llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reductionDeclSymbols,
- std::back_inserter(outReductionDeclSymbols));
- if (outReductionSymbols)
- llvm::copy(reductionSymbols,
- std::back_inserter(*outReductionSymbols));
-
- outReductionTypes.reserve(outReductionTypes.size() +
- reductionVars.size());
- llvm::transform(reductionVars, std::back_inserter(outReductionTypes),
- [](mlir::Value v) { return v.getType(); });
+ std::back_inserter(result.reductionDeclSymbols));
+
+ if (outReductionTypes) {
+ outReductionTypes->reserve(outReductionTypes->size() +
+ reductionVars.size());
+ llvm::transform(reductionVars, std::back_inserter(*outReductionTypes),
+ [](mlir::Value v) { return v.getType(); });
+ }
+
+ if (outReductionSyms)
+ llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms));
});
}
bool ClauseProcessor::processSectionsReduction(
- mlir::Location currentLocation) const {
+ mlir::Location currentLocation, mlir::omp::ReductionClauseOps &) const {
return findRepeatableClause<omp::clause::Reduction>(
[&](const omp::clause::Reduction &, const Fortran::parser::CharBlock &) {
TODO(currentLocation, "OMPC_Reduction");
@@ -967,30 +962,30 @@ bool ClauseProcessor::processEnter(
}
bool ClauseProcessor::processUseDeviceAddr(
- llvm::SmallVectorImpl<mlir::Value> &operands,
+ mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
const {
return findRepeatableClause<omp::clause::UseDeviceAddr>(
[&](const omp::clause::UseDeviceAddr &clause,
const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
- useDeviceLocs, useDeviceSymbols);
+ addUseDeviceClause(converter, clause.v, result.useDeviceAddrVars,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms);
});
}
bool ClauseProcessor::processUseDevicePtr(
- llvm::SmallVectorImpl<mlir::Value> &operands,
+ mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSymbols)
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
const {
return findRepeatableClause<omp::clause::UseDevicePtr>(
[&](const omp::clause::UseDevicePtr &clause,
const Fortran::parser::CharBlock &) {
- addUseDeviceClause(converter, clause.v, operands, useDeviceTypes,
- useDeviceLocs, useDeviceSymbols);
+ addUseDeviceClause(converter, clause.v, result.useDevicePtrVars,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms);
});
}
} // namespace omp
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index c0c603feb296af..d933e0a913d2bc 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -37,7 +37,7 @@ namespace omp {
/// corresponding clause if it is present in the clause list. Otherwise, they
/// will return `false` to signal that the clause was not found.
///
-/// The intended use is of this class is to move clause processing outside of
+/// The intended use of this class is to move clause processing outside of
/// construct processing, since the same clauses can appear attached to
/// different constructs and constructs can be combined, so that code
/// duplication is minimized.
@@ -56,94 +56,83 @@ class ClauseProcessor {
// 'Unique' clauses: They can appear at most once in the clause list.
bool processCollapse(
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
+ mlir::omp::CollapseClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const;
bool processDefault() const;
bool processDevice(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const;
+ mlir::omp::DeviceClauseOps &result) const;
+ bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processFinal(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processHint(mlir::IntegerAttr &result) const;
- bool processMergeable(mlir::UnitAttr &result) const;
- bool processNowait(mlir::UnitAttr &result) const;
+ mlir::omp::FinalClauseOps &result) const;
+ bool processHint(mlir::omp::HintClauseOps &result) const;
+ bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
+ bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
+ mlir::omp::NumTeamsClauseOps &result) const;
bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processOrdered(mlir::IntegerAttr &result) const;
+ mlir::omp::NumThreadsClauseOps &result) const;
+ bool processOrdered(mlir::omp::OrderedClauseOps &result) const;
bool processPriority(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const;
- bool processSafelen(mlir::IntegerAttr &result) const;
- bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr,
- mlir::omp::ScheduleModifierAttr &modifierAttr,
- mlir::UnitAttr &simdModifierAttr) const;
- bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processSimdlen(mlir::IntegerAttr &result) const;
+ mlir::omp::PriorityClauseOps &result) const;
+ bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
+ bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
+ bool processSchedule(Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::ScheduleClauseOps &result) const;
+ bool processSimdlen(mlir::omp::SimdlenClauseOps &result) const;
bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const;
- bool processUntied(mlir::UnitAttr &result) const;
+ mlir::omp::ThreadLimitClauseOps &result) const;
+ bool processUntied(mlir::omp::UntiedClauseOps &result) const;
// 'Repeatable' clauses: They can appear multiple times in the clause list.
- bool
- processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
- llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const;
+ bool processAllocate(mlir::omp::AllocateClauseOps &result) const;
bool processCopyin() const;
- bool processCopyPrivate(
- mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> ©PrivateVars,
- llvm::SmallVectorImpl<mlir::Attribute> ©PrivateFuncs) const;
- bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
- llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
+ bool processCopyprivate(mlir::Location currentLocation,
+ mlir::omp::CopyprivateClauseOps &result) const;
+ bool processDepend(mlir::omp::DependClauseOps &result) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
- mlir::Value &result) const;
+ mlir::omp::IfClauseOps &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
// This method is used to process a map clause.
- // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to
+ // The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to
// store the original type, location and Fortran symbol for the map operands.
// They may be used later on to create the block_arguments for some of the
// target directives that require it.
- bool processMap(mlir::Location currentLocation,
- const llvm::omp::Directive &directive,
- Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
- llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *mapSymbols = nullptr) const;
- bool
- processReduction(mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- *reductionSymbols = nullptr) const;
- bool processSectionsReduction(mlir::Location currentLocation) const;
+ bool processMap(
+ mlir::Location currentLocation, const llvm::omp::Directive &directive,
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::MapClauseOps &result,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms =
+ nullptr,
+ llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
+ llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
+ bool processReduction(
+ mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
+ llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSyms =
+ nullptr) const;
+ bool processSectionsReduction(mlir::Location currentLocation,
+ mlir::omp::ReductionClauseOps &result) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
- processUseDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands,
+ processUseDeviceAddr(mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) const;
+ &useDeviceSyms) const;
bool
- processUseDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands,
+ processUseDevicePtr(mlir::omp::UseDeviceClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) const;
+ &useDeviceSyms) const;
template <typename T>
bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands);
+ mlir::omp::MapClauseOps &result);
// Call this method for these clauses that should be supported but are not
// implemented yet. It triggers a compilation error if any of the given
@@ -185,7 +174,7 @@ class ClauseProcessor {
template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
+ mlir::omp::MapClauseOps &result) {
return findRepeatableClause<T>(
[&](const T &clause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
@@ -227,7 +216,7 @@ bool ClauseProcessor::processMotionClauses(
mapTypeBits),
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
- mapOperands.push_back(mapOp);
+ result.mapVars.push_back(mapOp);
}
});
}
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index e114ab9f4548ab..5a42e6a6aa4175 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -23,11 +23,13 @@ namespace Fortran {
namespace lower {
namespace omp {
-void DataSharingProcessor::processStep1() {
+void DataSharingProcessor::processStep1(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
collectSymbolsForPrivatization();
collectDefaultSymbols();
- privatize();
- defaultPrivatize();
+ privatize(clauseOps, privateSyms);
+ defaultPrivatize(clauseOps, privateSyms);
insertBarrier();
}
@@ -299,14 +301,16 @@ void DataSharingProcessor::collectDefaultSymbols() {
}
}
-void DataSharingProcessor::privatize() {
+void DataSharingProcessor::privatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
if (const auto *commonDet =
sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
for (const auto &mem : commonDet->objects())
- doPrivatize(&*mem);
+ doPrivatize(&*mem, clauseOps, privateSyms);
} else
- doPrivatize(sym);
+ doPrivatize(sym, clauseOps, privateSyms);
}
}
@@ -323,7 +327,9 @@ void DataSharingProcessor::copyLastPrivatize(mlir::Operation *op) {
}
}
-void DataSharingProcessor::defaultPrivatize() {
+void DataSharingProcessor::defaultPrivatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
for (const Fortran::semantics::Symbol *sym : defaultSymbols) {
if (!Fortran::semantics::IsProcedure(*sym) &&
!sym->GetUltimate().has<Fortran::semantics::DerivedTypeDetails>() &&
@@ -331,11 +337,14 @@ void DataSharingProcessor::defaultPrivatize() {
!symbolsInNestedRegions.contains(sym) &&
!symbolsInParentRegions.contains(sym) &&
!privatizedSymbols.contains(sym))
- doPrivatize(sym);
+ doPrivatize(sym, clauseOps, privateSyms);
}
}
-void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) {
+void DataSharingProcessor::doPrivatize(
+ const Fortran::semantics::Symbol *sym,
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms) {
if (!useDelayedPrivatization) {
cloneSymbol(sym);
copyFirstPrivateSymbol(sym);
@@ -419,10 +428,13 @@ void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) {
return result;
}();
- delayedPrivatizationInfo.privatizers.push_back(
- mlir::SymbolRefAttr::get(privatizerOp));
- delayedPrivatizationInfo.originalAddresses.push_back(hsb.getAddr());
- delayedPrivatizationInfo.symbols.push_back(sym);
+ if (clauseOps) {
+ clauseOps->privatizers.push_back(mlir::SymbolRefAttr::get(privatizerOp));
+ clauseOps->privateVars.push_back(hsb.getAddr());
+ }
+
+ if (privateSyms)
+ privateSyms->push_back(sym);
}
} // namespace omp
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
index 226abe96705e35..9724b3d5ed02fe 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h
@@ -19,28 +19,17 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
+namespace mlir {
+namespace omp {
+struct PrivateClauseOps;
+} // namespace omp
+} // namespace mlir
+
namespace Fortran {
namespace lower {
namespace omp {
class DataSharingProcessor {
-public:
- /// Collects all the information needed for delayed privatization. This can be
- /// used by ops with data-sharing clauses to properly generate their regions
- /// (e.g. add region arguments) and map the original SSA values to their
- /// corresponding OMP region operands.
- struct DelayedPrivatizationInfo {
- // The list of symbols referring to delayed privatizer ops (i.e.
- // `omp.private` ops).
- llvm::SmallVector<mlir::SymbolRefAttr> privatizers;
- // SSA values that correspond to "original" values being privatized.
- // "Original" here means the SSA value outside the OpenMP region from which
- // a clone is created inside the region.
- llvm::SmallVector<mlir::Value> originalAddresses;
- // Fortran symbols corresponding to the above SSA values.
- llvm::SmallVector<const Fortran::semantics::Symbol *> symbols;
- };
-
private:
bool hasLastPrivateOp;
mlir::OpBuilder::InsertPoint lastPrivIP;
@@ -57,7 +46,6 @@ class DataSharingProcessor {
Fortran::lower::pft::Evaluation &eval;
bool useDelayedPrivatization;
Fortran::lower::SymMap *symTable;
- DelayedPrivatizationInfo delayedPrivatizationInfo;
bool needBarrier();
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
@@ -67,9 +55,16 @@ class DataSharingProcessor {
void collectSymbolsForPrivatization();
void insertBarrier();
void collectDefaultSymbols();
- void privatize();
- void defaultPrivatize();
- void doPrivatize(const Fortran::semantics::Symbol *sym);
+ void privatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
+ void defaultPrivatize(
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
+ void doPrivatize(
+ const Fortran::semantics::Symbol *sym,
+ mlir::omp::PrivateClauseOps *clauseOps,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *privateSyms);
void copyLastPrivatize(mlir::Operation *op);
void insertLastPrivateCompare(mlir::Operation *op);
void cloneSymbol(const Fortran::semantics::Symbol *sym);
@@ -103,17 +98,15 @@ class DataSharingProcessor {
// Step2 performs the copying for lastprivates and requires knowledge of the
// MLIR operation to insert the last private update. Step2 adds
// dealocation code as well.
- void processStep1();
+ void processStep1(mlir::omp::PrivateClauseOps *clauseOps = nullptr,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+ *privateSyms = nullptr);
void processStep2(mlir::Operation *op, bool isLoop);
void setLoopIV(mlir::Value iv) {
assert(!loopIV && "Loop iteration variable already set");
loopIV = iv;
}
-
- const DelayedPrivatizationInfo &getDelayedPrivatizationInfo() const {
- return delayedPrivatizationInfo;
- }
};
} // namespace omp
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 0cf2a8f97040a8..d67060d1cce72b 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -523,19 +523,25 @@ genMasterOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation) {
return genOpWithBody<mlir::omp::MasterOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
- .setGenNested(genNested),
- /*resultTypes=*/mlir::TypeRange());
+ .setGenNested(genNested));
}
static mlir::omp::OrderedRegionOp
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
- mlir::Location currentLocation) {
+ mlir::Location currentLocation,
+ const Fortran::parser::OmpClauseList &clauseList) {
+ mlir::omp::OrderedRegionClauseOps clauseOps;
+
+ ClauseProcessor cp(converter, semaCtx, clauseList);
+ cp.processTODO<clause::Simd>(currentLocation,
+ llvm::omp::Directive::OMPD_ordered);
+
return genOpWithBody<mlir::omp::OrderedRegionOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested),
- /*simd=*/false);
+ clauseOps);
}
static mlir::omp::ParallelOp
@@ -546,77 +552,62 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList,
bool outerCombined = false) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, numThreadsClauseOperand;
- mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
- reductionVars;
+ mlir::omp::ParallelClauseOps clauseOps;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms;
llvm::SmallVector<mlir::Type> reductionTypes;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_parallel, ifClauseOperand);
- cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
- cp.processProcBind(procBindKindAttr);
+ cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
+ cp.processNumThreads(stmtCtx, clauseOps);
+ cp.processProcBind(clauseOps);
cp.processDefault();
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processAllocate(clauseOps);
+
if (!outerCombined)
- cp.processReduction(currentLocation, reductionVars, reductionTypes,
- reductionDeclSymbols, &reductionSymbols);
+ cp.processReduction(currentLocation, clauseOps, &reductionTypes,
+ &reductionSyms);
+
+ if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
+ clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
auto reductionCallback = [&](mlir::Operation *op) {
- llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
+ llvm::SmallVector<mlir::Location> locs(clauseOps.reductionVars.size(),
currentLocation);
- auto *block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {},
- reductionTypes, locs);
+ auto *block =
+ firOpBuilder.createBlock(&op->getRegion(0), {}, reductionTypes, locs);
for (auto [arg, prv] :
- llvm::zip_equal(reductionSymbols, block->getArguments())) {
+ llvm::zip_equal(reductionSyms, block->getArguments())) {
converter.bindSymbol(*arg, prv);
}
- return reductionSymbols;
+ return reductionSyms;
};
- mlir::UnitAttr byrefAttr;
- if (ReductionProcessor::doReductionByRef(reductionVars))
- byrefAttr = converter.getFirOpBuilder().getUnitAttr();
-
OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauseList)
- .setReductions(&reductionSymbols, &reductionTypes)
+ .setReductions(&reductionSyms, &reductionTypes)
.setGenRegionEntryCb(reductionCallback);
- if (!enableDelayedPrivatization) {
- return genOpWithBody<mlir::omp::ParallelOp>(
- genInfo,
- /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
- numThreadsClauseOperand, allocateOperands, allocatorOperands,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- reductionDeclSymbols),
- procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
- /*privatizers=*/nullptr, byrefAttr);
- }
+ if (!enableDelayedPrivatization)
+ return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
bool privatize = !outerCombined;
DataSharingProcessor dsp(converter, semaCtx, clauseList, eval,
/*useDelayedPrivatization=*/true, &symTable);
if (privatize)
- dsp.processStep1();
-
- const auto &delayedPrivatizationInfo = dsp.getDelayedPrivatizationInfo();
+ dsp.processStep1(&clauseOps, &privateSyms);
auto genRegionEntryCB = [&](mlir::Operation *op) {
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
- llvm::SmallVector<mlir::Location> reductionLocs(reductionVars.size(),
- currentLocation);
+ llvm::SmallVector<mlir::Location> reductionLocs(
+ clauseOps.reductionVars.size(), currentLocation);
mlir::OperandRange privateVars = parallelOp.getPrivateVars();
mlir::Region ®ion = parallelOp.getRegion();
@@ -631,12 +622,12 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::transform(privateVars, std::back_inserter(privateVarLocs),
[](mlir::Value v) { return v.getLoc(); });
- converter.getFirOpBuilder().createBlock(®ion, /*insertPt=*/{},
- privateVarTypes, privateVarLocs);
+ firOpBuilder.createBlock(®ion, /*insertPt=*/{}, privateVarTypes,
+ privateVarLocs);
llvm::SmallVector<const Fortran::semantics::Symbol *> allSymbols =
- reductionSymbols;
- allSymbols.append(delayedPrivatizationInfo.symbols);
+ reductionSyms;
+ allSymbols.append(privateSyms);
for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) {
converter.bindSymbol(*arg, prv);
}
@@ -646,26 +637,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
// TODO Merge with the reduction CB.
genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
-
- llvm::SmallVector<mlir::Attribute> privatizers(
- delayedPrivatizationInfo.privatizers.begin(),
- delayedPrivatizationInfo.privatizers.end());
-
- return genOpWithBody<mlir::omp::ParallelOp>(
- genInfo,
- /*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
- numThreadsClauseOperand, allocateOperands, allocatorOperands,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- reductionDeclSymbols),
- procBindKindAttr, delayedPrivatizationInfo.originalAddresses,
- delayedPrivatizationInfo.privatizers.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- privatizers),
- byrefAttr);
+ return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps);
}
static mlir::omp::SectionOp
@@ -689,28 +661,21 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &beginClauseList,
const Fortran::parser::OmpClauseList &endClauseList) {
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
- llvm::SmallVector<mlir::Value> copyPrivateVars;
- llvm::SmallVector<mlir::Attribute> copyPrivateFuncs;
- mlir::UnitAttr nowaitAttr;
+ mlir::omp::SingleClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, beginClauseList);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processAllocate(clauseOps);
+ // TODO Support delayed privatization.
ClauseProcessor ecp(converter, semaCtx, endClauseList);
- ecp.processNowait(nowaitAttr);
- ecp.processCopyPrivate(currentLocation, copyPrivateVars, copyPrivateFuncs);
+ ecp.processNowait(clauseOps);
+ ecp.processCopyprivate(currentLocation, clauseOps);
return genOpWithBody<mlir::omp::SingleOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&beginClauseList),
- allocateOperands, allocatorOperands, copyPrivateVars,
- copyPrivateFuncs.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- copyPrivateFuncs),
- nowaitAttr);
+ clauseOps);
}
static mlir::omp::TaskOp
@@ -720,21 +685,19 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, finalClauseOperand, priorityClauseOperand;
- mlir::UnitAttr untiedAttr, mergeableAttr;
- llvm::SmallVector<mlir::Attribute> dependTypeOperands;
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
- dependOperands;
+ mlir::omp::TaskClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_task, ifClauseOperand);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
+ cp.processAllocate(clauseOps);
cp.processDefault();
- cp.processFinal(stmtCtx, finalClauseOperand);
- cp.processUntied(untiedAttr);
- cp.processMergeable(mergeableAttr);
- cp.processPriority(stmtCtx, priorityClauseOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
+ cp.processFinal(stmtCtx, clauseOps);
+ cp.processUntied(clauseOps);
+ cp.processMergeable(clauseOps);
+ cp.processPriority(stmtCtx, clauseOps);
+ cp.processDepend(clauseOps);
+ // TODO Support delayed privatization.
+
cp.processTODO<clause::InReduction, clause::Detach, clause::Affinity>(
currentLocation, llvm::omp::Directive::OMPD_task);
@@ -742,14 +705,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&clauseList),
- ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr,
- /*in_reduction_vars=*/mlir::ValueRange(),
- /*in_reductions=*/nullptr, priorityClauseOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, allocateOperands, allocatorOperands);
+ clauseOps);
}
static mlir::omp::TaskgroupOp
@@ -758,17 +714,18 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval, bool genNested,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
+ mlir::omp::TaskgroupClauseOps clauseOps;
+
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processAllocate(clauseOps);
cp.processTODO<clause::TaskReduction>(currentLocation,
llvm::omp::Directive::OMPD_taskgroup);
+
return genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(genNested)
.setClauses(&clauseList),
- /*task_reduction_vars=*/mlir::ValueRange(),
- /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
+ clauseOps);
}
// This helper function implements the functionality of "promoting"
@@ -789,8 +746,7 @@ genTaskgroupOp(Fortran::lower::AbstractConverter &converter,
// clause. Support for such list items in a use_device_ptr clause
// is deprecated."
static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
- llvm::SmallVectorImpl<mlir::Value> &devicePtrOperands,
- llvm::SmallVectorImpl<mlir::Value> &deviceAddrOperands,
+ mlir::omp::UseDeviceClauseOps &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
@@ -803,9 +759,10 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
// Iterate over our use_device_ptr list and shift all non-cptr arguments into
// use_device_addr.
- for (auto *it = devicePtrOperands.begin(); it != devicePtrOperands.end();) {
+ for (auto *it = clauseOps.useDevicePtrVars.begin();
+ it != clauseOps.useDevicePtrVars.end();) {
if (!fir::isa_builtin_cptr_type(fir::unwrapRefType(it->getType()))) {
- deviceAddrOperands.push_back(*it);
+ clauseOps.useDeviceAddrVars.push_back(*it);
// We have to shuffle the symbols around as well, to maintain
// the correct Input -> BlockArg for use_device_ptr/use_device_addr.
// NOTE: However, as map's do not seem to be included currently
@@ -813,11 +770,11 @@ static void promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
// future alterations. I believe the reason they are not currently
// is that the BlockArg assign/lowering needs to be extended
// to a greater set of types.
- auto idx = std::distance(devicePtrOperands.begin(), it);
+ auto idx = std::distance(clauseOps.useDevicePtrVars.begin(), it);
moveElementToBack(idx, useDeviceTypes);
moveElementToBack(idx, useDeviceLocs);
moveElementToBack(idx, useDeviceSymbols);
- it = devicePtrOperands.erase(it);
+ it = clauseOps.useDevicePtrVars.erase(it);
continue;
}
++it;
@@ -831,20 +788,19 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
mlir::Location currentLocation,
const Fortran::parser::OmpClauseList &clauseList) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, deviceOperand;
- llvm::SmallVector<mlir::Value> mapOperands, devicePtrOperands,
- deviceAddrOperands;
+ mlir::omp::TargetDataClauseOps clauseOps;
llvm::SmallVector<mlir::Type> useDeviceTypes;
llvm::SmallVector<mlir::Location> useDeviceLocs;
- llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_target_data, ifClauseOperand);
- cp.processDevice(stmtCtx, deviceOperand);
- cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
- useDeviceSymbols);
- cp.processUseDeviceAddr(deviceAddrOperands, useDeviceTypes, useDeviceLocs,
- useDeviceSymbols);
+ cp.processIf(llvm::omp::Directive::OMPD_target_data, clauseOps);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processUseDevicePtr(clauseOps, useDeviceTypes, useDeviceLocs,
+ useDeviceSyms);
+ cp.processUseDeviceAddr(clauseOps, useDeviceTypes, useDeviceLocs,
+ useDeviceSyms);
+
// This function implements the deprecated functionality of use_device_ptr
// that allows users to provide non-CPTR arguments to it with the caveat
// that the compiler will treat them as use_device_addr. A lot of legacy
@@ -856,17 +812,16 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter,
// ordering.
// TODO: Perhaps create a user provideable compiler option that will
// re-introduce a hard-error rather than a warning in these cases.
- promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(
- devicePtrOperands, deviceAddrOperands, useDeviceTypes, useDeviceLocs,
- useDeviceSymbols);
+ promoteNonCPtrUseDevicePtrArgsToUseDeviceAddr(clauseOps, useDeviceTypes,
+ useDeviceLocs, useDeviceSyms);
cp.processMap(currentLocation, llvm::omp::Directive::OMPD_target_data,
- stmtCtx, mapOperands);
+ stmtCtx, clauseOps);
auto dataOp = converter.getFirOpBuilder().create<mlir::omp::TargetDataOp>(
- currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
- deviceAddrOperands, mapOperands);
+ currentLocation, clauseOps);
+
genBodyOfTargetDataOp(converter, semaCtx, eval, genNested, dataOp,
- useDeviceTypes, useDeviceLocs, useDeviceSymbols,
+ useDeviceTypes, useDeviceLocs, useDeviceSyms,
currentLocation);
return dataOp;
}
@@ -879,10 +834,7 @@ static OpTy genTargetEnterExitDataUpdateOp(
const Fortran::parser::OmpClauseList &clauseList) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, deviceOperand;
- mlir::UnitAttr nowaitAttr;
- llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
- llvm::SmallVector<mlir::Attribute> dependTypeOperands;
+ mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps;
// GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
[[maybe_unused]] llvm::omp::Directive directive;
@@ -897,25 +849,19 @@ static OpTy genTargetEnterExitDataUpdateOp(
}
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(directive, ifClauseOperand);
- cp.processDevice(stmtCtx, deviceOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
- cp.processNowait(nowaitAttr);
+ cp.processIf(directive, clauseOps);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processDepend(clauseOps);
+ cp.processNowait(clauseOps);
if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) {
- cp.processMotionClauses<clause::To>(stmtCtx, mapOperands);
- cp.processMotionClauses<clause::From>(stmtCtx, mapOperands);
+ cp.processMotionClauses<clause::To>(stmtCtx, clauseOps);
+ cp.processMotionClauses<clause::From>(stmtCtx, clauseOps);
} else {
- cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
+ cp.processMap(currentLocation, directive, stmtCtx, clauseOps);
}
- return firOpBuilder.create<OpTy>(
- currentLocation, ifClauseOperand, deviceOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, nowaitAttr, mapOperands);
+ return firOpBuilder.create<OpTy>(currentLocation, clauseOps);
}
// This functions creates a block for the body of the targetOp's region. It adds
@@ -925,9 +871,9 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval, bool genNested,
mlir::omp::TargetOp &targetOp,
- llvm::ArrayRef<mlir::Type> mapSymTypes,
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSyms,
llvm::ArrayRef<mlir::Location> mapSymLocs,
- llvm::ArrayRef<const Fortran::semantics::Symbol *> mapSymbols,
+ llvm::ArrayRef<mlir::Type> mapSymTypes,
const mlir::Location ¤tLocation) {
assert(mapSymTypes.size() == mapSymLocs.size());
@@ -956,7 +902,7 @@ genBodyOfTargetOp(Fortran::lower::AbstractConverter &converter,
};
// Bind the symbols to their corresponding block arguments.
- for (auto [argIndex, argSymbol] : llvm::enumerate(mapSymbols)) {
+ for (auto [argIndex, argSymbol] : llvm::enumerate(mapSyms)) {
const mlir::BlockArgument &arg = region.getArgument(argIndex);
// Avoid capture of a reference to a structured binding.
const Fortran::semantics::Symbol *sym = argSymbol;
@@ -1080,22 +1026,20 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &clauseList,
llvm::omp::Directive directive, bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
- mlir::UnitAttr nowaitAttr;
- llvm::SmallVector<mlir::Attribute> dependTypeOperands;
- llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
- llvm::SmallVector<mlir::Type> mapSymTypes;
+ mlir::omp::TargetClauseOps clauseOps;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms;
llvm::SmallVector<mlir::Location> mapSymLocs;
- llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
+ llvm::SmallVector<mlir::Type> mapSymTypes;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_target, ifClauseOperand);
- cp.processDevice(stmtCtx, deviceOperand);
- cp.processThreadLimit(stmtCtx, threadLimitOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
- cp.processNowait(nowaitAttr);
- cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
- &mapSymLocs, &mapSymbols);
+ cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
+ cp.processDevice(stmtCtx, clauseOps);
+ cp.processThreadLimit(stmtCtx, clauseOps);
+ cp.processDepend(clauseOps);
+ cp.processNowait(clauseOps);
+ cp.processMap(currentLocation, directive, stmtCtx, clauseOps, &mapSyms,
+ &mapSymLocs, &mapSymTypes);
+ // TODO Support delayed privatization.
cp.processTODO<clause::Private, clause::Firstprivate, clause::IsDevicePtr,
clause::HasDeviceAddr, clause::Reduction, clause::InReduction,
@@ -1107,7 +1051,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
// symbols used inside the region that have not been explicitly mapped using
// the map clause.
auto captureImplicitMap = [&](const Fortran::semantics::Symbol &sym) {
- if (llvm::find(mapSymbols, &sym) == mapSymbols.end()) {
+ if (llvm::find(mapSyms, &sym) == mapSyms.end()) {
mlir::Value baseOp = converter.getSymbolAddress(sym);
if (!baseOp)
if (const auto *details = sym.template detailsIf<
@@ -1178,25 +1122,20 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
mapFlag),
captureKind, baseOp.getType());
- mapOperands.push_back(mapOp);
- mapSymTypes.push_back(baseOp.getType());
+ clauseOps.mapVars.push_back(mapOp);
+ mapSyms.push_back(&sym);
mapSymLocs.push_back(baseOp.getLoc());
- mapSymbols.push_back(&sym);
+ mapSymTypes.push_back(baseOp.getType());
}
}
};
Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
- currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, nowaitAttr, mapOperands);
+ currentLocation, clauseOps);
- genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
- mapSymLocs, mapSymbols, currentLocation);
+ genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSyms,
+ mapSymLocs, mapSymTypes, currentLocation);
return targetOp;
}
@@ -1209,17 +1148,16 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpClauseList &clauseList,
bool outerCombined = false) {
Fortran::lower::StatementContext stmtCtx;
- mlir::Value numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand;
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
- reductionVars;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
+ mlir::omp::TeamsClauseOps clauseOps;
ClauseProcessor cp(converter, semaCtx, clauseList);
- cp.processIf(llvm::omp::Directive::OMPD_teams, ifClauseOperand);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
+ cp.processAllocate(clauseOps);
cp.processDefault();
- cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
- cp.processThreadLimit(stmtCtx, threadLimitClauseOperand);
+ cp.processNumTeams(stmtCtx, clauseOps);
+ cp.processThreadLimit(stmtCtx, clauseOps);
+ // TODO Support delayed privatization.
+
cp.processTODO<clause::Reduction>(currentLocation,
llvm::omp::Directive::OMPD_teams);
@@ -1228,30 +1166,20 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
.setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauseList),
- /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
- threadLimitClauseOperand, allocateOperands, allocatorOperands,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- reductionDeclSymbols));
+ clauseOps);
}
/// Extract the list of function and variable symbols affected by the given
/// 'declare target' directive and return the intended device type for them.
-static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
+static void getDeclareTargetInfo(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
+ mlir::omp::DeclareTargetClauseOps &clauseOps,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
-
- // The default capture type
- mlir::omp::DeclareTargetDeviceType deviceType =
- mlir::omp::DeclareTargetDeviceType::any;
const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
declareTargetConstruct.t);
-
if (const auto *objectList{
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
ObjectList objects{makeList(*objectList, semaCtx)};
@@ -1272,12 +1200,10 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
cp.processTo(symbolAndClause);
cp.processEnter(symbolAndClause);
cp.processLink(symbolAndClause);
- cp.processDeviceType(deviceType);
+ cp.processDeviceType(clauseOps);
cp.processTODO<clause::Indirect>(converter.getCurrentLocation(),
llvm::omp::Directive::OMPD_declare_target);
}
-
- return deviceType;
}
static void collectDeferredDeclareTargets(
@@ -1287,9 +1213,10 @@ static void collectDeferredDeclareTargets(
const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
llvm::SmallVectorImpl<Fortran::lower::OMPDeferredDeclareTargetInfo>
&deferredDeclareTarget) {
+ mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
- mlir::omp::DeclareTargetDeviceType devType = getDeclareTargetInfo(
- converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
+ getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
+ clauseOps, symbolAndClause);
// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
@@ -1299,8 +1226,9 @@ static void collectDeferredDeclareTargets(
std::get<const Fortran::semantics::Symbol &>(symClause)));
if (!op) {
- deferredDeclareTarget.push_back(
- {std::get<0>(symClause), devType, std::get<1>(symClause)});
+ deferredDeclareTarget.push_back({std::get<0>(symClause),
+ clauseOps.deviceType,
+ std::get<1>(symClause)});
}
}
}
@@ -1312,9 +1240,10 @@ getDeclareTargetFunctionDevice(
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
+ mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
- mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
- converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
+ getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
+ clauseOps, symbolAndClause);
// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
@@ -1324,7 +1253,7 @@ getDeclareTargetFunctionDevice(
std::get<const Fortran::semantics::Symbol &>(symClause)));
if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op))
- return deviceType;
+ return clauseOps.deviceType;
}
return std::nullopt;
@@ -1354,12 +1283,14 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_barrier:
firOpBuilder.create<mlir::omp::BarrierOp>(currentLocation);
break;
- case llvm::omp::Directive::OMPD_taskwait:
- ClauseProcessor(converter, semaCtx, opClauseList)
- .processTODO<clause::Depend, clause::Nowait>(
- currentLocation, llvm::omp::Directive::OMPD_taskwait);
- firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation);
+ case llvm::omp::Directive::OMPD_taskwait: {
+ mlir::omp::TaskwaitClauseOps clauseOps;
+ ClauseProcessor cp(converter, semaCtx, opClauseList);
+ cp.processTODO<clause::Depend, clause::Nowait>(
+ currentLocation, llvm::omp::Directive::OMPD_taskwait);
+ firOpBuilder.create<mlir::omp::TaskwaitOp>(currentLocation, clauseOps);
break;
+ }
case llvm::omp::Directive::OMPD_taskyield:
firOpBuilder.create<mlir::omp::TaskyieldOp>(currentLocation);
break;
@@ -1494,32 +1425,21 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value scheduleChunkClauseOperand, ifClauseOperand;
- llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
- llvm::SmallVector<mlir::Value> alignedVars, nontemporalVars;
+ mlir::omp::SimdLoopClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
- llvm::SmallVector<mlir::Type> reductionTypes;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- mlir::omp::ClauseOrderKindAttr orderClauseOperand;
- mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand;
ClauseProcessor cp(converter, semaCtx, loopOpClauseList);
- cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
- cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
- cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols);
- cp.processIf(llvm::omp::Directive::OMPD_simd, ifClauseOperand);
- cp.processSimdlen(simdlenClauseOperand);
- cp.processSafelen(safelenClauseOperand);
+ cp.processCollapse(loc, eval, clauseOps, iv);
+ cp.processReduction(loc, clauseOps);
+ cp.processIf(llvm::omp::Directive::OMPD_simd, clauseOps);
+ cp.processSimdlen(clauseOps);
+ cp.processSafelen(clauseOps);
+ clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
+ // TODO Support delayed privatization.
+
cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear,
clause::Nontemporal, clause::Order>(loc, ompDirective);
- mlir::TypeRange resultType;
- auto simdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
- loc, resultType, lowerBound, upperBound, step, alignedVars,
- /*alignment_values=*/nullptr, ifClauseOperand, nontemporalVars,
- orderClauseOperand, simdlenClauseOperand, safelenClauseOperand,
- /*inclusive=*/firOpBuilder.getUnitAttr());
-
auto *nestedEval = getCollapsedLoopEval(
eval, Fortran::lower::getCollapseValue(loopOpClauseList));
@@ -1527,11 +1447,12 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
return genLoopVars(op, converter, loc, iv);
};
- createBodyOfOp<mlir::omp::SimdLoopOp>(
- simdLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&loopOpClauseList)
- .setDataSharingProcessor(&dsp)
- .setGenRegionEntryCb(ivCallback));
+ genOpWithBody<mlir::omp::SimdLoopOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
+ .setClauses(&loopOpClauseList)
+ .setDataSharingProcessor(&dsp)
+ .setGenRegionEntryCb(ivCallback),
+ clauseOps);
}
static void createWsloop(Fortran::lower::AbstractConverter &converter,
@@ -1546,77 +1467,50 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
dsp.processStep1();
Fortran::lower::StatementContext stmtCtx;
- mlir::Value scheduleChunkClauseOperand;
- llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
- llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
+ mlir::omp::WsloopClauseOps clauseOps;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
llvm::SmallVector<mlir::Type> reductionTypes;
- llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
- llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
- mlir::omp::ClauseOrderKindAttr orderClauseOperand;
- mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
- mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
- mlir::IntegerAttr orderedClauseOperand;
- mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
ClauseProcessor cp(converter, semaCtx, beginClauseList);
- cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
- cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
- cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols,
- &reductionSymbols);
- cp.processTODO<clause::Linear, clause::Order>(loc, ompDirective);
-
- if (ReductionProcessor::doReductionByRef(reductionVars))
- byrefOperand = firOpBuilder.getUnitAttr();
-
- auto wsLoopOp = firOpBuilder.create<mlir::omp::WsloopOp>(
- loc, lowerBound, upperBound, step, linearVars, linearStepVars,
- reductionVars,
- reductionDeclSymbols.empty()
- ? nullptr
- : mlir::ArrayAttr::get(firOpBuilder.getContext(),
- reductionDeclSymbols),
- scheduleValClauseOperand, scheduleChunkClauseOperand,
- /*schedule_modifiers=*/nullptr,
- /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
- orderedClauseOperand, orderClauseOperand,
- /*inclusive=*/firOpBuilder.getUnitAttr());
-
- // Handle attribute based clauses.
- if (cp.processOrdered(orderedClauseOperand))
- wsLoopOp.setOrderedValAttr(orderedClauseOperand);
-
- if (cp.processSchedule(scheduleValClauseOperand, scheduleModClauseOperand,
- scheduleSimdClauseOperand)) {
- wsLoopOp.setScheduleValAttr(scheduleValClauseOperand);
- wsLoopOp.setScheduleModifierAttr(scheduleModClauseOperand);
- wsLoopOp.setSimdModifierAttr(scheduleSimdClauseOperand);
- }
+ cp.processCollapse(loc, eval, clauseOps, iv);
+ cp.processSchedule(stmtCtx, clauseOps);
+ cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);
+ cp.processOrdered(clauseOps);
+ clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr();
+ // TODO Support delayed privatization.
+
+ if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars))
+ clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr();
+
+ cp.processTODO<clause::Allocate, clause::Linear, clause::Order>(loc,
+ ompDirective);
+
// In FORTRAN `nowait` clause occur at the end of `omp do` directive.
// i.e
// !$omp do
// <...>
// !$omp end do nowait
if (endClauseList) {
- if (ClauseProcessor(converter, semaCtx, *endClauseList)
- .processNowait(nowaitClauseOperand))
- wsLoopOp.setNowaitAttr(nowaitClauseOperand);
+ ClauseProcessor ecp(converter, semaCtx, *endClauseList);
+ ecp.processNowait(clauseOps);
}
auto *nestedEval = getCollapsedLoopEval(
eval, Fortran::lower::getCollapseValue(beginClauseList));
auto ivCallback = [&](mlir::Operation *op) {
- return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols,
+ return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms,
reductionTypes);
};
- createBodyOfOp<mlir::omp::WsloopOp>(
- wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&beginClauseList)
- .setDataSharingProcessor(&dsp)
- .setReductions(&reductionSymbols, &reductionTypes)
- .setGenRegionEntryCb(ivCallback));
+ genOpWithBody<mlir::omp::WsloopOp>(
+ OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
+ .setClauses(&beginClauseList)
+ .setDataSharingProcessor(&dsp)
+ .setReductions(&reductionSyms, &reductionTypes)
+ .setGenRegionEntryCb(ivCallback),
+ clauseOps);
}
static void createSimdWsloop(
@@ -1704,10 +1598,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPDeclareTargetConstruct
&declareTargetConstruct) {
+ mlir::omp::DeclareTargetClauseOps clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
- mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
- converter, semaCtx, eval, declareTargetConstruct, symbolAndClause);
+ getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
+ clauseOps, symbolAndClause);
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
mlir::Operation *op = mod.lookupSymbol(converter.mangleName(
@@ -1721,7 +1616,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
markDeclareTarget(
op, converter,
- std::get<mlir::omp::DeclareTargetCaptureClause>(symClause), deviceType);
+ std::get<mlir::omp::DeclareTargetCaptureClause>(symClause),
+ clauseOps.deviceType);
}
}
@@ -1853,7 +1749,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
!std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
!std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
- !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
+ !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u) &&
+ !std::get_if<Fortran::parser::OmpClause::Simd>(&clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
}
}
@@ -1873,7 +1770,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
break;
case llvm::omp::Directive::OMPD_ordered:
genOrderedRegionOp(converter, semaCtx, eval, /*genNested=*/true,
- currentLocation);
+ currentLocation, beginClauseList);
break;
case llvm::omp::Directive::OMPD_parallel:
genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true,
@@ -1964,7 +1861,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Location currentLocation = converter.getCurrentLocation();
- mlir::IntegerAttr hintClauseOp;
std::string name;
const Fortran::parser::OmpCriticalDirective &cd =
std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
@@ -1973,21 +1869,28 @@ genOMP(Fortran::lower::AbstractConverter &converter,
std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString();
}
- const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
- ClauseProcessor(converter, semaCtx, clauseList).processHint(hintClauseOp);
-
mlir::omp::CriticalOp criticalOp = [&]() {
if (name.empty()) {
return firOpBuilder.create<mlir::omp::CriticalOp>(
currentLocation, mlir::FlatSymbolRefAttr());
}
+
mlir::ModuleOp module = firOpBuilder.getModule();
mlir::OpBuilder modBuilder(module.getBodyRegion());
auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name);
- if (!global)
- global = modBuilder.create<mlir::omp::CriticalDeclareOp>(
- currentLocation,
- mlir::StringAttr::get(firOpBuilder.getContext(), name), hintClauseOp);
+ if (!global) {
+ mlir::omp::CriticalClauseOps clauseOps;
+ const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
+
+ ClauseProcessor cp(converter, semaCtx, clauseList);
+ cp.processHint(clauseOps);
+ clauseOps.nameAttr =
+ mlir::StringAttr::get(firOpBuilder.getContext(), name);
+
+ global = modBuilder.create<mlir::omp::CriticalDeclareOp>(currentLocation,
+ clauseOps);
+ }
+
return firOpBuilder.create<mlir::omp::CriticalOp>(
currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(),
global.getSymName()));
@@ -2104,8 +2007,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) {
mlir::Location currentLocation = converter.getCurrentLocation();
- llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands;
- mlir::UnitAttr nowaitClauseOperand;
+ mlir::omp::SectionsClauseOps clauseOps;
const auto &beginSectionsDirective =
std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t);
const auto §ionsClauseList =
@@ -2114,8 +2016,9 @@ genOMP(Fortran::lower::AbstractConverter &converter,
// Process clauses before optional omp.parallel, so that new variables are
// allocated outside of the parallel region
ClauseProcessor cp(converter, semaCtx, sectionsClauseList);
- cp.processSectionsReduction(currentLocation);
- cp.processAllocate(allocatorOperands, allocateOperands);
+ cp.processSectionsReduction(currentLocation, clauseOps);
+ cp.processAllocate(clauseOps);
+ // TODO Support delayed privatization.
llvm::omp::Directive dir =
std::get<Fortran::parser::OmpSectionsDirective>(beginSectionsDirective.t)
@@ -2132,16 +2035,14 @@ genOMP(Fortran::lower::AbstractConverter &converter,
const auto &endSectionsClauseList =
std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t);
ClauseProcessor(converter, semaCtx, endSectionsClauseList)
- .processNowait(nowaitClauseOperand);
+ .processNowait(clauseOps);
}
// SECTIONS construct
genOpWithBody<mlir::omp::SectionsOp>(
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
.setGenNested(false),
- /*reduction_vars=*/mlir::ValueRange(),
- /*reductions=*/nullptr, allocateOperands, allocatorOperands,
- nowaitClauseOperand);
+ clauseOps);
const auto §ionBlocks =
std::get<Fortran::parser::OmpSectionBlocks>(sectionsConstruct.t);
More information about the llvm-branch-commits
mailing list