[flang-commits] [flang] 9bf5093 - [flang][openmp] Parallel reduction FIR lowering
Ethan Luis McDonough via flang-commits
flang-commits at lists.llvm.org
Fri Jun 30 13:36:34 PDT 2023
Author: Ethan Luis McDonough
Date: 2023-06-30T15:36:27-05:00
New Revision: 9bf50936237ded3fc324f4686dc4c1c5a9fb16eb
URL: https://github.com/llvm/llvm-project/commit/9bf50936237ded3fc324f4686dc4c1c5a9fb16eb
DIFF: https://github.com/llvm/llvm-project/commit/9bf50936237ded3fc324f4686dc4c1c5a9fb16eb.diff
LOG: [flang][openmp] Parallel reduction FIR lowering
This patch extends the logic for lowering loop construct reductions to parallel block reductions.
Reviewed By: kiranchandramohan
Differential Revision: https://reviews.llvm.org/D154182
Added:
flang/test/Lower/OpenMP/parallel-reduction-add.f90
Modified:
flang/lib/Lower/Bridge.cpp
flang/lib/Lower/OpenMP.cpp
Removed:
flang/test/Lower/OpenMP/Todo/parallel-reduction.f90
################################################################################
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 087d32032278c3..3fee12b8b6e0ee 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2172,6 +2172,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
const Fortran::parser::OpenMPLoopConstruct *ompLoop =
std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
+ const Fortran::parser::OpenMPBlockConstruct *ompBlock =
+ std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u);
// If loop is part of an OpenMP Construct then the OpenMP dialect
// workshare loop operation has already been created. Only the
@@ -2196,8 +2198,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
genFIR(e);
- if (ompLoop)
+ if (ompLoop) {
genOpenMPReduction(*this, *loopOpClauseList);
+ } else if (ompBlock) {
+ const auto &blockStart =
+ std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t);
+ const auto &blockClauses =
+ std::get<Fortran::parser::OmpClauseList>(blockStart.t);
+ genOpenMPReduction(*this, blockClauses);
+ }
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 67b88d2033aa25..61eb1cc7d2a957 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1154,209 +1154,6 @@ createCombinedParallelOp(Fortran::lower::AbstractConverter &converter,
/*isCombined=*/true);
}
-static void
-genOMP(Fortran::lower::AbstractConverter &converter,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
- const auto &beginBlockDirective =
- std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
- const auto &blockDirective =
- std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
- const auto &endBlockDirective =
- std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- mlir::Location currentLocation = converter.genLocation(blockDirective.source);
-
- Fortran::lower::StatementContext stmtCtx;
- llvm::ArrayRef<mlir::Type> argTy;
- mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand,
- priorityClauseOperand;
- mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
- SmallVector<Value> allocateOperands, allocatorOperands, dependOperands;
- SmallVector<Attribute> dependTypeOperands;
- mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;
-
- const auto &opClauseList =
- std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
- for (const auto &clause : opClauseList.v) {
- mlir::Location clauseLocation = converter.genLocation(clause.source);
- if (const auto &ifClause =
- std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
- ifClauseOperand =
- getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation);
- } else if (const auto &numThreadsClause =
- std::get_if<Fortran::parser::OmpClause::NumThreads>(
- &clause.u)) {
- // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
- numThreadsClauseOperand = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
- } else if (const auto &procBindClause =
- std::get_if<Fortran::parser::OmpClause::ProcBind>(
- &clause.u)) {
- procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause);
- } else if (const auto &allocateClause =
- std::get_if<Fortran::parser::OmpClause::Allocate>(
- &clause.u)) {
- genAllocateClause(converter, allocateClause->v, allocatorOperands,
- allocateOperands);
- } else if (std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) ||
- std::get_if<Fortran::parser::OmpClause::Firstprivate>(
- &clause.u) ||
- std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) {
- // Privatisation and copyin clauses are handled elsewhere.
- continue;
- } else if (std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) {
- // Shared is the default behavior in the IR, so no handling is required.
- continue;
- } else if (const auto &defaultClause =
- std::get_if<Fortran::parser::OmpClause::Default>(
- &clause.u)) {
- if ((defaultClause->v.v ==
- Fortran::parser::OmpDefaultClause::Type::Shared) ||
- (defaultClause->v.v ==
- Fortran::parser::OmpDefaultClause::Type::None)) {
- // Default clause with shared or none do not require any handling since
- // Shared is the default behavior in the IR and None is only required
- // for semantic checks.
- continue;
- }
- } else if (std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u)) {
- // Nothing needs to be done for threads clause.
- continue;
- } else if (std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
- // Map clause is exclusive to Target Data directives. It is handled
- // as part of the TargetOp creation.
- continue;
- } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
- &clause.u)) {
- // UseDevicePtr clause is exclusive to Target Data directives. It is
- // handled as part of the TargetOp creation.
- continue;
- } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
- &clause.u)) {
- // UseDeviceAddr clause is exclusive to Target Data directives. It is
- // handled as part of the TargetOp creation.
- continue;
- } else if (std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
- &clause.u)) {
- // Handled as part of TargetOp creation.
- continue;
- } else if (const auto &finalClause =
- std::get_if<Fortran::parser::OmpClause::Final>(&clause.u)) {
- mlir::Value finalVal = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
- finalClauseOperand = firOpBuilder.createConvert(
- currentLocation, firOpBuilder.getI1Type(), finalVal);
- } else if (std::get_if<Fortran::parser::OmpClause::Untied>(&clause.u)) {
- untiedAttr = firOpBuilder.getUnitAttr();
- } else if (std::get_if<Fortran::parser::OmpClause::Mergeable>(&clause.u)) {
- mergeableAttr = firOpBuilder.getUnitAttr();
- } else if (const auto &priorityClause =
- std::get_if<Fortran::parser::OmpClause::Priority>(
- &clause.u)) {
- priorityClauseOperand = fir::getBase(converter.genExprValue(
- *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
- } else if (std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
- TODO(currentLocation,
- "Reduction in OpenMP " +
- llvm::omp::getOpenMPDirectiveName(blockDirective.v) +
- " construct");
- } else if (const auto &dependClause =
- std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u)) {
- const std::list<Fortran::parser::Designator> &depVal =
- std::get<std::list<Fortran::parser::Designator>>(
- std::get<Fortran::parser::OmpDependClause::InOut>(
- dependClause->v.u)
- .t);
- omp::ClauseTaskDependAttr dependTypeOperand =
- genDependKindAttr(firOpBuilder, dependClause);
- dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
- dependTypeOperand);
- for (const Fortran::parser::Designator &ompObject : depVal) {
- Fortran::semantics::Symbol *sym = nullptr;
- std::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::DataRef &designator) {
- if (const Fortran::parser::Name *name =
- std::get_if<Fortran::parser::Name>(&designator.u)) {
- sym = name->symbol;
- } else if (std::get_if<Fortran::common::Indirection<
- Fortran::parser::ArrayElement>>(
- &designator.u)) {
- TODO(converter.getCurrentLocation(),
- "array sections not supported for task depend");
- }
- },
- [&](const Fortran::parser::Substring &designator) {
- TODO(converter.getCurrentLocation(),
- "substring not supported for task depend");
- }},
- (ompObject).u);
- const mlir::Value variable = converter.getSymbolAddress(*sym);
- dependOperands.push_back(((variable)));
- }
- } else {
- TODO(converter.getCurrentLocation(), "OpenMP Block construct clause");
- }
- }
-
- for (const auto &clause :
- std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
- if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
- nowaitAttr = firOpBuilder.getUnitAttr();
- }
-
- if (blockDirective.v == llvm::omp::OMPD_parallel) {
- // Create and insert the operation.
- auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
- currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
- allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(),
- /*reductions=*/nullptr, procBindKindAttr);
- createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation,
- eval, &opClauseList);
- } else if (blockDirective.v == llvm::omp::OMPD_master) {
- auto masterOp =
- firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
- createBodyOfOp<omp::MasterOp>(masterOp, converter, currentLocation, eval);
- } else if (blockDirective.v == llvm::omp::OMPD_single) {
- auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>(
- currentLocation, allocateOperands, allocatorOperands, nowaitAttr);
- createBodyOfOp<omp::SingleOp>(singleOp, converter, currentLocation, eval,
- &opClauseList);
- } else if (blockDirective.v == llvm::omp::OMPD_ordered) {
- auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>(
- currentLocation, /*simd=*/false);
- createBodyOfOp<omp::OrderedRegionOp>(orderedOp, converter, currentLocation,
- eval);
- } else if (blockDirective.v == llvm::omp::OMPD_task) {
- auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
- currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
- mergeableAttr, /*in_reduction_vars=*/ValueRange(),
- /*in_reductions=*/nullptr, priorityClauseOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(firOpBuilder.getContext(),
- dependTypeOperands),
- dependOperands, allocateOperands, allocatorOperands);
- createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
- } else if (blockDirective.v == llvm::omp::OMPD_taskgroup) {
- // TODO: Add task_reduction support
- auto taskGroupOp = firOpBuilder.create<mlir::omp::TaskGroupOp>(
- currentLocation, /*task_reduction_vars=*/ValueRange(),
- /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
- createBodyOfOp(taskGroupOp, converter, currentLocation, eval,
- &opClauseList);
- } else if (blockDirective.v == llvm::omp::OMPD_target) {
- createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
- &eval);
- } else if (blockDirective.v == llvm::omp::OMPD_target_data) {
- createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
- &eval);
- } else {
- TODO(currentLocation, "Unhandled block directive");
- }
-}
-
/// This function returns the identity value of the operator \p reductionOpName.
/// For example:
/// 0 + x = x,
@@ -1691,6 +1488,97 @@ static std::string getReductionName(
return getReductionName(reductionName, ty);
}
+/// Creates a reduction declaration and associates it with an
+/// OpenMP block directive
+static void
+addReductionDecl(mlir::Location currentLocation,
+ Fortran::lower::AbstractConverter &converter,
+ const Fortran::parser::OmpReductionClause &reduction,
+ SmallVector<Value> &reductionVars,
+ SmallVector<Attribute> &reductionDeclSymbols) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ omp::ReductionDeclareOp decl;
+ const auto &redOperator{
+ std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
+ const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+ if (const auto &redDefinedOp =
+ std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+ const auto &intrinsicOp{
+ std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+ redDefinedOp->u)};
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ break;
+
+ default:
+ TODO(currentLocation,
+ "Reduction of some intrinsic operators is not supported");
+ break;
+ }
+ for (const auto &ompObject : objectList.v) {
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+ if (const auto *symbol{name->symbol}) {
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ if (redType.isa<fir::LogicalType>())
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
+ intrinsicOp, redType, currentLocation);
+ else if (redType.isIntOrIndexOrFloat()) {
+ decl = createReductionDecl(firOpBuilder,
+ getReductionName(intrinsicOp, redType),
+ intrinsicOp, redType, currentLocation);
+ } else {
+ TODO(currentLocation, "Reduction of some types is not supported");
+ }
+ reductionDeclSymbols.push_back(
+ SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ }
+ } else if (auto reductionIntrinsic =
+ std::get_if<Fortran::parser::ProcedureDesignator>(
+ &redOperator.u)) {
+ if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+ reductionIntrinsic)}) {
+ if ((name->source != "max") && (name->source != "min") &&
+ (name->source != "ior") && (name->source != "ieor") &&
+ (name->source != "iand")) {
+ TODO(currentLocation,
+ "Reduction of intrinsic procedures is not supported");
+ }
+ std::string intrinsicOp = name->ToString();
+ for (const auto &ompObject : objectList.v) {
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+ if (const auto *symbol{name->symbol}) {
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ assert(redType.isIntOrIndexOrFloat() &&
+ "Unsupported reduction type");
+ decl = createReductionDecl(
+ firOpBuilder, getReductionName(intrinsicOp, redType),
+ *reductionIntrinsic, redType, currentLocation);
+ reductionDeclSymbols.push_back(SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ }
+ }
+ }
+}
+
static void genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
@@ -1786,88 +1674,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
} else if (const auto &reductionClause =
std::get_if<Fortran::parser::OmpClause::Reduction>(
&clause.u)) {
- omp::ReductionDeclareOp decl;
- const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
- reductionClause->v.t)};
- const auto &objectList{
- std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
- if (const auto &redDefinedOp =
- std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
- const auto &intrinsicOp{
- std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
- redDefinedOp->u)};
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
- break;
-
- default:
- TODO(currentLocation,
- "Reduction of some intrinsic operators is not supported");
- break;
- }
- for (const auto &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const auto *symbol{name->symbol}) {
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- if (redType.isa<fir::LogicalType>())
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
- intrinsicOp, redType, currentLocation);
- else if (redType.isIntOrIndexOrFloat()) {
- decl = createReductionDecl(
- firOpBuilder, getReductionName(intrinsicOp, redType),
- intrinsicOp, redType, currentLocation);
- } else {
- TODO(currentLocation,
- "Reduction of some types is not supported");
- }
- reductionDeclSymbols.push_back(SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- } else if (auto reductionIntrinsic =
- std::get_if<Fortran::parser::ProcedureDesignator>(
- &redOperator.u)) {
- if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
- reductionIntrinsic)}) {
- if ((name->source != "max") && (name->source != "min") &&
- (name->source != "ior") && (name->source != "ieor") &&
- (name->source != "iand")) {
- TODO(currentLocation,
- "Reduction of intrinsic procedures is not supported");
- }
- std::string intrinsicOp = name->ToString();
- for (const auto &ompObject : objectList.v) {
- if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
- ompObject)}) {
- if (const auto *symbol{name->symbol}) {
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- assert(redType.isIntOrIndexOrFloat() &&
- "Unsupported reduction type");
- decl = createReductionDecl(
- firOpBuilder, getReductionName(intrinsicOp, redType),
- *reductionIntrinsic, redType, currentLocation);
- reductionDeclSymbols.push_back(SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- }
- }
+ addReductionDecl(currentLocation, converter, reductionClause->v,
+ reductionVars, reductionDeclSymbols);
} else if (const auto &simdlenClause =
std::get_if<Fortran::parser::OmpClause::Simdlen>(
&clause.u)) {
@@ -2003,6 +1811,214 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
&loopOpClauseList, iv, /*outer=*/false, &dsp);
}
+static void
+genOMP(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
+ const auto &beginBlockDirective =
+ std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
+ const auto &blockDirective =
+ std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
+ const auto &endBlockDirective =
+ std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::Location currentLocation = converter.genLocation(blockDirective.source);
+
+ Fortran::lower::StatementContext stmtCtx;
+ llvm::ArrayRef<mlir::Type> argTy;
+ mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand,
+ priorityClauseOperand;
+ mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
+ SmallVector<Value> allocateOperands, allocatorOperands, dependOperands,
+ reductionVars;
+ SmallVector<Attribute> dependTypeOperands, reductionDeclSymbols;
+ mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;
+
+ const auto &opClauseList =
+ std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
+ for (const auto &clause : opClauseList.v) {
+ mlir::Location clauseLocation = converter.genLocation(clause.source);
+ if (const auto &ifClause =
+ std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
+ ifClauseOperand =
+ getIfClauseOperand(converter, stmtCtx, ifClause, clauseLocation);
+ } else if (const auto &numThreadsClause =
+ std::get_if<Fortran::parser::OmpClause::NumThreads>(
+ &clause.u)) {
+ // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
+ numThreadsClauseOperand = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
+ } else if (const auto &procBindClause =
+ std::get_if<Fortran::parser::OmpClause::ProcBind>(
+ &clause.u)) {
+ procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause);
+ } else if (const auto &allocateClause =
+ std::get_if<Fortran::parser::OmpClause::Allocate>(
+ &clause.u)) {
+ genAllocateClause(converter, allocateClause->v, allocatorOperands,
+ allocateOperands);
+ } else if (std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) ||
+ std::get_if<Fortran::parser::OmpClause::Firstprivate>(
+ &clause.u) ||
+ std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) {
+ // Privatisation and copyin clauses are handled elsewhere.
+ continue;
+ } else if (std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) {
+ // Shared is the default behavior in the IR, so no handling is required.
+ continue;
+ } else if (const auto &defaultClause =
+ std::get_if<Fortran::parser::OmpClause::Default>(
+ &clause.u)) {
+ if ((defaultClause->v.v ==
+ Fortran::parser::OmpDefaultClause::Type::Shared) ||
+ (defaultClause->v.v ==
+ Fortran::parser::OmpDefaultClause::Type::None)) {
+ // Default clause with shared or none do not require any handling since
+ // Shared is the default behavior in the IR and None is only required
+ // for semantic checks.
+ continue;
+ }
+ } else if (std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u)) {
+ // Nothing needs to be done for threads clause.
+ continue;
+ } else if (std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
+ // Map clause is exclusive to Target Data directives. It is handled
+ // as part of the TargetOp creation.
+ continue;
+ } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
+ &clause.u)) {
+ // UseDevicePtr clause is exclusive to Target Data directives. It is
+ // handled as part of the TargetOp creation.
+ continue;
+ } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
+ &clause.u)) {
+ // UseDeviceAddr clause is exclusive to Target Data directives. It is
+ // handled as part of the TargetOp creation.
+ continue;
+ } else if (std::get_if<Fortran::parser::OmpClause::ThreadLimit>(
+ &clause.u)) {
+ // Handled as part of TargetOp creation.
+ continue;
+ } else if (const auto &finalClause =
+ std::get_if<Fortran::parser::OmpClause::Final>(&clause.u)) {
+ mlir::Value finalVal = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
+ finalClauseOperand = firOpBuilder.createConvert(
+ currentLocation, firOpBuilder.getI1Type(), finalVal);
+ } else if (std::get_if<Fortran::parser::OmpClause::Untied>(&clause.u)) {
+ untiedAttr = firOpBuilder.getUnitAttr();
+ } else if (std::get_if<Fortran::parser::OmpClause::Mergeable>(&clause.u)) {
+ mergeableAttr = firOpBuilder.getUnitAttr();
+ } else if (const auto &priorityClause =
+ std::get_if<Fortran::parser::OmpClause::Priority>(
+ &clause.u)) {
+ priorityClauseOperand = fir::getBase(converter.genExprValue(
+ *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
+ } else if (const auto &reductionClause =
+ std::get_if<Fortran::parser::OmpClause::Reduction>(
+ &clause.u)) {
+ addReductionDecl(currentLocation, converter, reductionClause->v,
+ reductionVars, reductionDeclSymbols);
+ } else if (const auto &dependClause =
+ std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u)) {
+ const std::list<Fortran::parser::Designator> &depVal =
+ std::get<std::list<Fortran::parser::Designator>>(
+ std::get<Fortran::parser::OmpDependClause::InOut>(
+ dependClause->v.u)
+ .t);
+ omp::ClauseTaskDependAttr dependTypeOperand =
+ genDependKindAttr(firOpBuilder, dependClause);
+ dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
+ dependTypeOperand);
+ for (const Fortran::parser::Designator &ompObject : depVal) {
+ Fortran::semantics::Symbol *sym = nullptr;
+ std::visit(
+ Fortran::common::visitors{
+ [&](const Fortran::parser::DataRef &designator) {
+ if (const Fortran::parser::Name *name =
+ std::get_if<Fortran::parser::Name>(&designator.u)) {
+ sym = name->symbol;
+ } else if (std::get_if<Fortran::common::Indirection<
+ Fortran::parser::ArrayElement>>(
+ &designator.u)) {
+ TODO(converter.getCurrentLocation(),
+ "array sections not supported for task depend");
+ }
+ },
+ [&](const Fortran::parser::Substring &designator) {
+ TODO(converter.getCurrentLocation(),
+ "substring not supported for task depend");
+ }},
+ (ompObject).u);
+ const mlir::Value variable = converter.getSymbolAddress(*sym);
+ dependOperands.push_back(((variable)));
+ }
+ } else {
+ TODO(converter.getCurrentLocation(), "OpenMP Block construct clause");
+ }
+ }
+
+ for (const auto &clause :
+ std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
+ if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
+ nowaitAttr = firOpBuilder.getUnitAttr();
+ }
+
+ if (blockDirective.v == llvm::omp::OMPD_parallel) {
+ // Create and insert the operation.
+ auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
+ currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
+ allocateOperands, allocatorOperands, reductionVars,
+ reductionDeclSymbols.empty()
+ ? nullptr
+ : mlir::ArrayAttr::get(firOpBuilder.getContext(),
+ reductionDeclSymbols),
+ procBindKindAttr);
+ createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation,
+ eval, &opClauseList);
+ } else if (blockDirective.v == llvm::omp::OMPD_master) {
+ auto masterOp =
+ firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
+ createBodyOfOp<omp::MasterOp>(masterOp, converter, currentLocation, eval);
+ } else if (blockDirective.v == llvm::omp::OMPD_single) {
+ auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>(
+ currentLocation, allocateOperands, allocatorOperands, nowaitAttr);
+ createBodyOfOp<omp::SingleOp>(singleOp, converter, currentLocation, eval,
+ &opClauseList);
+ } else if (blockDirective.v == llvm::omp::OMPD_ordered) {
+ auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>(
+ currentLocation, /*simd=*/false);
+ createBodyOfOp<omp::OrderedRegionOp>(orderedOp, converter, currentLocation,
+ eval);
+ } else if (blockDirective.v == llvm::omp::OMPD_task) {
+ auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
+ currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
+ mergeableAttr, /*in_reduction_vars=*/ValueRange(),
+ /*in_reductions=*/nullptr, priorityClauseOperand,
+ dependTypeOperands.empty()
+ ? nullptr
+ : mlir::ArrayAttr::get(firOpBuilder.getContext(),
+ dependTypeOperands),
+ dependOperands, allocateOperands, allocatorOperands);
+ createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
+ } else if (blockDirective.v == llvm::omp::OMPD_taskgroup) {
+ // TODO: Add task_reduction support
+ auto taskGroupOp = firOpBuilder.create<mlir::omp::TaskGroupOp>(
+ currentLocation, /*task_reduction_vars=*/ValueRange(),
+ /*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
+ createBodyOfOp(taskGroupOp, converter, currentLocation, eval,
+ &opClauseList);
+ } else if (blockDirective.v == llvm::omp::OMPD_target) {
+ createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
+ &eval);
+ } else if (blockDirective.v == llvm::omp::OMPD_target_data) {
+ createTargetOp(converter, opClauseList, blockDirective.v, currentLocation,
+ &eval);
+ } else {
+ TODO(currentLocation, "Unhandled block directive");
+ }
+}
+
static void
genOMP(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
diff --git a/flang/test/Lower/OpenMP/Todo/parallel-reduction.f90 b/flang/test/Lower/OpenMP/Todo/parallel-reduction.f90
deleted file mode 100644
index df13f1faefe806..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/parallel-reduction.f90
+++ /dev/null
@@ -1,11 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction in OpenMP parallel construct
-subroutine reduction_parallel
- integer :: x
- !$omp parallel reduction(+:x)
- x = x + i
- !$omp end parallel
- print *, x
-end subroutine
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-add.f90 b/flang/test/Lower/OpenMP/parallel-reduction-add.f90
new file mode 100644
index 00000000000000..cab076f817cadd
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-add.f90
@@ -0,0 +1,97 @@
+! RUN: bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_F32_NAME:.*]] : f32 init {
+!CHECK: ^bb0(%{{.*}}: f32):
+!CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK: omp.yield(%[[C0_1]] : f32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32):
+!CHECK: %[[RES:.*]] = arith.addf %[[ARG0]], %[[ARG1]] {{.*}}: f32
+!CHECK: omp.yield(%[[RES]] : f32)
+!CHECK: }
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
+!CHECK: ^bb0(%{{.*}}: i32):
+!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK: omp.yield(%[[C0_1]] : i32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
+!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+!CHECK: omp.yield(%[[RES]] : i32)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_int_add
+!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
+!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
+!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
+!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_int_add
+ integer :: i
+ i = 0
+
+ !$omp parallel reduction(+:i)
+ i = i + 1
+ !$omp end parallel
+
+ print *, i
+end subroutine
+
+!CHECK-LABEL: func.func @_QPsimple_real_add
+!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
+!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
+!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
+!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_real_add
+ real :: r
+ r = 0.0
+
+ !$omp parallel reduction(+:r)
+ r = r + 1.5
+ !$omp end parallel
+
+ print *, r
+end subroutine
+
+!CHECK-LABEL: func.func @_QPint_real_add
+!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFint_real_addEi"}
+!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFint_real_addEr"}
+!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
+!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
+!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
+!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
+!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
+!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
+!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
+!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
+!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine int_real_add
+ real :: r
+ integer :: i
+
+ r = 0.0
+ i = 0
+
+ !$omp parallel reduction(+:i,r)
+ r = 1.5 + r
+ i = i + 3
+ !$omp end parallel
+
+ print *, r
+ print *, i
+end subroutine
More information about the flang-commits
mailing list