[llvm-branch-commits] [flang] [Flang][OpenMP][Lower] Use clause operand structures (PR #86802)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Mar 27 05:55:40 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Sergio Afonso (skatrak)
<details>
<summary>Changes</summary>
This patch updates Flang lowering to use the new set of OpenMP clause operand structures and their groupings into directive-specific sets of clause operands.
It simplifies the passing of information from the clause processor and the creation of operations.
The `DataSharingProcessor` is slightly modified to not hold delayed privatization state. Instead, optional arguments are added to `processStep1` which are only passed when delayed privatization is used. This enables using the clause operand structure for `private` and removes the need for the ad-hoc `DelayedPrivatizationInfo` structure.
The processing of the `schedule` clause is updated to process the `chunk` modifier rather than requiring two separate calls to the `ClauseProcessor`.
Lowering of a block-associated `ordered` construct is updated to emit a TODO error if the `simd` clause is specified, since it is not currently supported by the `ClauseProcessor` or later compilation stages.
Removed processing of `schedule` from `omp.simdloop`, as it doesn't apply to `simd` constructs.
---
Patch is 85.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86802.diff
5 Files Affected:
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+128-133)
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+47-58)
- (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+25-13)
- (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.h (+19-26)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+209-308)
``````````diff
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 0a57a1496289f4..ee1f6c2fbc7e89 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -162,14 +162,13 @@ getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
ifVal);
}
-static void
-addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
- const omp::ObjectList &objects,
- llvm::SmallVectorImpl<mlir::Value> &operands,
- llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
- llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
- &useDeviceSymbols) {
+static void addUseDeviceClause(
+ Fortran::lower::AbstractConverter &converter,
+ const omp::ObjectList &objects,
+ llvm::SmallVectorImpl<mlir::Value> &operands,
+ llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
+ llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) {
genObjectList(objects, converter, operands);
for (mlir::Value &operand : operands) {
checkMapType(operand.getLoc(), operand.getType());
@@ -177,25 +176,24 @@ addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
useDeviceLocs.push_back(operand.getLoc());
}
for (const omp::Object &object : objects)
- useDeviceSymbols.push_back(object.id());
+ useDeviceSyms.push_back(object.id());
}
static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
mlir::Location loc,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
+ mlir::omp::CollapseClauseOps &result,
std::size_t loopVarTypeSize) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
// The types of lower bound, upper bound, and step are converted into the
// type of the loop variable if necessary.
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
- lowerBound[it] =
- firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]);
- upperBound[it] =
- firOpBuilder.createConvert(loc, loopVarType, upperBound[it]);
- step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]);
+ for (unsigned it = 0; it < (unsigned)result.loopLBVar.size(); it++) {
+ result.loopLBVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopLBVar[it]);
+ result.loopUBVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopUBVar[it]);
+ result.loopStepVar[it] =
+ firOpBuilder.createConvert(loc, loopVarType, result.loopStepVar[it]);
}
}
@@ -205,9 +203,7 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
bool ClauseProcessor::processCollapse(
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
- llvm::SmallVectorImpl<mlir::Value> &lowerBound,
- llvm::SmallVectorImpl<mlir::Value> &upperBound,
- llvm::SmallVectorImpl<mlir::Value> &step,
+ mlir::omp::CollapseClauseOps &result,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const {
bool found = false;
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -238,15 +234,15 @@ bool ClauseProcessor::processCollapse(
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
assert(bounds && "Expected bounds for worksharing do loop");
Fortran::lower::StatementContext stmtCtx;
- lowerBound.push_back(fir::getBase(converter.genExprValue(
+ result.loopLBVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
- upperBound.push_back(fir::getBase(converter.genExprValue(
+ result.loopUBVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
if (bounds->step) {
- step.push_back(fir::getBase(converter.genExprValue(
+ result.loopStepVar.push_back(fir::getBase(converter.genExprValue(
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
} else { // If `step` is not present, assume it as `1`.
- step.push_back(firOpBuilder.createIntegerConstant(
+ result.loopStepVar.push_back(firOpBuilder.createIntegerConstant(
currentLocation, firOpBuilder.getIntegerType(32), 1));
}
iv.push_back(bounds->name.thing.symbol);
@@ -257,8 +253,7 @@ bool ClauseProcessor::processCollapse(
&*std::next(doConstructEval->getNestedEvaluations().begin());
} while (collapseValue > 0);
- convertLoopBounds(converter, currentLocation, lowerBound, upperBound, step,
- loopVarTypeSize);
+ convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
return found;
}
@@ -286,7 +281,7 @@ bool ClauseProcessor::processDefault() const {
}
bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+ mlir::omp::DeviceClauseOps &result) const {
const Fortran::parser::CharBlock *source = nullptr;
if (auto *clause = findUniqueClause<omp::clause::Device>(&source)) {
mlir::Location clauseLocation = converter.genLocation(*source);
@@ -298,25 +293,26 @@ bool ClauseProcessor::processDevice(Fortran::lower::StatementContext &stmtCtx,
}
}
const auto &deviceExpr = std::get<omp::SomeExpr>(clause->t);
- result = fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
+ result.deviceVar =
+ fir::getBase(converter.genExprValue(deviceExpr, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processDeviceType(
- mlir::omp::DeclareTargetDeviceType &result) const {
+ mlir::omp::DeviceTypeClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::DeviceType>()) {
// Case: declare target ... device_type(any | host | nohost)
switch (clause->v) {
case omp::clause::DeviceType::DeviceTypeDescription::Nohost:
- result = mlir::omp::DeclareTargetDeviceType::nohost;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::nohost;
break;
case omp::clause::DeviceType::DeviceTypeDescription::Host:
- result = mlir::omp::DeclareTargetDeviceType::host;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::host;
break;
case omp::clause::DeviceType::DeviceTypeDescription::Any:
- result = mlir::omp::DeclareTargetDeviceType::any;
+ result.deviceType = mlir::omp::DeclareTargetDeviceType::any;
break;
}
return true;
@@ -325,7 +321,7 @@ bool ClauseProcessor::processDeviceType(
}
bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+ mlir::omp::FinalClauseOps &result) const {
const Fortran::parser::CharBlock *source = nullptr;
if (auto *clause = findUniqueClause<omp::clause::Final>(&source)) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -333,100 +329,108 @@ bool ClauseProcessor::processFinal(Fortran::lower::StatementContext &stmtCtx,
mlir::Value finalVal =
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
- result = firOpBuilder.createConvert(clauseLocation,
- firOpBuilder.getI1Type(), finalVal);
+ result.finalVar = firOpBuilder.createConvert(
+ clauseLocation, firOpBuilder.getI1Type(), finalVal);
return true;
}
return false;
}
-bool ClauseProcessor::processHint(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Hint>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
int64_t hintValue = *Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(hintValue);
+ result.hintAttr = firOpBuilder.getI64IntegerAttr(hintValue);
return true;
}
return false;
}
-bool ClauseProcessor::processMergeable(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Mergeable>(result);
+bool ClauseProcessor::processMergeable(
+ mlir::omp::MergeableClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Mergeable>(result.mergeableAttr);
}
-bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Nowait>(result);
+bool ClauseProcessor::processNowait(mlir::omp::NowaitClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Nowait>(result.nowaitAttr);
}
-bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+bool ClauseProcessor::processNumTeams(
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::NumTeamsClauseOps &result) const {
// TODO Get lower and upper bounds for num_teams when parser is updated to
// accept both.
if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
// auto lowerBound = std::get<std::optional<ExprTy>>(clause->t);
auto &upperBound = std::get<ExprTy>(clause->t);
- result = fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+ result.numTeamsUpperVar =
+ fir::getBase(converter.genExprValue(upperBound, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processNumThreads(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::NumThreadsClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.numThreadsVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
-bool ClauseProcessor::processOrdered(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processOrdered(
+ mlir::omp::OrderedClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Ordered>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
int64_t orderedClauseValue = 0l;
if (clause->v.has_value())
orderedClauseValue = *Fortran::evaluate::ToInt64(*clause->v);
- result = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
+ result.orderedAttr = firOpBuilder.getI64IntegerAttr(orderedClauseValue);
return true;
}
return false;
}
-bool ClauseProcessor::processPriority(Fortran::lower::StatementContext &stmtCtx,
- mlir::Value &result) const {
+bool ClauseProcessor::processPriority(
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::PriorityClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Priority>()) {
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.priorityVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
bool ClauseProcessor::processProcBind(
- mlir::omp::ClauseProcBindKindAttr &result) const {
+ mlir::omp::ProcBindClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::ProcBind>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- result = genProcBindKindAttr(firOpBuilder, *clause);
+ result.procBindKindAttr = genProcBindKindAttr(firOpBuilder, *clause);
return true;
}
return false;
}
-bool ClauseProcessor::processSafelen(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processSafelen(
+ mlir::omp::SafelenClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
const std::optional<std::int64_t> safelenVal =
Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(*safelenVal);
+ result.safelenAttr = firOpBuilder.getI64IntegerAttr(*safelenVal);
return true;
}
return false;
}
bool ClauseProcessor::processSchedule(
- mlir::omp::ClauseScheduleKindAttr &valAttr,
- mlir::omp::ScheduleModifierAttr &modifierAttr,
- mlir::UnitAttr &simdModifierAttr) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::ScheduleClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::MLIRContext *context = firOpBuilder.getContext();
@@ -451,53 +455,51 @@ bool ClauseProcessor::processSchedule(
break;
}
- mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
+ result.scheduleValAttr =
+ mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
+ mlir::omp::ScheduleModifier scheduleModifier = getScheduleModifier(*clause);
if (scheduleModifier != mlir::omp::ScheduleModifier::none)
- modifierAttr =
+ result.scheduleModAttr =
mlir::omp::ScheduleModifierAttr::get(context, scheduleModifier);
if (getSimdModifier(*clause) != mlir::omp::ScheduleModifier::none)
- simdModifierAttr = firOpBuilder.getUnitAttr();
+ result.scheduleSimdAttr = firOpBuilder.getUnitAttr();
- valAttr = mlir::omp::ClauseScheduleKindAttr::get(context, scheduleKind);
- return true;
- }
- return false;
-}
-
-bool ClauseProcessor::processScheduleChunk(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
- if (auto *clause = findUniqueClause<omp::clause::Schedule>()) {
if (const auto &chunkExpr = std::get<omp::MaybeExpr>(clause->t))
- result = fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
+ result.scheduleChunkVar =
+ fir::getBase(converter.genExprValue(*chunkExpr, stmtCtx));
+
return true;
}
return false;
}
-bool ClauseProcessor::processSimdlen(mlir::IntegerAttr &result) const {
+bool ClauseProcessor::processSimdlen(
+ mlir::omp::SimdlenClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Simdlen>()) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
const std::optional<std::int64_t> simdlenVal =
Fortran::evaluate::ToInt64(clause->v);
- result = firOpBuilder.getI64IntegerAttr(*simdlenVal);
+ result.simdlenAttr = firOpBuilder.getI64IntegerAttr(*simdlenVal);
return true;
}
return false;
}
bool ClauseProcessor::processThreadLimit(
- Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
+ Fortran::lower::StatementContext &stmtCtx,
+ mlir::omp::ThreadLimitClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
- result = fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.threadLimitVar =
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx));
return true;
}
return false;
}
-bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
- return markClauseOccurrence<omp::clause::Untied>(result);
+bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const {
+ return markClauseOccurrence<omp::clause::Untied>(result.untiedAttr);
}
//===----------------------------------------------------------------------===//
@@ -505,13 +507,12 @@ bool ClauseProcessor::processUntied(mlir::UnitAttr &result) const {
//===----------------------------------------------------------------------===//
bool ClauseProcessor::processAllocate(
- llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
- llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const {
+ mlir::omp::AllocateClauseOps &result) const {
return findRepeatableClause<omp::clause::Allocate>(
[&](const omp::clause::Allocate &clause,
const Fortran::parser::CharBlock &) {
- genAllocateClause(converter, clause, allocatorOperands,
- allocateOperands);
+ genAllocateClause(converter, clause, result.allocatorVars,
+ result.allocateVars);
});
}
@@ -660,10 +661,9 @@ createCopyFunc(mlir::Location loc, Fortran::lower::AbstractConverter &converter,
return funcOp;
}
-bool ClauseProcessor::processCopyPrivate(
+bool ClauseProcessor::processCopyprivate(
mlir::Location currentLocation,
- llvm::SmallVectorImpl<mlir::Value> ©PrivateVars,
- llvm::SmallVectorImpl<mlir::Attribute> ©PrivateFuncs) const {
+ mlir::omp::CopyprivateClauseOps &result) const {
auto addCopyPrivateVar = [&](Fortran::semantics::Symbol *sym) {
mlir::Value symVal = converter.getSymbolAddress(*sym);
auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>();
@@ -690,10 +690,10 @@ bool ClauseProcessor::processCopyPrivate(
cpVar = alloca;
}
- copyPrivateVars.push_back(cpVar);
+ result.copyprivateVars.push_back(cpVar);
mlir::func::FuncOp funcOp =
createCopyFunc(currentLocation, converter, cpVar.getType(), attrs);
- copyPrivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
+ result.copyprivateFuncs.push_back(mlir::SymbolRefAttr::get(funcOp));
};
bool hasCopyPrivate = findRepeatableClause<clause::Copyprivate>(
@@ -714,9 +714,7 @@ bool ClauseProcessor::processCopyPrivate(
return hasCopyPrivate;
}
-bool ClauseProcessor::processDepend(
- llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
- llvm::SmallVectorImpl<mlir::Value> &dependOperands) const {
+bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause<omp::clause::Depend>(
@@ -731,7 +729,7 @@ bool ClauseProcessor::processDepend(
mlir::omp::ClauseTaskDependAttr dependTypeOperand =
genDependKindAttr(firOpBuilder, kind);
- dependTypeOperands.append(objects.size(), dependTypeOperand);
+ result.dependTypeAttrs.append(objects.size(), dependTypeOperand);
for (const omp::Object &object : objects) {
assert(object.ref() && "Expecting designator");
@@ -746,14 +744,14 @@ bool ClauseProcessor::processDepend(
Fortran::semantics::Symbol *sym = object.id();
const mlir::Value variable = converter.getSymbolAddress(*sym);
- dependOperands.push_back(variable);
+ result.dependVars.push_back(variable);
}
});
}
bool ClauseProcessor::processIf(
omp::clause::If::DirectiveNameModifier directiveName,
- mlir::Value &result) const {
+ mlir::omp::IfClauseOps &result) const {
bool found = false;
findRepeatableClause<omp::clause::If>(
[&](const omp::clause::If &clause,
@@ -764,7 +762,7 @@ bool ClauseProcessor::processIf(
// Assume that, at most, a single 'if' clause will be applicable to the
// given directive.
if (operand) {
- result = operand;
+ result.ifVar = operand;
found = true;
}
});
@@ -807,12 +805,10 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
bool ClauseProcessor::processMap(
mlir::Location currentLocation, const llvm::omp::Directive &directive,
- Fortran::lower::StatementContext &stmtCtx,
- llvm::SmallVectorImpl<mlir::Value> &mapOperands,
- llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
+ Fortran::lower::StatementContext &stmtCtx, mlir::omp::MapClauseOps &result,
+ llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSyms,
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
- const {
+ llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
fir::FirOpBuilder &firOpB...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/86802
More information about the llvm-branch-commits
mailing list